In [12]:
import os
from grl.p2sro.payoff_table import PayoffTableStrategySpec, PayoffTable
from grl.rl_apps.kuhn_poker_p2sro.poker_multi_agent_env import PokerMultiAgentEnv
import json
from grl.p2sro.payoff_table import PayoffTable
from ray.rllib.policy import Policy
from ray.rllib.models.action_dist import ActionDistribution

import deepdish

from typing import Dict, Callable, List, Tuple
from open_spiel.python import rl_environment


import numpy as np


import ray

from open_spiel.python.policy import Policy as OpenSpielPolicy, tabular_policy_from_callable
from open_spiel.python.algorithms.exploitability import exploitability
from pyspiel import Game as OpenSpielGame

import pyspiel
from grl.p2sro.p2sro_manager.utils import get_latest_metanash_strategies
from grl.p2sro.payoff_table import PayoffTableStrategySpec
from grl.rl_apps.kuhn_poker_p2sro.poker_multi_agent_env import PokerMultiAgentEnv
from ray.rllib.utils import try_import_torch
from ray.rllib.utils.typing import TensorType

from grl.rl_apps.kuhn_poker_p2sro.general_psro_approx_br import train_poker_approx_best_response_xdfo

torch, _ = try_import_torch()


def train_br_for_metanash(args):
    player, metanash_policy_specs, metanash_weights, scenario = args
    return train_poker_approx_best_response_xdfo(br_player=player,
                                          ray_head_address=None,
                                          scenario=scenario,
                                          general_trainer_config_overrrides={"metrics_smoothing_episodes": 6000},
                                          br_policy_config_overrides={"lr": 0.001,
                                                                      "model": {"fcnet_hiddens": [128, 128]}},
                                          get_stopping_condition=lambda: EpisodesSingleBRRewardPlateauStoppingCondition(
                                              br_policy_id="best_response",
                                              dont_check_plateau_before_n_episodes=int(6e4),
                                              check_plateau_every_n_episodes=int(4e4),
                                              minimum_reward_improvement_otherwise_plateaued=0.01,
                                              max_train_episodes=int(1e2)),
                                          metanash_policy_specs=metanash_policy_specs,
                                          metanash_weights=metanash_weights,
                                          results_dir="/tmp",
                                          print_train_results=True
                                          )




