In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from vitascopic_nca.trainer import Trainer
from vitascopic_nca.config import DefaultTrainerConfig
from tqdm.auto import tqdm
from IPython.display import display, clear_output
import panel as pn

torch.autograd.set_detect_anomaly(False)
pn.extension()

In [None]:
config = DefaultTrainerConfig(
    # ===== NCA CONFIG =====
    message_channels=12,
    visual_channels=1,  # overridden
    hidden_channels=128,
    fire_rate=0.99,
    alive_threshold=0.1,
    zero_initialization=False,
    mass_conserving="normal",
    padding_type="circular",
    beta=1.0,
    num_embs=5,
    msg_type="random",
    # ===== OPTIMIZATION CONFIG =====
    loss_type="mse",
    lr=0.0001,
    batch_size=24,
    # ===== DECODER CONFIG =====
    n_layers=3,
    hidden_dim=128,
    in_dim=1,
    pooling_fn=torch.amax,
    # ===== TRAINER CONFIG =====
    H=64,
    W=64,
    device="cuda",
    checkpoint_path="./checkpoints",
)

In [None]:
trainer = Trainer(config)
trainer.sanity_check()

In [None]:
with torch.no_grad():
    info = trainer.optim_step(steps=150)
    display(trainer.display_optim_step(info))

In [None]:
pbar = tqdm(range(50_000))
for i in pbar:
    info = trainer.optim_step(steps=(20, 100))
    pbar.set_description(f"Loss: {info['loss']:.4f}")

    if i % 250 == 0:
        with torch.no_grad():
            info = trainer.optim_step(steps=100)

        clear_output(wait=True)
        display(pbar.container)
        display(trainer.display_optim_step(info))
        trainer.save_checkpoint()

In [None]:
# # Loading a checkpoint works

# trainer = Trainer.load_last_trainer(checkpoint_path="checkpoints")
# trainer.load_checkpoint(40001)

# with torch.no_grad():
#     info = trainer.optim_step(steps=110)
#     plots = trainer.display_optim_step(info)
#     display(plots)

---