In [1]:
%load_ext tensorboard

In [2]:
from pathlib import Path

import jax_verify
import optax
from flax import nnx
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from vml_final.data import CSVDataset, CSVDatasetEpochLoader
from vml_final.model import TemporalConvolutionalNetwork
from vml_final.training import do_eval_epoch, do_train_epoch

In [3]:
project_root = Path("../../")
dset = CSVDataset(project_root / "csv_dataset" / "AB09")


Processing trial: 06

Processing trial: 07

Processing trial: 02

Processing trial: 01

Processing trial: 04

Processing trial: 03

Processing trial: 05


In [4]:
train_loader = CSVDatasetEpochLoader(dset, 64)
eval_loader = CSVDatasetEpochLoader(dset, 512, train=False)

In [65]:
%tensorboard --logdir ../../logs --port 6006

Reusing TensorBoard on port 6006 (pid 842308), started 0:22:49 ago. (Use '!kill 842308' to kill it.)

In [69]:
rngs = nnx.Rngs(0)

model = TemporalConvolutionalNetwork(
    input_channels=dset.x.shape[-1],
    extractor_hidden_features=4,
    extractor_groups=1,
    extractor_kernel_size=5,
    hidden_dims=[4, 4, 16],
    kernel_size=3,
    dropout=0.2,
    rngs=rngs,
)

optim = nnx.Optimizer(model, optax.adam(2.5e-3))

In [70]:
writer = SummaryWriter(project_root / "logs")

In [71]:
num_epochs = 128

pbar = tqdm(total=num_epochs)
for i in range(num_epochs):
    train_loss = do_train_epoch(optim, train_loader, pbar=False)
    # print(f"Epoch {i} loss: {epoch_loss:2.5f}")
    validation_loss = do_eval_epoch(optim, eval_loader, pbar=False)
    pbar.update()
    pbar.set_postfix({"Train Loss": train_loss, "Val Loss": validation_loss})
    writer.add_scalars(
        "loss",
        {"train": train_loss, "validation": validation_loss},
        global_step=optim.step,
    )

pbar.close()

100%|██████████| 128/128 [00:44<00:00,  2.90it/s, Train Loss=0.0177, Val Loss=0.103]


In [72]:
nnx.display(optim.model)