In [1]:
%load_ext autoreload
%autoreload 2
import torch
from torch.utils.data import DataLoader, TensorDataset

from games.chopsticks import ChopsticksGame
from models.chopsticks import ChopsticksMLP
from sims.tree import MCTS
from utils import trainer

# Training

In [2]:
game = ChopsticksGame(draw_limit=30)
model = ChopsticksMLP(game.state_dim(), 64, game.num_actions())
mcts = MCTS(game, model, {'c_puct': 1.0, 'num_simulations': 1000})

In [6]:
num_games = 50
batch_size = 64
num_epochs = 50

In [49]:
samples = trainer.self_play(mcts, game, num_games)
dataloader = trainer.create_dataloader(samples, batch_size=batch_size)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
trainer.train_model(model, dataloader, optimizer, num_epochs)

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [07:38<00:00,  9.18s/it]


Epoch 1/50, Loss: 2.0744
Epoch 2/50, Loss: 1.8088
Epoch 3/50, Loss: 1.7233
Epoch 4/50, Loss: 1.6321
Epoch 5/50, Loss: 1.5915
Epoch 6/50, Loss: 1.5983
Epoch 7/50, Loss: 1.5207
Epoch 8/50, Loss: 1.4463
Epoch 9/50, Loss: 1.4862
Epoch 10/50, Loss: 1.4853
Epoch 11/50, Loss: 1.4596
Epoch 12/50, Loss: 1.4306
Epoch 13/50, Loss: 1.4320
Epoch 14/50, Loss: 1.4452
Epoch 15/50, Loss: 1.4008
Epoch 16/50, Loss: 1.3922
Epoch 17/50, Loss: 1.3867
Epoch 18/50, Loss: 1.3929
Epoch 19/50, Loss: 1.3830
Epoch 20/50, Loss: 1.3599
Epoch 21/50, Loss: 1.3442
Epoch 22/50, Loss: 1.3310
Epoch 23/50, Loss: 1.3395
Epoch 24/50, Loss: 1.3140
Epoch 25/50, Loss: 1.3153
Epoch 26/50, Loss: 1.2832
Epoch 27/50, Loss: 1.3442
Epoch 28/50, Loss: 1.2909
Epoch 29/50, Loss: 1.3028
Epoch 30/50, Loss: 1.3121
Epoch 31/50, Loss: 1.2680
Epoch 32/50, Loss: 1.2829
Epoch 33/50, Loss: 1.2760
Epoch 34/50, Loss: 1.2895
Epoch 35/50, Loss: 1.2990
Epoch 36/50, Loss: 1.3024
Epoch 37/50, Loss: 1.2155
Epoch 38/50, Loss: 1.2549
Epoch 39/50, Loss: 1.

In [62]:
trainer.train_model(model, dataloader, optimizer, num_epochs)

Epoch 1/50, Loss: 0.8329
Epoch 2/50, Loss: 0.8134
Epoch 3/50, Loss: 0.8129
Epoch 4/50, Loss: 0.8371
Epoch 5/50, Loss: 0.8101
Epoch 6/50, Loss: 0.8170
Epoch 7/50, Loss: 0.8148
Epoch 8/50, Loss: 0.8173
Epoch 9/50, Loss: 0.8192
Epoch 10/50, Loss: 0.8143
Epoch 11/50, Loss: 0.8013
Epoch 12/50, Loss: 0.8031
Epoch 13/50, Loss: 0.8149
Epoch 14/50, Loss: 0.8001
Epoch 15/50, Loss: 0.8349
Epoch 16/50, Loss: 0.7954
Epoch 17/50, Loss: 0.8040
Epoch 18/50, Loss: 0.8041
Epoch 19/50, Loss: 0.8168
Epoch 20/50, Loss: 0.8267
Epoch 21/50, Loss: 0.8008
Epoch 22/50, Loss: 0.8088
Epoch 23/50, Loss: 0.8420
Epoch 24/50, Loss: 0.8129
Epoch 25/50, Loss: 0.8388
Epoch 26/50, Loss: 0.8095
Epoch 27/50, Loss: 0.8310
Epoch 28/50, Loss: 0.8234
Epoch 29/50, Loss: 0.8340
Epoch 30/50, Loss: 0.8113
Epoch 31/50, Loss: 0.8235
Epoch 32/50, Loss: 0.8172
Epoch 33/50, Loss: 0.8279
Epoch 34/50, Loss: 0.8413
Epoch 35/50, Loss: 0.8338
Epoch 36/50, Loss: 0.8187
Epoch 37/50, Loss: 0.8214
Epoch 38/50, Loss: 0.8081
Epoch 39/50, Loss: 0.

In [65]:
game.reset()
node = mcts.run(game.state)
node.children

{0: tensor([1., 1., 2., 1., 1., 1.]) Prior: 0.02 Count: 1 Value: tensor([0.8859]),
 1: tensor([1., 1., 1., 2., 1., 1.]) Prior: 0.00 Count: 1 Value: tensor([0.7036]),
 2: tensor([1., 2., 1., 1., 1., 1.]) Prior: 0.90 Count: 974 Value: tensor([-0.1474]),
 3: tensor([1., 1., 2., 1., 1., 1.]) Prior: 0.02 Count: 1 Value: tensor([0.8859]),
 4: tensor([1., 1., 1., 2., 1., 1.]) Prior: 0.00 Count: 1 Value: tensor([0.7036]),
 5: tensor([2., 1., 1., 1., 1., 1.]) Prior: 0.06 Count: 22 Value: tensor([-0.0011])}

In [64]:
game.reset()
game_continues = True
while game_continues:
    game.print_state(game.state)
    root = mcts.run(game.state)
    print(f"Action probabilities:{root.action_probs()}")
    action = root.sample_action()  # sample action according to visit counts
    # action = root.best_action()  # choose the most visited action
    game_continues = game.play(action)
    print(f"Action taken:{game.describe_action(action)}")
    print("-" * 30)
    if not game_continues:
        game.print_state(game.state)
        reward = game.reward(game.state)
        if reward == 0:
            print("It's a draw!")
        else:
            winner = 1 - game.state[-1]
            print(f"Player {int(winner + 1)} wins!")
        print("Game over.")


  Rounds Played: 0
  P1: Left=1, Right=1
  P2: Left=1, Right=1
  Current Player: P1
Action probabilities:{0: 0.001, 1: 0.001, 2: 0.974, 3: 0.001, 4: 0.001, 5: 0.022}
Action taken:Current left hand taps current right hand
------------------------------

  Rounds Played: 1
  P1: Left=1, Right=2
  P2: Left=1, Right=1
  Current Player: P2
Action probabilities:{0: 0.001, 1: 0.591, 2: 0.143, 3: 0.001, 4: 0.217, 5: 0.047}
Action taken:Current right hand taps opponent's right hand
------------------------------

  Rounds Played: 2
  P1: Left=1, Right=3
  P2: Left=1, Right=1
  Current Player: P1
Action probabilities:{0: 0.001, 1: 0.0, 2: 0.0, 3: 0.015, 4: 0.001, 5: 0.065, 7: 0.918}
Action taken:Split fingers: left hand gets 2 fingers
------------------------------

  Rounds Played: 3
  P1: Left=2, Right=2
  P2: Left=1, Right=1
  Current Player: P2
Action probabilities:{0: 0.002, 1: 0.001, 2: 0.001, 3: 0.005, 4: 0.001, 5: 0.99}
Action taken:Current right hand taps current left hand
------------

# Testing

In [None]:
game = ChopsticksGame()
game.test_get_next_state()

In [None]:
game.test_get_legal_actions()

In [None]:
game.test_reward()