In [None]:
%load_ext tensorboard

In [None]:
from pathlib import Path

import jax
import jax_verify
import numpy as np
import optax
from einops import reduce
from flax import nnx
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from vml_final.data import CSVDataset, CSVDatasetEpochLoader
from vml_final.model import TemporalConvolutionalNetwork
from vml_final.training import do_eval_epoch, do_train_epoch

In [None]:
project_root = Path("../../")
dset = CSVDataset(project_root / "csv_dataset" / "AB09", stack_size=512)

In [None]:
train_loader = CSVDatasetEpochLoader(dset, 1024)
eval_loader = CSVDatasetEpochLoader(dset, 256, train=False)

In [None]:
%tensorboard --logdir ../../logs --port 6006

In [None]:
rngs = nnx.Rngs(0)

model = TemporalConvolutionalNetwork(
    input_channels=dset.x.shape[-1],
    conv_hidden_dims=[8, 8, 8, 16, 32],
    # mlp_hidden_dims=[128, 128],
    kernel_size=5,
    stride=5,
    dropout=0.0,
    rngs=rngs,
)

optim = nnx.Optimizer(model, optax.adam(1.0e-3))

In [None]:
writer = SummaryWriter(project_root / "logs")

In [None]:
num_epochs = 128

# pbar.close()

pbar = tqdm(total=num_epochs)
for i in range(num_epochs):
    train_loss = do_train_epoch(optim, train_loader, pbar=False)
    # print(f"Epoch {i} loss: {epoch_loss:2.5f}")
    validation_loss = do_eval_epoch(optim, eval_loader, pbar=False)
    pbar.update()
    pbar.set_postfix({"Train Loss": train_loss, "Val Loss": validation_loss})
    writer.add_scalars(
        "loss",
        {"train": train_loss, "validation": validation_loss},
        global_step=optim.step,
    )

pbar.close()

In [None]:
batch_x, batch_y = dset.get_batch(128, train=False)

np.stack([model(batch_x), batch_y], -1)

In [None]:
do_eval_epoch(optim, eval_loader)

In [None]:
nnx.display(optim.model)

In [None]:
channel_max = reduce(dset.x, "e t c -> c", "max")
channel_min = reduce(dset.x, "e t c -> c", "min")

upper = np.array([channel_max] * dset.x.shape[-2])
lower = np.array([channel_min] * dset.x.shape[-2])

In [None]:
model_graphdef, model_state = nnx.split(model)

model.eval()

@jax.jit
def pure_call(x):
    model = nnx.merge(model_graphdef, model_state)
    return model(x)

In [None]:
output_bound = jax_verify.backward_crown_bound_propagation(
    model,
    jax_verify.IntervalBound(lower, upper),
)

In [None]:
output_bound.lower, output_bound.upper