In [3]:
import numpy as np
from players.base import play_game
from players.mcts import SearchParameters, PlayerMCTS
from players.strategy import StrategyTokenProducer

np.random.seed(3)

mcts_params = SearchParameters(
    num_simulations=100
)
from network.checkpoints import Checkpoint
ckpt = Checkpoint.from_json_file("./network/data/checkpoints/tr/2.json")

player_mcts = PlayerMCTS(ckpt.params, ckpt.model.create_caching_model(), mcts_params)

player = player_mcts

result = play_game(player, player, print_board=True, token_producer=StrategyTokenProducer())

print(result.tokens1)
print(result.tokens2)

|   [31mR[0m  [31mR[0m  [31mR[0m  [31mR[0m   | |   [34mB[0m  [34mB[0m  [31mR[0m  [31mR[0m   |
|[34mB[0m     [34mB[0m  [34mB[0m  [34mB[0m   | |   [34mB[0m  [31mR[0m  [34mB[0m  [31mR[0m   |
|                | |                |
|                | |                |
|   5  5  5  5   | |   5  5  5     5|
|   5  5  5  5   | |   5  5  5  5   |
blue=0 red=0 blue=0 red=0
0
|   [31mR[0m  [31mR[0m  [31mR[0m  [31mR[0m   | |   [34mB[0m  [34mB[0m  [31mR[0m  [31mR[0m   |
|[34mB[0m     [34mB[0m  [34mB[0m  [34mB[0m   | |   [34mB[0m  [31mR[0m     [31mR[0m   |
|                | |         [34mB[0m      |
|      5         | |                |
|   5     5  5   | |   5  5  5     5|
|   5  5  5  5   | |   5  5  5  5   |
blue=0 red=0 blue=0 red=0
1
|   [31mR[0m  [31mR[0m  [31mR[0m  [31mR[0m   | |   [34mB[0m  [34mB[0m  [31mR[0m  [31mR[0m   |
|[34mB[0m        [34mB[0m  [34mB[0m   | |   [34mB[0m  [31mR[0m     [31mR[0m   

IndexError: index 23 is out of bounds for axis 0 with size 2

In [40]:
import numpy as np
from tqdm import tqdm

import env.state as game
from env.state import State
from game_analytics import state_to_str
from players.base import is_action_to_enter_deadlock
import batch


def play_game(
    tokens: np.ndarray,
    colors: np.ndarray,
    actions: np.ndarray,
    print_board: bool = True
):
    state = State.create(tokens[:8, game.Token.COLOR])
    state.board[game.POS_P] = tokens[:8, game.Token.X] + tokens[:8, game.Token.Y] * 6

    if tokens[8, game.Token.ID] < 8:
        init_t = 1
        state.board[game.POS_P, tokens[8, game.Token.ID]] = tokens[8, game.Token.X] + tokens[8, game.Token.Y] * 6
    else:
        init_t = 0

    tokens_ = np.zeros((len(tokens), 7), dtype=np.uint8)
    tokens_[:, :5] = tokens
    tokens = tokens_

    turn_player = -1

    attacked_count = np.zeros((2, 8))
    last_capturing_t = 0

    for t in range(init_t, tokens[:, game.Token.T].max()):
        p = 0 if turn_player == 1 else 1

        action = actions[tokens[:, game.Token.T] == t][0]

        if is_action_to_enter_deadlock(state, action, turn_player) and turn_player == 1:
            p_id, _ = game.action_to_id(action)
            attacked_count[p, p_id] = 1

        state, result = state.step(action, turn_player)
        tokens_i = result.tokens

        for a in result.afterstates:
            state, result = state.step_afterstate(a, colors[a.piece_id])
            tokens_i += result.tokens

        if any([t[game.Token.X] == 6 for t in tokens_i]) and turn_player == -1:
            mask = tokens[:, game.Token.T] == last_capturing_t
            mask[np.all(tokens == 0, axis=-1)] = 0
            # mask *= tokens[:, game.Token.X] == 6

            col_p = state.board[game.COL_P]

            count_r = np.sum((attacked_count[0] == 1) * (col_p == game.RED))
            count_b = np.sum((attacked_count[0] == 1) * (col_p == game.BLUE))

            tokens[mask, 5] = count_b * 5 + count_r

            attacked_count[0] = 0

            last_capturing_t = t

        if print_board:
            s = state_to_str(
                state=state,
                predicted_color=[0.5]*8,
                colored=True
            )
            print(s)
            print(t)

        if result.winner != 0:
            break

        turn_player = -turn_player

    # print(tokens[:, game.Token.T].max())
    return tokens[:, 5]


b = batch.load("./data/replay_buffer/run-7.npy")
b = b.reshape((-1, b.shape[-1]))
tokens, actions, rewards, colors = batch.FORMAT_XARC.astuple(b)

print(tokens.shape)

count = np.zeros(tokens.shape[:2], dtype=np.int16)

for i in tqdm(range(10000)):
    if tokens[i, 0, game.Token.Y] > 2:
        continue

    count[i] = play_game(
        tokens=tokens[i],
        colors=colors[i],
        actions=actions[i],
        print_board=False
    )

print(np.bincount(count.flatten()))

(1268736, 220, 5)


100%|██████████| 10000/10000 [00:27<00:00, 369.30it/s]


[279088733     22214       144         1         0     10347       376
         8         0         0        93         4]
