In [3]:
import numpy as np

from batch import load

batch_org = load("../data/replay_buffer/run-4.npy")

indices = np.arange(len(batch_org))

np.random.seed(1)
np.random.shuffle(indices)

batch_org = batch_org[indices]
print(batch_org.shape)

(577536, 1329)


In [2]:
from batch import FORMAT_X5_PVC, FORMAT_X7_ST_PVC

x, _, p, v, c = FORMAT_X7_ST_PVC.get_features(batch_org)

x = x[..., :5]

x.shape, p.shape, v.shape, c.shape

batch_new = FORMAT_X5_PVC.from_tuple(x, p, v, c)

np.save("../data/replay_buffer/run-4.npy", batch_new)

In [10]:
import itertools

import jax
from jax import random
import optax
from network.transformer import TransformerConfig, TrainStateTransformer
from network.train import fit, MinibatchProducerSimple
from network.checkpoints import Checkpoint, CheckpointManager
from batch import FORMAT_X5_PVC

# jax.config.update("jax_debug_nans", True)

batch = batch_org[:500000]

n_train = int(batch.shape[0] * 0.8)
train_batch = batch[:n_train]
test_batch = batch[n_train:]

minibatch_producer = MinibatchProducerSimple(batch_size=256)

heads = 4,
dims = 256, 
num_layers = 4,

for h, d, n in itertools.product(heads, dims, num_layers):
    model_config = TransformerConfig(
        num_heads=h,
        embed_dim=d,
        num_hidden_layers=n,
    )
    model = model_config.create_model()

    init_x, _, _, _ = FORMAT_X5_PVC.get_features(train_batch[:1])
    print(init_x.shape)

    variables = model.init(random.PRNGKey(0), init_x)
    state = TrainStateTransformer.create(
        apply_fn=model.apply,
        params=variables['params'],
        tx=optax.adam(learning_rate=0.001),
        dropout_rng=random.PRNGKey(0),
        epoch=0,
    )

    ckpt_dir = f'../data/checkpoints/tr_h{h}_d{d}_n{n}'

    checkpoint_manager = CheckpointManager(ckpt_dir)
    checkpoint_manager.save(Checkpoint(state.epoch, model_config, state.params))

    state = fit(
        state, model_config, checkpoint_manager,
        train_batches=train_batch,
        test_batches=test_batch,
        minibatch_producer=minibatch_producer,
        epochs=8,
        log_wandb=False
    )

"""
st
Epoch: 1, Loss: (3.809, 3.415), P: (1.710, 1.431), V: (1.563, 1.495), C: (0.536, 0.488)
Epoch: 2, Loss: (3.412, 3.295), P: (1.428, 1.342), V: (1.497, 1.479), C: (0.487, 0.474)
Epoch: 3, Loss: (3.316, 3.239), P: (1.362, 1.302), V: (1.477, 1.466), C: (0.477, 0.470)
Epoch: 4, Loss: (3.263, 3.197), P: (1.326, 1.277), V: (1.465, 1.456), C: (0.471, 0.465)

not-st
Epoch: 1, Loss: (3.812, 3.409), P: (1.713, 1.426), V: (1.562, 1.496), C: (0.536, 0.487)
Epoch: 2, Loss: (3.402, 3.287), P: (1.422, 1.338), V: (1.497, 1.473), C: (0.484, 0.476)
"""

(1, 220, 5)
save ../data/checkpoints/tr_h4_d256_n4/0.json


100%|██████████| 1562/1562 [01:12<00:00, 21.64it/s, loss=3.753]
100%|██████████| 390/390 [00:07<00:00, 54.02it/s, loss=3.586]


Epoch: 1, Loss: (4.213, 3.581), P: (2.037, 1.559), V: (1.594, 1.522), C: (0.582, 0.500)
save ../data/checkpoints/tr_h4_d256_n4/1.json


100%|██████████| 1562/1562 [01:08<00:00, 22.67it/s, loss=3.557]
100%|██████████| 390/390 [00:05<00:00, 65.76it/s, loss=3.424]


Epoch: 2, Loss: (3.544, 3.388), P: (1.541, 1.424), V: (1.506, 1.486), C: (0.496, 0.478)
save ../data/checkpoints/tr_h4_d256_n4/2.json


100%|██████████| 1562/1562 [01:08<00:00, 22.69it/s, loss=3.465]
100%|██████████| 390/390 [00:05<00:00, 65.38it/s, loss=3.335]


Epoch: 3, Loss: (3.411, 3.302), P: (1.444, 1.362), V: (1.485, 1.470), C: (0.482, 0.470)
save ../data/checkpoints/tr_h4_d256_n4/3.json


100%|██████████| 1562/1562 [01:08<00:00, 22.80it/s, loss=3.407]
100%|██████████| 390/390 [00:05<00:00, 65.46it/s, loss=3.306]


Epoch: 4, Loss: (3.337, 3.249), P: (1.391, 1.324), V: (1.472, 1.459), C: (0.475, 0.466)
save ../data/checkpoints/tr_h4_d256_n4/4.json


100%|██████████| 1562/1562 [01:08<00:00, 22.93it/s, loss=3.377]
100%|██████████| 390/390 [00:06<00:00, 64.94it/s, loss=3.274]


Epoch: 5, Loss: (3.288, 3.219), P: (1.355, 1.300), V: (1.462, 1.457), C: (0.470, 0.462)
save ../data/checkpoints/tr_h4_d256_n4/5.json


100%|██████████| 1562/1562 [01:08<00:00, 22.65it/s, loss=3.332]
100%|██████████| 390/390 [00:05<00:00, 65.24it/s, loss=3.240]


Epoch: 6, Loss: (3.252, 3.188), P: (1.330, 1.280), V: (1.455, 1.448), C: (0.467, 0.460)
save ../data/checkpoints/tr_h4_d256_n4/6.json


100%|██████████| 1562/1562 [01:08<00:00, 22.70it/s, loss=3.318]
100%|██████████| 390/390 [00:05<00:00, 65.11it/s, loss=3.220]


Epoch: 7, Loss: (3.225, 3.164), P: (1.312, 1.264), V: (1.449, 1.443), C: (0.464, 0.457)
save ../data/checkpoints/tr_h4_d256_n4/7.json


100%|██████████| 1562/1562 [01:08<00:00, 22.95it/s, loss=3.300]
100%|██████████| 390/390 [00:05<00:00, 65.43it/s, loss=3.209]


Epoch: 8, Loss: (3.203, 3.152), P: (1.297, 1.255), V: (1.445, 1.442), C: (0.461, 0.455)
save ../data/checkpoints/tr_h4_d256_n4/8.json


'\nst\nEpoch: 1, Loss: (3.809, 3.415), P: (1.710, 1.431), V: (1.563, 1.495), C: (0.536, 0.488)\nEpoch: 2, Loss: (3.412, 3.295), P: (1.428, 1.342), V: (1.497, 1.479), C: (0.487, 0.474)\nEpoch: 3, Loss: (3.316, 3.239), P: (1.362, 1.302), V: (1.477, 1.466), C: (0.477, 0.470)\nEpoch: 4, Loss: (3.263, 3.197), P: (1.326, 1.277), V: (1.465, 1.456), C: (0.471, 0.465)\n\nnot-st\nEpoch: 1, Loss: (3.812, 3.409), P: (1.713, 1.426), V: (1.562, 1.496), C: (0.536, 0.487)\nEpoch: 2, Loss: (3.402, 3.287), P: (1.422, 1.338), V: (1.497, 1.473), C: (0.484, 0.476)\n'