In [1]:
%load_ext autoreload
%autoreload 2
from games.chopsticks import ChopsticksGame
import torch

In [2]:
game = ChopsticksGame()
game.test_get_next_state()

CHOPSTICKS GAME - get_next_state() TEST CASES

--- TEST 1: Initial state - P1 left taps P2 left (Action 0) ---
Initial State:
  P1: Left=1, Right=1
  P2: Left=1, Right=1
  Current Player: P1
Action: Current left hand taps opponent's left hand
After Action:
  P1: Left=1, Right=1
  P2: Left=2, Right=1
  Current Player: P2

--- TEST 2: Tap causing hand to die (3+2=5 -> 0) (Action 0) ---
Initial State:
  P1: Left=3, Right=2
  P2: Left=2, Right=1
  Current Player: P1
Action: Current left hand taps opponent's left hand
After Action:
  P1: Left=3, Right=2
  P2: Left=0, Right=1
  Current Player: P2

--- TEST 3: Tap causing overflow (4+2=6 mod 5 = 1) (Action 4) ---
Initial State:
  P1: Left=2, Right=4
  P2: Left=1, Right=2
  Current Player: P1
Action: Current right hand taps opponent's right hand
After Action:
  P1: Left=2, Right=4
  P2: Left=1, Right=0
  Current Player: P2

--- TEST 4: P1 left taps P1 right (Action 2) ---
Initial State:
  P1: Left=2, Right=1
  P2: Left=3, Right=2
  Current Pla

In [3]:
game.test_get_legal_actions()

CHOPSTICKS GAME - get_legal_actions() TEST CASES

--- TEST 1: Initial state - all hands alive ---
State:
  P1: Left=1, Right=1
  P2: Left=1, Right=1
  Current Player: P1
Legal actions: [0, 1, 2, 3, 4, 5]
Descriptions:
  0: Current left hand taps opponent's left hand
  1: Current left hand taps opponent's right hand
  2: Current left hand taps current right hand
  3: Current right hand taps opponent's left hand
  4: Current right hand taps opponent's right hand
  5: Current right hand taps current left hand

--- TEST 2: P1 left hand is dead (0 fingers) ---
State:
  P1: Left=0, Right=3
  P2: Left=2, Right=1
  Current Player: P1
Legal actions: [3, 4, 6, 7]
Descriptions:
  3: Current right hand taps opponent's left hand
  4: Current right hand taps opponent's right hand
  6: Split fingers: left hand gets 1 finger
  7: Split fingers: left hand gets 2 fingers

--- TEST 3: P1 right hand is dead (0 fingers) ---
State:
  P1: Left=3, Right=0
  P2: Left=2, Right=1
  Current Player: P1
Legal actio

In [4]:
game.test_check_winner()

CHOPSTICKS GAME - check_winner() TEST CASES

--- TEST 1: Initial state - no winner ---
State:
  P1: Left=1, Right=1
  P2: Left=1, Right=1
  Current Player: P1
  Winner: None (Game continues)

--- TEST 2: P1 kills P2's last hand (winning move) ---
Initial State:
  P1: Left=1, Right=2
  P2: Left=0, Right=4
  Current Player: P1
Action: Current left hand taps opponent's right hand
After Action:
  P1: Left=1, Right=2
  P2: Left=0, Right=0
  Current Player: P2

--- TEST 3: P2 kills P1's last hand (winning move) ---
Initial State:
  P1: Left=0, Right=3
  P2: Left=2, Right=2
  Current Player: P2
Action: Current right hand taps opponent's right hand
After Action:
  P1: Left=0, Right=0
  P2: Left=2, Right=2
  Current Player: P1

--- TEST 4: P1 kills one hand but P2 still has another - no winner ---
Initial State:
  P1: Left=3, Right=1
  P2: Left=2, Right=1
  Current Player: P1
Action: Current left hand taps opponent's left hand
After Action:
  P1: Left=3, Right=1
  P2: Left=0, Right=1
  Current 

In [None]:
from models.chopsticks import ChopsticksMLP
from sims.tree import MCTS

