In [1]:
%load_ext tensorboard

In [2]:
from pathlib import Path

import jax
import jax_verify
import numpy as np
import optax
from einops import reduce
from flax import nnx
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from vml_final.data import CSVDataset, CSVDatasetEpochLoader
from vml_final.model import TemporalConvolutionalNetwork
from vml_final.training import do_eval_epoch, do_train_epoch

In [3]:
project_root = Path("../../")
dset = CSVDataset(project_root / "csv_dataset" / "AB09")


Processing trial: 06

Processing trial: 07

Processing trial: 02

Processing trial: 01

Processing trial: 04

Processing trial: 03

Processing trial: 05


In [4]:
train_loader = CSVDatasetEpochLoader(dset, 64)
eval_loader = CSVDatasetEpochLoader(dset, 64, train=False)

In [5]:
%tensorboard --logdir ../../logs --port 6006

Reusing TensorBoard on port 6006 (pid 846430), started 1 day, 12:03:34 ago. (Use '!kill 846430' to kill it.)

In [6]:
rngs = nnx.Rngs(0)

model = TemporalConvolutionalNetwork(
    input_channels=dset.x.shape[-1],
    hidden_dims=[32, 64, 128],
    kernel_size=3,
    stride=3,
    dropout=0.0,
    rngs=rngs,
)

optim = nnx.Optimizer(model, optax.adam(1.0e-4))

In [7]:
writer = SummaryWriter(project_root / "logs")

In [8]:
num_epochs = 1024

# pbar.close()

pbar = tqdm(total=num_epochs)
for i in range(num_epochs):
    train_loss = do_train_epoch(optim, train_loader, pbar=False)
    # print(f"Epoch {i} loss: {epoch_loss:2.5f}")
    validation_loss = do_eval_epoch(optim, eval_loader, pbar=False)
    pbar.update()
    pbar.set_postfix({"Train Loss": train_loss, "Val Loss": validation_loss})
    writer.add_scalars(
        "loss",
        {"train": train_loss, "validation": validation_loss},
        global_step=optim.step,
    )

pbar.close()

100%|██████████| 1024/1024 [00:29<00:00, 34.71it/s, Train Loss=0.901, Val Loss=0.908]


In [9]:
do_eval_epoch(optim, eval_loader)

100%|██████████| 4/4 [00:00<00:00, 287.45it/s]


np.float32(0.9264138)

In [10]:
nnx.display(optim.model)

In [11]:
np.unique_counts(dset.y)

UniqueCountsResult(values=array([0.0465, 0.0474, 0.048 , 0.0489, 0.0495, 0.0504, 0.051 , 0.0519,
       0.0525, 0.0534, 0.054 , 0.0549, 0.0555, 0.0564, 0.057 , 0.0579,
       0.0585, 0.0594, 0.06  , 0.0609, 0.0615, 0.0624, 0.063 , 0.0639,
       0.0645, 0.0654, 0.066 , 0.0669, 0.0675, 0.0684, 0.069 , 0.0699,
       0.0705, 0.0714, 0.072 , 0.0729, 0.0735, 0.0744, 0.075 , 0.0759,
       0.0765, 0.0774, 0.078 , 0.0789, 0.0795, 0.0804, 0.081 , 0.0819,
       0.0825, 0.0834, 0.084 , 0.0849, 0.0855, 0.0864, 0.087 , 0.0879,
       0.0885, 0.0894, 0.09  , 0.0909, 0.0915, 0.0924, 0.093 , 0.0939,
       0.0945, 0.0954, 0.096 , 0.0969, 0.0975, 0.0984, 0.099 , 0.0999,
       0.1005, 0.1014, 0.102 , 0.1029, 0.1035, 0.1044, 0.105 , 0.1059,
       0.1065, 0.1074, 0.108 , 0.1089, 0.1095, 0.1104, 0.111 , 0.1119,
       0.1125, 0.1134, 0.114 , 0.1149, 0.1155, 0.1164, 0.117 , 0.1179,
       0.1185, 0.1194, 0.12  , 0.1209, 0.1215, 0.1224, 0.123 , 0.1239,
       0.1245, 0.1254, 0.126 , 0.1269, 0.1275, 0.12

In [12]:
channel_max = reduce(dset.x, "e t c -> c", "max")
channel_min = reduce(dset.x, "e t c -> c", "min")

upper = np.array([channel_max] * dset.x.shape[-2])
lower = np.array([channel_min] * dset.x.shape[-2])

In [13]:
model_graphdef, model_state = nnx.split(model)

model.eval()

@jax.jit
def pure_call(x):
    model = nnx.merge(model_graphdef, model_state)
    return model(x)

In [14]:
output_bound = jax_verify.backward_crown_bound_propagation(
    model,
    jax_verify.IntervalBound(lower, upper),
)

AttributeError: 'ClosedJaxpr' object has no attribute 'invars'

In [None]:
output_bound.lower, output_bound.upper

(Array([-160.58766], dtype=float32), Array([87.69452], dtype=float32))