In [None]:
import dataclasses
import numpy as np
from players.base import play_game
from players.mcts import SearchParameters, PlayerMCTS
from players.strategy import StrategyTokenProducer, Strategy

np.random.seed(3)

mcts_params = SearchParameters(
    num_simulations=100
)
from network.checkpoints import Checkpoint
ckpt = Checkpoint.from_json_file("./data/checkpoints/tr-st/4.json")

st1 = Strategy(table=Strategy.create_empty_table())
st1.table[1, :, :] = 1

st2 = Strategy(table=Strategy.create_empty_table())
st2.table[1, :, :] = 0

player_mcts = PlayerMCTS(
    ckpt.params,
    ckpt.model.create_caching_model(),
    mcts_params
)

strategy = np.zeros((2, 4, 4, 2, 2), dtype=np.uint8)
win_count = np.zeros((7), dtype=np.uint8)

for i in range(1000):
    result = play_game(
        player1=dataclasses.replace(player_mcts, strategy=st1),
        player2=dataclasses.replace(player_mcts, strategy=st2),
        print_board=False,
        token_producer=StrategyTokenProducer()
    )

    win_count[int(result.win_type.value) * result.winner + 3] += 1

    strategy[0] += StrategyTokenProducer.create_strategy_table(result.tokens1)
    strategy[1] += StrategyTokenProducer.create_strategy_table(result.tokens2)

    print(i, win_count)
    print(strategy[0, :, :, 1, 0].sum(), strategy[0, :, :, 1, 1].sum())
    print(strategy[1, :, :, 1, 0].sum(), strategy[1, :, :, 1, 1].sum())