model = ChopsticksMLP(64)
game = ChopsticksGame()
game.reset()
mcts = MCTS(game, model, {'c_puct': 1.0, 'num_simulations': 1000})
node = mcts.run(game.state)

In [None]:
game.reset()
game_continues = True
mcts = MCTS(game, model, {'c_puct': 1.0, 'num_simulations': 1000})
policy_samples = []
while game_continues:
    root = mcts.run(game.state)
    action_dict = root.action_probs()
    action_probs = torch.zeros(game.action_size())
    for action, prob in action_dict.items():
        action_probs[action] = prob
    policy_samples.append((game.state.clone(), action_probs))
    game_continues = game.play(root.best_action())
    if not game_continues:
        reward = game.check_winner(game.state)
        winner = 1 - game.state[-1] # last element indicates losing player
        samples = []
        for state, action_probs in policy_samples:
            r = 1.0 if winner == (state[-1]) else -1.0
            r = torch.tensor(r, dtype=torch.float32)
            samples.append((state, action_probs, r))
        break

In [55]:
samples

[(tensor([1., 1., 1., 1., 0.]),
  tensor([0.4290, 0.3280, 0.0590, 0.0460, 0.1020, 0.0360, 0.0000, 0.0000, 0.0000,
          0.0000]),
  -1.0),
 (tensor([1., 1., 2., 1., 1.]),
  tensor([0.0450, 0.0780, 0.0860, 0.0400, 0.6870, 0.0640, 0.0000, 0.0000, 0.0000,
          0.0000]),
  1.0),
 (tensor([1., 2., 2., 1., 0.]),
  tensor([0.0870, 0.5410, 0.1550, 0.0640, 0.1210, 0.0320, 0.0000, 0.0000, 0.0000,
          0.0000]),
  -1.0),
 (tensor([1., 2., 2., 2., 1.]),
  tensor([0.1540, 0.2090, 0.0440, 0.1450, 0.2000, 0.0400, 0.1210, 0.0000, 0.0870,
          0.0000]),
  1.0),
 (tensor([1., 4., 2., 2., 0.]),
  tensor([0.1020, 0.0850, 0.0740, 0.2730, 0.2950, 0.0030, 0.0000, 0.0880, 0.0800,
          0.0000]),
  -1.0),
 (tensor([1., 4., 2., 0., 1.]),
  tensor([0.0040, 0.0710, 0.0000, 0.0000, 0.0000, 0.0000, 0.9250, 0.0000, 0.0000,
          0.0000]),
  1.0),
 (tensor([1., 4., 1., 1., 0.]),
  tensor([0.0800, 0.3410, 0.1210, 0.1030, 0.0960, 0.0030, 0.0000, 0.1700, 0.0860,
          0.0000]),
  -1.0),
 (

In [33]:
node.children

{0: tensor([1., 1., 2., 1., 1.]) Prior: 0.10 Count: 18721 Value: tensor([-0.0098]),
 1: tensor([1., 1., 1., 2., 1.]) Prior: 0.12 Count: 6519 Value: tensor([-0.0069]),
 2: tensor([1., 2., 1., 1., 1.]) Prior: 0.09 Count: 1763 Value: tensor([8.5081e-06]),
 3: tensor([1., 1., 2., 1., 1.]) Prior: 0.09 Count: 18548 Value: tensor([-0.0101]),
 4: tensor([1., 1., 1., 2., 1.]) Prior: 0.09 Count: 1804 Value: tensor([-0.0008]),
 5: tensor([2., 1., 1., 1., 1.]) Prior: 0.09 Count: 2645 Value: tensor([-0.0039])}

In [None]:
game.reset()
for action in [0, 1]:
    game.play(action)
game.get_legal_actions(game.state)

Next State: [1, 1, 2, 1, 1]
Next State: [3, 1, 2, 1, 0]
Next State: [3, 1, 0, 1, 1]
Next State: [4, 1, 0, 1, 0]


tensor([1, 2, 4, 5, 7, 8], dtype=torch.int32)