In [1]:
import numpy as np

from batch import FORMAT_XARC, load
from players.base import play_game
from players.simple import PlayerTracing
from players.strategy import StrategyTokenProducer
import  env.state as game

def create_pos_history_from_tokens(tokens: np.ndarray, color_o: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    pos_history = np.zeros((tokens.shape[0], 16), dtype=np.uint8)
    action_history = np.zeros((tokens.shape[0]), dtype=np.uint8)

    if tokens[0, game.Token.Y] < 3:
        pos = np.array([1, 2, 3, 4, 7, 8, 9, 10, 25, 26, 27, 28, 31, 32, 33, 34])
    else:
        pos = np.array([1, 2, 3, 4, 7, 8, 9, 10, 25, 26, 27, 28, 31, 32, 33, 34])

        empty_mask = np.any(tokens != 0, axis=1)

        id_p_mask = empty_mask * (tokens[:, game.Token.ID] < 8)
        id_o_mask = empty_mask * (tokens[:, game.Token.ID] >= 8)

        tokens[id_p_mask, game.Token.ID] = 7 - tokens[id_p_mask, game.Token.ID]
        tokens[id_o_mask, game.Token.ID] = 7 - (tokens[id_o_mask, game.Token.ID] - 8) + 8

        mask = (tokens[:, game.Token.X] != 6) * (tokens[:, game.Token.Y] != 6) * empty_mask

        tokens[mask, game.Token.X] = 5 - tokens[mask, game.Token.X]
        tokens[mask, game.Token.Y] = 5 - tokens[mask, game.Token.Y]

        tokens[:8] = tokens[:8][::-1]

    diff_mask = pos[:8] != (tokens[:8, game.Token.X] + tokens[:8, game.Token.Y] * 6)

    if np.any(diff_mask):
        assert np.sum(diff_mask) == 1, f"{pos[:8]}, {tokens[:8, game.Token.X]}, {tokens[:8, game.Token.Y]}"

        diff_id = np.arange(8)[diff_mask][0]

        tokens[9:] = tokens[8:-1]
        tokens[9:, game.Token.T] += 1

        tokens[8] = tokens[diff_id]
        tokens[8, game.Token.T] = 1

    for i, (c, id, x, y, t) in enumerate(tokens):
        if np.all(tokens[i] == 0):
            break

        if x < 6 and y < 6:
            pos[id] = x + 6 * y
        else:
            pos[id] = 36

        pos_history[t] = pos

        if t > 0 and x < 6 and y < 6 and (tokens[i - 1, game.Token.T] != tokens[i, game.Token.T]):
            d = int(pos_history[t, id]) - int(pos_history[t - 1, id])

            if d == -6:
                d_i = 0
            elif d == -1:
                d_i = 1
            elif d == 1:
                d_i = 2
            elif d == 6:
                d_i = 3
            else:
                assert False, f"{(c, id, x, y, t)}, {pos_history[t, id]}, {pos_history[t - 1, id]}"

            if id < 8:
                action_history[t - 1] = id * 4 + d_i
            else:
                action_history[t - 1] = (15 - id) * 4 + 3 - d_i
            
            last_t = t

    if tokens[8, game.Token.ID] < 8:
        return action_history[:last_t], tokens[:8, game.Token.COLOR], color_o[::-1]
    else:
        return action_history[:last_t], color_o, tokens[:8, game.Token.COLOR]

batch = load("../data/replay_buffer/run-4.npy")
batch = batch.reshape(-1, batch.shape[-1])
print(batch.shape)


import multiprocessing

def func(b):
    t, _, r, c = FORMAT_XARC.astuple(b)
    action_history, color_p, color_o = create_pos_history_from_tokens(t, c)
    #  print(action_history, len(action_history))
    player = PlayerTracing(action_history)
    result = play_game(
        player, player,
        color1=color_p,
        color2=color_o,
        token_producer=StrategyTokenProducer(),
        print_board=False,
        game_length=199
    )

    return result.create_sample_p(token_length=220)

if True:
    pool = multiprocessing.Pool(20)
    results = pool.map(func, iterable=batch[:])
else:
    results = []
    for i in range(400):
        # print(i)
        results.append(func(batch[i]))

batch_dst = np.stack(results, axis=0)
print(batch_dst.shape)

np.save("../data/replay_buffer/run-4-st.npy", batch_dst)

(577536, 1329)
not enough length 197
not enough length 198
not enough length 198
not enough length 197
not enough length 198
not enough length 197
not enough length 198
(577536, 1769)


In [1]:
import numpy as np

from batch import load

batch_org = load("../data/replay_buffer/run-4-st.npy")

batch_org = batch_org.reshape(-1, batch_org.shape[-1])

indices = np.arange(len(batch_org))

np.random.seed(1)
np.random.shuffle(indices)

batch_org = batch_org[indices]
print(batch_org.shape)

(577536, 1833)


In [3]:
import numpy as np
from batch import FORMAT_X7_ST_PVC

x, _, _, _, _ = FORMAT_X7_ST_PVC.astuple(batch_org)

print(x.shape)

np.bincount(x[..., 5].flatten()), np.bincount(x[..., 6].flatten())

(577536, 220, 7)


(array([125364278,   1310138,    383504]),
 array([125366215,   1178321,    513384]))

In [3]:
import itertools

import jax
from jax import random
import optax
from network.transformer import TransformerConfig, TrainStateTransformer
from network.train import fit, MinibatchProducerSimple
from network.checkpoints import Checkpoint, CheckpointManager
from batch import FORMAT_X7_ST_PVC

# jax.config.update("jax_debug_nans", True)

batch = batch_org[:500000]

n_train = int(batch.shape[0] * 0.8)
train_batch = batch[:n_train]
test_batch = batch[n_train:]

minibatch_producer = MinibatchProducerSimple(batch_size=1024)

heads = 4,
dims = 256, 
num_layers = 4,

for h, d, n in itertools.product(heads, dims, num_layers):
    model_config = TransformerConfig(
        num_heads=h,
        embed_dim=d,
        num_hidden_layers=n,
        strategy=True,
    )
    model = model_config.create_model()

    init_x, init_st, _, _, _ = FORMAT_X7_ST_PVC.get_features(train_batch[:1])
    print(init_x.shape)

    variables = model.init(random.PRNGKey(0), init_x, init_st)
    state = TrainStateTransformer.create(
        apply_fn=model.apply,
        params=variables['params'],
        tx=optax.adam(learning_rate=0.0005),
        dropout_rng=random.PRNGKey(0),
        epoch=0,
    )

    ckpt_dir = f'../data/checkpoints/tr-st-test'

    checkpoint_manager = CheckpointManager(ckpt_dir)
    checkpoint_manager.save(Checkpoint(state.epoch, model_config, state.params))

    state = fit(
        state, model_config, checkpoint_manager,
        train_batches=train_batch,
        test_batches=test_batch,
        minibatch_producer=minibatch_producer,
        epochs=4,
        log_wandb=False
    )

"""
st
Epoch: 1, Loss: (3.809, 3.415), P: (1.710, 1.431), V: (1.563, 1.495), C: (0.536, 0.488)
Epoch: 2, Loss: (3.412, 3.295), P: (1.428, 1.342), V: (1.497, 1.479), C: (0.487, 0.474)
Epoch: 3, Loss: (3.316, 3.239), P: (1.362, 1.302), V: (1.477, 1.466), C: (0.477, 0.470)
Epoch: 4, Loss: (3.263, 3.197), P: (1.326, 1.277), V: (1.465, 1.456), C: (0.471, 0.465)

not-st
Epoch: 1, Loss: (3.812, 3.409), P: (1.713, 1.426), V: (1.562, 1.496), C: (0.536, 0.487)
Epoch: 2, Loss: (3.402, 3.287), P: (1.422, 1.338), V: (1.497, 1.473), C: (0.484, 0.476)
"""

(1, 220, 7)
save ../data/checkpoints/tr-st-test/0.json


100%|██████████| 390/390 [01:20<00:00,  4.82it/s, loss=4.395]
100%|██████████| 97/97 [00:07<00:00, 13.50it/s, loss=4.205]


Epoch: 1, Loss: (5.047, 4.228), P: (2.670, 2.020), V: (1.675, 1.548), C: (0.703, 0.660)
save ../data/checkpoints/tr-st-test/1.json


  8%|▊         | 30/390 [00:05<01:07,  5.32it/s, loss=4.315]


KeyboardInterrupt: 