def get_stats_for_single_payoff_table(payoff_table:PayoffTable, highest_policy_num: int, scenario):

    exploitability_per_generation = []
    total_steps_per_generation = []
    total_episodes_per_generation = []
    num_policies_per_generation = []

    for i, n_policies in enumerate(range(1, highest_policy_num + 1)):

        metanash_probs_0 = get_latest_metanash_strategies(payoff_table=payoff_table,
                                                        as_player=1,
                                                        as_policy_num=n_policies,
                                                        fictitious_play_iters=2000,
                                                        mix_with_uniform_dist_coeff=0.0,
                                                        print_matrix=False)[0].probabilities_for_each_strategy()

        metanash_probs_1 = get_latest_metanash_strategies(payoff_table=payoff_table,
                                                          as_player=0,
                                                          as_policy_num=n_policies,
                                                          fictitious_play_iters=2000,
                                                          mix_with_uniform_dist_coeff=0.0,
                                                          print_matrix=False)[1].probabilities_for_each_strategy()

        pure_strat_index = get_latest_metanash_strategies(payoff_table=payoff_table,
                                       as_player=0,
                                       as_policy_num=n_policies,
                                       fictitious_play_iters=2000,
                                       mix_with_uniform_dist_coeff=0.0,
                                       print_matrix=False)[1].sample_policy_spec().get_pure_strat_indexes()
        # print(f"pure strat index: {pure_strat_index}")


        policy_specs_0 = payoff_table.get_ordered_spec_list_for_player(player=0)[:n_policies]

        policy_specs_1 = payoff_table.get_ordered_spec_list_for_player(player=1)[:n_policies]

        assert len(metanash_probs_1) == len(policy_specs_1), f"len(metanash_probs_1): {len(metanash_probs_1)}, len(policy_specs_1): {len(policy_specs_1)}"
        assert len(metanash_probs_0) == len(policy_specs_0)
        assert len(policy_specs_0) == len(policy_specs_1)

        # print(policy_specs_0)
        # print(metanash_probs_0)
        # print(policy_specs_1)
        # print(metanash_probs_1)

        pool = Pool(processes=4)

        # modified_scenario["restricted_game_custom_model"] = get_valid_action_fcn_class(obs_len=3, action_space_n=61, dummy_actions_multiplier=1)
        final_br_rewards = pool.map(train_br_for_metanash, [(0, policy_specs_1, metanash_probs_1, scenario), (1, policy_specs_0, metanash_probs_0, scenario)])
        exploitability_this_gen = np.mean(final_br_rewards)

        pool.close()
        pool.terminate()
        print(f"{n_policies} policies, {exploitability_this_gen} exploitability")

        policy_spec_added_this_gen = [payoff_table.get_spec_for_player_and_pure_strat_index(
            player=p, pure_strat_index=i) for p in range(2)]

        latest_policy_steps = sum(policy_spec_added_this_gen[p].metadata["timesteps_training_br"] for p in range(2))
        latest_policy_episodes = sum(policy_spec_added_this_gen[p].metadata["episodes_training_br"] for p in range(2))

        if i > 0:
            total_steps_this_generation = latest_policy_steps + total_steps_per_generation[i-1]
            total_episodes_this_generation = latest_policy_episodes + total_episodes_per_generation[i-1]
        else:
            total_steps_this_generation = latest_policy_steps
            total_episodes_this_generation = latest_policy_episodes

        exploitability_per_generation.append(exploitability_this_gen)
        total_steps_per_generation.append(total_steps_this_generation)
        total_episodes_per_generation.append(total_episodes_this_generation)
        num_policies_per_generation.append(n_policies)

    stats_out = {'num_policies': num_policies_per_generation, 'exploitability': exploitability_per_generation,
                 'timesteps': total_steps_per_generation, 'episodes': total_episodes_per_generation}

    return stats_out












