In [1]:
import numpy as np
from tqdm import tqdm

from batch import load, get_tokens, get_color, get_action, get_reward, get_seq_len

import numpy as np

def create_pos_history_from_tokens(tokens: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    pos_history = np.zeros((tokens.shape[0], 16), dtype=np.uint8)
    action_history = np.zeros((tokens.shape[0]), dtype=np.uint8)

    if tokens[0, 3] < 3:
        invert = False
        pos = np.array([1, 2, 3, 4, 7, 8, 9, 10, 25, 26, 27, 28, 31, 32, 33, 34])
    else:
        invert = True
        pos = 35 - np.array([25, 26, 27, 28, 31, 32, 33, 34, 1, 2, 3, 4, 7, 8, 9, 10])

    for i, (c, id, x, y, t) in enumerate(tokens):
        if np.all(tokens[i] == 0):
            break

        if x < 6 and y < 6:
            pos[id] = x + 6 * y
            if invert:
                pos[id] = 35 - pos[id]
        else:
            pos[id] = 36

        pos_history[t] = pos

        if t > 1 and x < 6 and y < 6 and (tokens[i - 1, 4] != tokens[i, 4]):
            d = int(pos_history[t, id]) - int(pos_history[t - 1, id])

            if d == -6:
                d_i = 0
            elif d == -1:
                d_i = 1
            elif d == 1:
                d_i = 2
            elif d == 6:
                d_i = 3
            else:
                assert False, f"{(c, id, x, y, t)}, {pos_history[t, id]}, {pos_history[t - 1, id]}"

            action_history[t - 1] = pos_history[t - 1, id] * 4 + d_i
    
    return pos_history, action_history

batch = load("../data/replay_buffer/run-7.npy")
batch = batch.reshape(-1, batch.shape[-1])
print(batch.shape)

seq_len = get_seq_len(batch.shape[-1])

batch_new = np.zeros((batch.shape[0], batch.shape[1] + 16 * seq_len), dtype=np.uint8)

for i, batch_i in tqdm(list(enumerate(batch))):
    t = get_tokens(batch_i)
    c = get_color(batch_i)
    a = get_action(batch_i)
    r = get_reward(batch_i)

    p, a = create_pos_history_from_tokens(t)

    if t[0, 3] > 3:
        mask = np.all(t == 0, axis=-1)

        t[t[:, 2] != 6, 2] = 5 - t[t[:, 2] != 6, 2]
        t[t[:, 3] != 6, 3] = 5 - t[t[:, 3] != 6, 3]

        t[mask] = 0

        c = c[::-1]

    a = a[t[:, 4]]
    p = p[t[:, 4]]

    p = p.reshape(seq_len * p.shape[-1])
    t = t.reshape(seq_len * t.shape[-1])

    batch_new[i] = np.concatenate(
        [p, t, a, np.array([r]), c],
        axis=-1,
        dtype=np.uint8
    )

np.save("../data/replay_buffer/run-7-cnn.npy", batch_new)

(1268736, 1329)


100%|██████████| 1268736/1268736 [09:53<00:00, 2138.63it/s]


In [None]:

from jax import numpy as jnp
from network.transformer import create_concat_input
from batch import get_tokens, get_posses, get_color


def pos_to_board(
    pos1: jnp.ndarray,
    pos2: jnp.ndarray,
    color1: jnp.ndarray,
    color2: jnp.ndarray
) -> jnp.ndarray:
    batch_shape = pos1.shape[:-1]

    pos1 = pos1.reshape(-1, 8)
    pos2 = pos2.reshape(-1, 8)
    color1 = color1.reshape(-1, 8)
    color2 = color2.reshape(-1, 8)

    def scan_f(x_i) -> jnp.ndarray:
        p1, p2, c1, c2 = [x_i[i*8: (i+1)*8] for i in range(4)]

        board = jnp.zeros((37, 4), dtype=jnp.uint8)
        board = board.at[p1, 0].set(c1)
        board = board.at[p1, 1].set(255 - c1)
        board = board.at[p2, 2].set(c2)
        board = board.at[p2, 3].set(255 - c2)

        return None, board

    xs = jnp.concatenate([pos1, pos2, color1, color2], axis=-1, dtype=jnp.uint8)

    _, board = jnp.apply_along_axis(scan_f, axis=-1, arr=xs)

    board = board[..., :36, :].reshape((*batch_shape, 6, 6, 4))

    return board

j = 3

x = get_tokens(batch_org[j: j+10])
pos = get_posses(batch_org[j: j+10])
col = get_color(batch_org[j: j+10])
concat = create_concat_input(x, pos, col)

color_1 = jnp.stack([x[..., :8, 0]]*x.shape[-2], axis=-2) * 200 + 20
color_2 = jnp.stack([col]*x.shape[-2], axis=-2) * 200 + 20
board = pos_to_board(pos[..., :8], pos[..., 8:], color_1, color_2)
# board = board.astype(jnp.float16) / 255.0

for j in range(10):
    for i in range(20, 21):
        print(i, x[j, i, :], concat[j, i])
        print()
        print(pos[j, i, :8], pos[j, i, 8:])
        print(color_1[j, i], color_2[j, i])
        print()
        print(board[j, i, :, :, 0])
        print()
        print(board[j, i, :, :, 1])
        print()
        print(board[j, i, :, :, 2])
        print()
        print(board[j, i, :, :, 3])
        print()
        print()


In [1]:
import numpy as np

from batch import load, astuple, get_reward

batch_org = load("../data/replay_buffer/run-7-cnn.npy")

batch_org = batch_org.reshape(-1, batch_org.shape[-1])

indices = np.arange(len(batch_org))

np.random.seed(1)
np.random.shuffle(indices)

batch_org = batch_org[indices]

In [2]:
import itertools

import jax
from jax import random, numpy as jnp
import optax
from network.cnn import CNNConfig
from network.transformer import TransformerConfig, TrainStateTransformer, create_concat_input
from network.train import fit, MinibatchProducerSimple
from network.checkpoints import Checkpoint, CheckpointManager
from batch import get_tokens, get_posses, get_color, get_reward

# jax.config.update("jax_debug_nans", True)

batch = batch_org[:200000]

# r = get_reward(batch)
# batch = batch[get_reward(batch) != 3]

n_train = int(batch.shape[0] * 0.8)
train_batch = batch[:n_train]
test_batch = batch[n_train:]

minibatch_producer = MinibatchProducerSimple(batch_size=16)

heads = 4,
dims = 256,
num_layers = 4,

for h, d, n in itertools.product(heads, dims, num_layers):
    model_config = TransformerConfig(
        num_heads=h,
        embed_dim=d,
        num_hidden_layers=n,
        cnn_config=CNNConfig(num_filters=[128, 128]),
    )
    model = model_config.create_model()

    init_x = get_tokens(train_batch[:1])
    init_pos = get_posses(train_batch[:1])
    init_color = get_color(train_batch[:1])
    init_concat = create_concat_input(init_x, init_pos, init_color)

    # variables = model.init(random.PRNGKey(0), init_x)
    variables = model.init(random.PRNGKey(0), init_x, pos=init_pos, concat=init_concat)
    state = TrainStateTransformer.create(
        apply_fn=model.apply,
        params=variables['params'],
        tx=optax.adam(learning_rate=0.0005),
        dropout_rng=random.PRNGKey(0),
        epoch=0,
    )

    ckpt_dir = f'./data/checkpoints/tr'

    checkpoint_manager = CheckpointManager(ckpt_dir)
    checkpoint_manager.save(Checkpoint(state.epoch, model_config, state.params))

    state = fit(
        state, model_config, checkpoint_manager,
        train_batches=train_batch,
        test_batches=test_batch,
        minibatch_producer=minibatch_producer,
        epochs=8,
        log_wandb=False
    )

"""
Epoch: 1, Loss: (4.711, 4.095), P: (2.292, 1.753), V: (1.720, 1.660), C: (0.699, 0.683)
Epoch: 2, Loss: (3.987, 3.741), P: (1.717, 1.567), V: (1.639, 1.617), C: (0.631, 0.557)
Epoch: 3, Loss: (3.738, 3.603), P: (1.583, 1.479), V: (1.609, 1.601), C: (0.546, 0.524)
Epoch: 4, Loss: (3.619, 3.527), P: (1.510, 1.426), V: (1.589, 1.588), C: (0.520, 0.513)
Epoch: 5, Loss: (3.547, 3.460), P: (1.463, 1.390), V: (1.576, 1.575), C: (0.508, 0.495)
Epoch: 6, Loss: (3.491, 3.429), P: (1.428, 1.367), V: (1.563, 1.568), C: (0.500, 0.494)
Epoch: 7, Loss: (3.450, 3.396), P: (1.401, 1.344), V: (1.554, 1.563), C: (0.494, 0.488)
Epoch: 8, Loss: (3.416, 3.371), P: (1.380, 1.328), V: (1.546, 1.558), C: (0.491, 0.485)
"""

2024-07-29 17:54:07.341113: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
100%|██████████| 2500/2500 [01:07<00:00, 36.78it/s, loss=3.543]
100%|██████████| 625/625 [00:07<00:00, 78.19it/s, loss=4.163] 


Epoch: 1, Loss: (4.543, 3.778), P: (2.148, 1.441), V: (1.696, 1.646), C: (0.699, 0.691)


100%|██████████| 2500/2500 [00:50<00:00, 49.22it/s, loss=3.305]
100%|██████████| 625/625 [00:04<00:00, 125.27it/s, loss=3.872]


Epoch: 2, Loss: (3.661, 3.540), P: (1.372, 1.255), V: (1.599, 1.604), C: (0.690, 0.681)


100%|██████████| 2500/2500 [00:50<00:00, 49.28it/s, loss=3.152]
100%|██████████| 625/625 [00:04<00:00, 125.13it/s, loss=3.661]


Epoch: 3, Loss: (3.443, 3.310), P: (1.251, 1.185), V: (1.573, 1.587), C: (0.619, 0.538)


100%|██████████| 2500/2500 [00:50<00:00, 49.38it/s, loss=2.985]
100%|██████████| 625/625 [00:05<00:00, 123.35it/s, loss=3.540]


Epoch: 4, Loss: (3.273, 3.226), P: (1.183, 1.136), V: (1.552, 1.573), C: (0.538, 0.517)


100%|██████████| 2500/2500 [00:51<00:00, 48.48it/s, loss=2.904]
100%|██████████| 625/625 [00:05<00:00, 118.65it/s, loss=3.487]


Epoch: 5, Loss: (3.186, 3.166), P: (1.134, 1.103), V: (1.536, 1.560), C: (0.516, 0.503)


100%|██████████| 2500/2500 [00:52<00:00, 47.37it/s, loss=2.834]
100%|██████████| 625/625 [00:05<00:00, 116.93it/s, loss=3.460]


Epoch: 6, Loss: (3.128, 3.125), P: (1.100, 1.082), V: (1.524, 1.550), C: (0.505, 0.492)


100%|██████████| 2500/2500 [00:53<00:00, 47.08it/s, loss=2.809]
100%|██████████| 625/625 [00:05<00:00, 117.36it/s, loss=3.438]


Epoch: 7, Loss: (3.086, 3.101), P: (1.074, 1.064), V: (1.514, 1.547), C: (0.498, 0.489)


100%|██████████| 2500/2500 [00:53<00:00, 46.57it/s, loss=2.768]
100%|██████████| 625/625 [00:05<00:00, 119.72it/s, loss=3.403]


Epoch: 8, Loss: (3.050, 3.071), P: (1.054, 1.048), V: (1.506, 1.541), C: (0.490, 0.482)


'\nEpoch: 1, Loss: (4.711, 4.095), P: (2.292, 1.753), V: (1.720, 1.660), C: (0.699, 0.683)\nEpoch: 2, Loss: (3.987, 3.741), P: (1.717, 1.567), V: (1.639, 1.617), C: (0.631, 0.557)\nEpoch: 3, Loss: (3.738, 3.603), P: (1.583, 1.479), V: (1.609, 1.601), C: (0.546, 0.524)\nEpoch: 4, Loss: (3.619, 3.527), P: (1.510, 1.426), V: (1.589, 1.588), C: (0.520, 0.513)\nEpoch: 5, Loss: (3.547, 3.460), P: (1.463, 1.390), V: (1.576, 1.575), C: (0.508, 0.495)\nEpoch: 6, Loss: (3.491, 3.429), P: (1.428, 1.367), V: (1.563, 1.568), C: (0.500, 0.494)\nEpoch: 7, Loss: (3.450, 3.396), P: (1.401, 1.344), V: (1.554, 1.563), C: (0.494, 0.488)\nEpoch: 8, Loss: (3.416, 3.371), P: (1.380, 1.328), V: (1.546, 1.558), C: (0.491, 0.485)\n'

In [7]:
import flax.linen
import jax
from jax import random, numpy as jnp
import flax
from network.cnn import pos_to_board
from network.transformer import TransformerConfig, TrainStateTransformer, create_concat_input
from network.train import fit, MinibatchProducerSimple
from network.checkpoints import Checkpoint, CheckpointManager
from batch import get_tokens, get_posses, get_color

from env.state import State
from game_analytics import state_to_str

ckpt = Checkpoint.from_json_file("data/checkpoints/tr/4.json")

model = ckpt.model.create_model()

tokens = get_tokens(train_batch[:1])

p, v, c = model.apply({"params": ckpt.params}, tokens)

for i in range(100):
    print(", ".join([f"{f:.3f}" for f in flax.linen.softmax(v[0, i])]))

0.313, 0.082, 0.172, 0.006, 0.251, 0.037, 0.139
0.312, 0.080, 0.172, 0.006, 0.252, 0.038, 0.141
0.310, 0.080, 0.172, 0.006, 0.253, 0.038, 0.141
0.307, 0.081, 0.171, 0.006, 0.255, 0.038, 0.141
0.305, 0.080, 0.172, 0.006, 0.257, 0.038, 0.141
0.304, 0.081, 0.172, 0.006, 0.257, 0.038, 0.142
0.304, 0.081, 0.171, 0.006, 0.258, 0.039, 0.142
0.303, 0.080, 0.171, 0.006, 0.258, 0.039, 0.142
0.260, 0.093, 0.178, 0.007, 0.268, 0.042, 0.152
0.289, 0.090, 0.179, 0.008, 0.251, 0.038, 0.145
0.291, 0.088, 0.169, 0.009, 0.280, 0.037, 0.126
0.298, 0.089, 0.180, 0.009, 0.259, 0.035, 0.130
0.303, 0.092, 0.169, 0.010, 0.261, 0.039, 0.125
0.302, 0.090, 0.175, 0.010, 0.271, 0.033, 0.118
0.303, 0.093, 0.173, 0.012, 0.261, 0.040, 0.118
0.307, 0.090, 0.172, 0.010, 0.273, 0.031, 0.117
0.304, 0.092, 0.176, 0.011, 0.263, 0.037, 0.118
0.305, 0.091, 0.171, 0.012, 0.274, 0.031, 0.115
0.305, 0.094, 0.178, 0.011, 0.255, 0.039, 0.117
0.306, 0.091, 0.172, 0.012, 0.276, 0.031, 0.113
0.295, 0.098, 0.151, 0.008, 0.289, 0.037

In [42]:
import flax.linen
import jax
from jax import random, numpy as jnp
import flax
from network.cnn import pos_to_board
from network.transformer import TransformerConfig, TrainStateTransformer, create_concat_input
from network.train import fit, MinibatchProducerSimple
from network.checkpoints import Checkpoint, CheckpointManager
from batch import get_tokens, get_posses, get_color

from env.state import State
from game_analytics import state_to_str

ckpt = Checkpoint.from_json_file("data/checkpoints/tr/4.json")

model = ckpt.model.cnn_config.create_model()

pos1 = jnp.array([[1, 2, 3, 20, 7, 8, 9, 35]], dtype=jnp.uint8)
col1 = jnp.array([[0, 0, 0, 0, 1, 1, 1, 1]], dtype=jnp.uint8) * 255

pos2 = jnp.array([[21, 33, 32, 31, 26, 36, 36, 36]], dtype=jnp.uint8)
col2 = jnp.array([[255, 128, 128, 128, 128, 128, 128, 128]], dtype=jnp.uint8)

concat = jnp.array([[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]], dtype=jnp.uint8)

n_cap_r = 3
n_cap_b = 0

concat = concat.at[n_cap_r].set(1)
concat = concat.at[4 + n_cap_b].set(1)

board = pos_to_board(pos1, pos2, col1, col2)

p, v = model.apply({"params": ckpt.params["cnn"]}, board, concat=concat)

print(", ".join([f"{f:.3f}" for f in flax.linen.softmax(v[0])]))

# 0.221, 0.082, 0.315, 0.015, 0.231, 0.062, 0.074
# 0.333, 0.098, 0.185, 0.011, 0.208, 0.060, 0.105

print(state_to_str(State(jnp.array([pos1[0], pos2[0], col1[0] // 255, jnp.array([3]*8)]), n_ply=10), predicted_color=[0.5]*8, colored=True))

0.081, 0.027, 0.067, 0.004, 0.745, 0.042, 0.034
|   [31mR[0m  [31mR[0m  [31mR[0m      |
|   [34mB[0m  [34mB[0m  [34mB[0m      |
|                |
|      [31mR[0m  5      |
|      5         |
|   5  5  5     [34mB[0m|
blue=0 red=0
