In [11]:
import numpy as np

from network.cnn import create_pos_history_from_tokens
from batch import load, get_tokens, get_color, get_action, get_reward


batch = load("../data/replay_buffer/run-7.npy")
batch = batch.reshape(-1, batch.shape[-1])
tokens = get_tokens(batch[:10000])

print(tokens.shape)

def create_layers(pos_p, pos_o, color_p, color_o) -> np.ndarray:
    layers = np.zeros((36, 12))

    pos_p_mask = (0 <= pos_p) & (pos_p < 36)
    pos_o_mask = (0 <= pos_o) & (pos_o < 36)

    layers[pos_p[pos_p_mask], 0] = color_p[pos_p_mask]
    layers[pos_p[pos_p_mask], 1] = 1 - color_p[pos_p_mask]

    layers[pos_o[pos_o_mask], 2] = color_o[pos_o_mask]
    layers[pos_o[pos_o_mask], 3] = 1 - color_o[pos_o_mask]

    n_cap_r = np.sum((color_o == 0) & (pos_o == -1))
    n_cap_b = np.sum((color_o == 1) & (pos_o == -1))

    layers[:, 4 + n_cap_r] = 1
    layers[:, 8 + n_cap_b] = 1

    return layers


batch_cnn = np.zeros((tokens.shape[0], tokens.shape[1], 12 * 36 + 2))

for i, batch_i in enumerate(batch[:1]):
    tokens_i = get_tokens(batch_i)

    color_p = tokens_i[:8, 0]
    color_o = get_color(batch_i)

    a = get_action(batch_i)
    r = get_reward(batch_i)

    pos_history = create_pos_history_from_tokens(tokens_i)

    for j, pos in enumerate(pos_history):
        layers = create_layers(pos[:8], pos[8:], color_p, color_o)
        if tokens_i[0, 3] == 5:
            layers = layers[]

        batch_cnn[i, j, :12 * 36] = layers.reshape(12 * 36)
        batch_cnn[i, j, 12 * 36] = a[j]
        batch_cnn[i, j, 12 * 36 + 1] = r

batch = batch_cnn.reshape((-1, 12 * 36 + 2))

(10000, 220, 5)
[25 26 27 28 31 32 33 34]
[[0. 0. 1. 0. 1. 0.]
 [0. 1. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]

[25 26 27 28 31 32 33 34]
[[0. 0. 1. 0. 1. 0.]
 [0. 1. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]

[25 26 27 28 31 32 33 34]
[[0. 0. 1. 0. 1. 0.]
 [0. 1. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]

[25 26 27 28 31 32 33 34]
[[0. 0. 1. 0. 1. 0.]
 [0. 1. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]

[25 26 27 28 31 32 33 34]
[[0. 0. 1. 0. 1. 0.]
 [0. 1. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]

[25 26 27 28 31 32 33 34]
[[0. 0. 1. 0. 1. 0.]
 [0. 1. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]

[25 26 27 28 31 32 33 34]
[[0. 0. 1. 0. 1. 0.]
 [0. 1. 1. 0.

In [8]:
import itertools

import optax
from jax import random
from network.cnn import CNNConfig, TrainStateCNN
from network.train import fit, MinibatchProducerSimple
from network.checkpoints import Checkpoint, CheckpointManager


n_train = int(batch.shape[0] * 0.8)
train_batch = batch[:n_train]
test_batch = batch[n_train:]

minibatch_producer = MinibatchProducerSimple(batch_size=16)

heads = 4,
dims = 256,
num_layers = 4,

for h, d, n in itertools.product(heads, dims, num_layers):
    model_config = CNNConfig(num_filters=[64, 64])
    model = model_config.create_model()

    init_x = train_batch[0, :12*36].reshape(1, 6, 6, 12)

    variables = model.init(random.PRNGKey(0), init_x)
    state = TrainStateCNN.create(
        apply_fn=model.apply,
        params=variables['params'],
        tx=optax.adam(learning_rate=0.0005),
        dropout_rng=random.PRNGKey(0),
        epoch=0,
    )

    ckpt_dir = f'./data/checkpoints/cnn'

    checkpoint_manager = CheckpointManager(ckpt_dir)
    checkpoint_manager.save(Checkpoint(state.epoch, model_config, state.params))

    state = fit(
        state, model_config, checkpoint_manager,
        train_batches=train_batch,
        test_batches=test_batch,
        minibatch_producer=minibatch_producer,
        epochs=4,
        log_wandb=False
    )

  0%|          | 0/110000 [00:00<?, ?it/s]


TypeError: cannot reshape array of shape (16,) (size 16) into shape (-1, 6, 6, 12) because the product of specified axis sizes (432) does not evenly divide 16