In [1]:
from grl.p2sro.p2sro_manager import P2SROManager
from grl.p2sro.payoff_table import PayoffTable
from grl.p2sro.p2sro_manager.utils import get_latest_metanash_strategies, PolicySpecDistribution
from grl.rl_examples.particle_games.simple_push_multi_agent_env import SimplePushMultiAgentEnv

import numpy as np
import time
from typing import Type
import os
import numpy as np
import argparse
import deepdish
import ray
from ray.rllib.env.multi_agent_env import MultiAgentEnv

from ray.rllib.utils import merge_dicts
from ray.rllib.policy.policy import Policy
from ray.rllib.agents.sac import SACTorchPolicy, DEFAULT_CONFIG as DEFAULT_SAC_CONFIG
from grl.p2sro.eval_dispatcher.remote import RemoteEvalDispatcherClient
from grl.rl_examples.particle_games.config import simple_push_sac_params_small
from grl.p2sro.payoff_table import PayoffTableStrategySpec


def arcus_path_to_goku_path(path: str):
    return path.replace("/jblanier/", "/jb/")

def load_weights(policy: Policy, pure_strat_spec: PayoffTableStrategySpec = None, checkpoint_path=None):
    assert pure_strat_spec is None or checkpoint_path is None
    if pure_strat_spec:
        pure_strat_checkpoint_path = pure_strat_spec.metadata["checkpoint_path"]
    else:
        pure_strat_checkpoint_path = checkpoint_path

    pure_strat_checkpoint_path = arcus_path_to_goku_path(pure_strat_checkpoint_path)

    checkpoint_data = deepdish.io.load(path=pure_strat_checkpoint_path)
    weights = checkpoint_data["weights"]
    weights = {k.replace("_dot_", "."): v for k, v in weights.items()}
    policy.set_weights(weights=weights)



def run_episode(env, policies_for_each_player, render=False) -> np.ndarray:

    num_players = len(policies_for_each_player)

    obs = env.reset()
    dones = {}
    game_length = 0
    policy_states = [None] * num_players

    payoffs_per_player_this_episode = np.zeros(shape=num_players, dtype=np.float64)
    while True:
        if "__all__" in dones:
            if dones["__all__"]:
                break
        game_length += 1
        time.sleep(0.1)
        action_dict = {}
        for player in range(num_players):
            if player in obs:
                action_index, new_policy_state, action_info = policies_for_each_player[player].compute_single_action(
                    obs=obs[player], state=policy_states[player])
                policy_states[player] = new_policy_state
                action_dict[player] = action_index

        obs, rewards, dones, infos = env.step(action_dict=action_dict)
        if render:
            env.render()

        for player in range(num_players):
            payoff_so_far = payoffs_per_player_this_episode[player]
            payoffs_per_player_this_episode[player] = payoff_so_far + rewards.get(player, 0.0)

    return payoffs_per_player_this_episode


Instructions for updating:
non-resource variables are not supported in the long term


In [2]:



simple_push_large_params_psro_payoff_table_checkpoint_path = "/home/jb/git/grl/grl/data/manager_12.45.57AM_Dec-21-2020/payoff_table_checkpoints/payoff_table_checkpoint_70.json"
simple_push_large_params_psro_payoff_table = PayoffTable.from_json_file(
    json_file_path=simple_push_large_params_psro_payoff_table_checkpoint_path)


psro_strats = get_latest_metanash_strategies(payoff_table=simple_push_large_params_psro_payoff_table,
                                             as_player=0,
                                             as_policy_num=70,
                                             fictitious_play_iters=2000)

payoff matrix as 1 (row) against 0 (columns):
[[ 0.         -0.02401407 -0.04443765 ...  1.28148699  0.58778667
   0.51695776]
 [ 0.02401407  0.          1.28426635 ...  2.04232836  1.05170274
   0.66932929]
 [ 0.04443765 -1.28426635  0.         ...  2.47086668  2.96118617
   1.81787074]
 ...
 [-1.28148699 -2.04232836 -2.47086668 ...  0.          0.46119061
  -0.50427115]
 [-0.58778667 -1.05170274 -2.96118617 ... -0.46119061  0.
  -0.72462857]
 [-0.51695776 -0.66932929 -1.81787074 ...  0.50427115  0.72462857
   0.        ]]


In [3]:
probabilities_for_each_strategy = psro_strats[1].probabilities_for_each_strategy()

In [4]:
simple_push_large_params_psro_payoff_table.shape()


(71, 71)

In [5]:
env = SimplePushMultiAgentEnv(env_config={})

policy_class = SACTorchPolicy
policies = [policy_class(env.observation_space,
                         env.action_space,
                         merge_dicts(DEFAULT_SAC_CONFIG, simple_push_sac_params_small(action_space=env.action_space)))
            for _ in range(2)]





In [6]:
while True:

    # load_weights(policy=policies[0], checkpoint_path="/home/jb/Downloads/best_response_0_1_12.03.04AM_Dec-27-2020.h5")
    # load_weights(policy=policies[1], checkpoint_path="/home/jb/Downloads/best_response_1_1_12.03.04AM_Dec-27-2020.h5")

    load_weights(policy=policies[0], checkpoint_path="/home/jb/Downloads/best_response_0_71_10.15.13PM_Dec-27-2020.h5")
    load_weights(policy=policies[1], checkpoint_path="/home/jb/Downloads/best_response_1_71_10.15.13PM_Dec-27-2020.h5")


    payoffs = run_episode(env=env, policies_for_each_player=policies, render=True)
    print(payoffs)

[ 4.10058123 -4.10058123]
[-73.19352985  73.19352985]
[ 3.61957171 -3.61957171]
[ 9.39552767 -9.39552767]
[ 8.61421926 -8.61421926]
[-33.91738001  33.91738001]
[ 10.61900742 -10.61900742]
[ 5.69334235 -5.69334235]
[ 5.9652022 -5.9652022]
[-56.59009153  56.59009153]
[ 4.50155334 -4.50155334]
[ 0.11645411 -0.11645411]
[-23.64318313  23.64318313]
[ 4.83279944 -4.83279944]
[-25.34578466  25.34578466]
[-148.36690618  148.36690618]
[-91.88343715  91.88343715]
[-33.14098525  33.14098525]
[ 5.7619823 -5.7619823]
[-93.01432321  93.01432321]
[-12.6078147  12.6078147]
[ 3.2221016 -3.2221016]
[ 1.23025996 -1.23025996]
[ 4.37209543 -4.37209543]
[ 12.98634674 -12.98634674]
[-98.60562249  98.60562249]
[-80.33828414  80.33828414]
[-11.55596129  11.55596129]
[ 6.11246916 -6.11246916]
[-40.12408555  40.12408555]
[-103.95790819  103.95790819]
[ 7.43730874 -7.43730874]
[ 7.16083778 -7.16083778]
[-38.41514604  38.41514604]
[-21.25103596  21.25103596]
[-55.81686949  55.81686949]
[ 2.32663623 -2.32663623]
[ 

KeyboardInterrupt: 