player: 1 iter: 55600, average_policy_player_1_iter_55600.json
{'checkpoint_path': '/home/jblanier/git/grl/grl/data/30_no_limit_leduc_nfsp_dqn_gpu_sparse_05.58.20PM_Jan-30-2021hz7ctui_/avg_policy_checkpoints/average_policy_player_1_iter_55600.h5', 'timesteps_training': 56949760, 'episodes_training': 14170891}
player: 1 iter: 7200, average_policy_player_1_iter_7200.json
{'checkpoint_path': '/home/jblanier/git/grl/grl/data/30_no_limit_leduc_nfsp_dqn_gpu_sparse_05.58.20PM_Jan-30-2021hz7ctui_/avg_policy_checkpoints/average_policy_player_1_iter_7200.h5', 'timesteps_training': 7388160, 'episodes_training': 1631989}
player: 0 iter: 35600, average_policy_player_0_iter_35600.json
{'checkpoint_path': '/home/jblanier/git/grl/grl/data/30_no_limit_leduc_nfsp_dqn_gpu_sparse_05.58.20PM_Jan-30-2021hz7ctui_/avg_policy_checkpoints/average_policy_player_0_iter_35600.h5', 'timesteps_training': 36469760, 'episodes_training': 8902984}
player: 1 iter: 24400, average_policy_player_1_iter_24400.json
{'checkpoi

{'checkpoint_path': '/home/jblanier/git/grl/grl/data/30_no_limit_leduc_nfsp_dqn_gpu_sparse_05.58.20PM_Jan-30-2021hz7ctui_/avg_policy_checkpoints/average_policy_player_0_iter_400.h5', 'timesteps_training': 424960, 'episodes_training': 105896}
{'checkpoint_path': '/home/jblanier/git/grl/grl/data/30_no_limit_leduc_nfsp_dqn_gpu_sparse_05.58.20PM_Jan-30-2021hz7ctui_/avg_policy_checkpoints/average_policy_player_1_iter_400.h5', 'timesteps_training': 424960, 'episodes_training': 105896}


In [14]:
from grl.rl_apps.nfsp.general_approx_br_nfsp import train_poker_approx_best_response_nfsp
from grl.rl_apps.scenarios.stopping_conditions import EpisodesSingleBRRewardPlateauStoppingCondition
from grl.rl_apps.scenarios.ray_setup import init_ray_for_scenario
from grl.rl_apps.scenarios.poker import scenarios
from grl.rllib_tools.valid_actions_fcnet import get_valid_action_fcn_class
from multiprocessing.pool import Pool
import logging
from grl.utils import datetime_str
logger = logging.getLogger(__name__)





all policies are [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209]
policies to evaulate are [  0  52 104 157 

ValueError: `callbacks` must be a callable method that returns a subclass of DefaultCallbacks, got None

In [None]:
import json
import numpy as np
import re
from grl.p2sro.payoff_table import PayoffTable

# psro_payoff_table_checkpoint_path = "/home/jb/git/grl/grl/data/simple_push_psro/manager_08.41.48PM_Dec-28-2020/payoff_table_checkpoints/payoff_table_checkpoint_25.json"
# psro_policy_nums_path = "/home/jb/git/grl/grl/data/simple_push_psro/manager_08.41.48PM_Dec-28-2020/payoff_table_checkpoints/policy_nums_checkpoint_25.json"

# psro_payoff_table_checkpoint_path = "/home/jb/git/grl/grl/data/simple_push_psro/manager_10.59.03PM_Dec-28-2020/payoff_table_checkpoints/payoff_table_checkpoint_30.json"
# psro_policy_nums_path = "/home/jb/git/grl/grl/data/simple_push_psro/manager_10.59.03PM_Dec-28-2020/payoff_table_checkpoints/policy_nums_checkpoint_30.json"

scenario_name = "30_no_limit_leduc_nfsp_dqn_gpu"

psro_payoff_table_checkpoint_path = "/home/jb/git/grl/grl/data/simple_push_psro/manager_01.04.06AM_Dec-29-2020/payoff_table_checkpoints/payoff_table_checkpoint_81.json"
psro_policy_nums_path = "/home/jb/git/grl/grl/data/simple_push_psro/manager_01.04.06AM_Dec-29-2020/payoff_table_checkpoints/policy_nums_checkpoint_81.json"


logging.basicConfig()


try:
    scenario = scenarios[scenario_name]
except KeyError:
    raise NotImplementedError(f"Unknown scenario name: \'{scenario_name}\'. Existing scenarios are:\n"
                              f"{list(scenarios.keys())}")


original_date = re.search("\d\d\.\d\d.\d\d.\w\w\w\w\w\W\d\d\W\d\d\d\d", psro_payoff_table_checkpoint_path)[0]

with open(psro_policy_nums_path, "r") as policy_nums_file:
    policy_nums = json.load(policy_nums_file)
largest_fixed_policy_nums = (max(policy_nums['0']['fixed_policies']), max(policy_nums['1']['fixed_policies']))

payoff_table = PayoffTable.from_json_file(json_file_path=psro_payoff_table_checkpoint_path)

In [None]:
from ray.rllib.agents.dqn import SimpleQTorchPolicy, SIMPLE_Q_DEFAULT_CONFIG
from ray.rllib.utils import merge_dicts


num_players = 2

stats = get_stats_for_single_payoff_table(
    payoff_table=payoff_table,
    highest_policy_num=min(largest_fixed_policy_nums),
    scenario=scenario
)


In [None]:
save_path = f"/home/jb/git/grl/grl/data/psro_approx_stats_{scenario_name}_{original_date}.json"
with open(save_path, "+w") as json_file:
    json.dump(stats, json_file)
print(f"saved to {save_path}")