In [1]:
from pettingzoo.classic import texas_holdem_v4
import copy
from agent_configs.cfr_config import CFRConfig
from active_player import ActivePlayer
from cfr_agent import CFRAgent
import torch
from cfr_network import CFRNetwork
game = texas_holdem_v4.env(num_players=2)



In [2]:
hidden_dim = 128
input_dim = 72
output_dim = 4
num_players = 2
replay_buffer_size = 4000000
minibatch_size = 5000
steps_per_epoch = 1000
traversals = 1500
training_steps = 100
lr = 0.001
optimizer = None
p_v_networks = {'input_shape':input_dim, 'output_shape':output_dim, 'hidden_size':hidden_dim, 'learning_rate':lr, 'optimizer':optimizer}
active_player_obj = ActivePlayer(num_players)
config = CFRConfig(
    config_dict={'network': {'policy': p_v_networks, 'value': p_v_networks, 'num_players':num_players},
                 'replay_buffer_size':replay_buffer_size,
                 'minibatch_size':minibatch_size,
                 'steps_per_epoch':steps_per_epoch,
                 'traversals': traversals,
                 'training_steps': training_steps,
                 'active_player_obj': active_player_obj,
                 },
    game_config={'num_players':num_players,
                 'observation_space':72,
                 'action_space':4,},


)

CFRConfig


In [None]:
modelselect = CFRAgent(
    env=game,
    config=config,
)
modelselect.train()


In [4]:
agent = torch.load('checkpoints/1744571971.415963.pt')

In [6]:
model = CFRNetwork(
     config = {'policy': p_v_networks, 'value': p_v_networks, 'num_players':num_players}
)
model.load_state_dict(agent)

<All keys matched successfully>

In [7]:
model.eval()

CFRNetwork(
  (policy): PolicyNetwork(
    (layers): ModuleList(
      (0): Linear(in_features=72, out_features=128, bias=True)
      (1-2): 2 x Linear(in_features=128, out_features=128, bias=True)
      (3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (4): Linear(in_features=128, out_features=4, bias=True)
    )
  )
)

In [11]:
model.values

[ValueNetwork(
   (layers): ModuleList(
     (0): Linear(in_features=72, out_features=128, bias=True)
     (1-2): 2 x Linear(in_features=128, out_features=128, bias=True)
     (3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
     (4): Linear(in_features=128, out_features=4, bias=True)
   )
 ),
 ValueNetwork(
   (layers): ModuleList(
     (0): Linear(in_features=72, out_features=128, bias=True)
     (1-2): 2 x Linear(in_features=128, out_features=128, bias=True)
     (3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
     (4): Linear(in_features=128, out_features=4, bias=True)
   )
 )]

In [88]:
eval_games = 100000
import numpy as np
rewards = []
for i in range(eval_games):
    # FOR EACH EVAL GAME, RESET ENVIRONEMENT (DEBATABLE STEP) BUT RESET WITH SET SEED FOR RECREATION
    random_seed = np.random.randint(0, 2**32 - 1)
    observation, reward, termination, truncation, infos =  modelselect.env.last()

    modelselect.env.reset(seed=random_seed)
    active_player =  modelselect.env.agent_selection[-1]
    modelselect.active_player_obj.set_active_player(int(active_player))
    while not termination and not truncation:
        # GET CURRENT STATE
        observation, reward, termination, truncation, infos =  modelselect.env.last()
        if termination or truncation:
            break
        active_player =  modelselect.active_player_obj.get_active_player()
        if active_player == 1:
            predictions = model.policy(torch.tensor(observation['observation'], dtype=torch.float32).reshape(1,72)).detach().numpy()[0]

            sample, policy = modelselect.select_actions(predictions, info=torch.from_numpy(observation["action_mask"]).type(torch.float), mask_actions=True)
        else:
            # predictions = np.ones(4) / 4
            # sample, policy = modelselect.select_actions(predictions, info=torch.from_numpy(observation["action_mask"]).type(torch.float), mask_actions=True)
            predictions = model.policy(torch.tensor(observation['observation'], dtype=torch.float32).reshape(1,72)).detach().numpy()[0]
            sample, policy = modelselect.select_actions(predictions, info=torch.from_numpy(observation["action_mask"]).type(torch.float), mask_actions=True)
        # if active player, branch off and traverse
        modelselect.env.step(sample)
        modelselect.active_player_obj.next()
    final_rewards = modelselect.env.rewards["player_1"]  # dict of {agent_0: r0, agent_1: r1}
    print(final_rewards)
    rewards.append(final_rewards)
    modelselect.env.close()

print(np.mean(rewards))
print(np.std(rewards))
print(rewards)

0
2.0
0
0.5
0
-0.5
0
-0.5
0
-1.0
0
0.5
0
0.5
0
1.0
0
8.0
0
-5.0
0
-0.5
0
-1.0
0
-1.0
0
-1.0
0
0.5
0
0.5
0
0.5
0
0.5
0
0.5
0
1.0
0
0.5
0
-0.5
0
0.5
0
0.5
0
-1.0
0
-0.5
0
0.5
0
-0.5
0
-1.0
0
2.0
0
-0.5
0
0.5
0
-1.0
0
-1.0
0
1.0
0
1.0
0
-3.0
0
1.0
0
4.0
0
-0.5
0
0.5
0
-0.5
0
0.5
0
-1.0
0
-1.0
0
0.5
0
-4.0
0
-0.5
0
-0.5
0
0.5
0
-1.0
0
-0.5
0
-5.0
0
2.0
0
0.5
0
2.0
0
-1.0
0
-0.5
0
-1.0
0
0.5
0
0.5
0
-1.0
0
-1.0
0
-1.0
0
-1.0
0
0.5
0
2.0
0
0.5
0
1.0
0
0.5
0
1.0
0
-0.5
0
-0.5
0
1.0
0
-1.0
0
-1.0
0
3.0
0
-1.0
0
0.5
0
1.0
0
1.0
0
1.0
0
0.5
0
1.0
0
-0.5
0
-0.5
0
1.0
0
-0.5
0
0.5
0
-1.0
0
1.0
0
-0.5
0
-0.5
0
0.5
0
0.5
0
0.5
0
-1.0
0
0.5
0
0.5
0
1.0
0
1.0
0
-0.5
0
-0.5
0
-3.0
0
-0.5
0
0.5
0
1.0
0
-0.5
0
-2.0
0
0.5
0
-1.0
0
1.0
0
-2.0
0
-0.5
0
1.0
0
1.0
0
2.0
0
1.0
0
-7.0
0
0.5
0
2.0
0
-3.0
0
-1.0
0
1.0
0
-1.0
0
-0.5
0
0.5
0
2.0
0
0.5
0
1.0
0
1.0
0
-0.5
0
-1.0
0
-1.0
0
0.5
0
1.0
0
1.0
0
-1.0
0
2.0
0
0.5
0
2.0
0
-0.5
0
-0.5
0
-0.5
0
0.5
0
-1.0
0
-1.0
0
-1.0
0
-1.0
0
3.0
0
-1.0
0
-0.5
0
-1.0
0
1.0
0


In [44]:
modelselect.env.reset()

In [37]:
modelselect.env.last()

({'observation': array([0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.,
         0., 0., 0., 0.], dtype=float32),
  'action_mask': array([0, 0, 0, 0], dtype=int8)},
 -0.5,
 True,
 False,
 {})

In [52]:
modelselect.env.close()