In [None]:
import os

import hydra
import lightning
import wandb

from openretina.data_io.cyclers import LongCycler

In [None]:
with hydra.initialize(config_path="../configs", version_base="1.3"):
    cfg = hydra.compose(config_name="vystrcilova_2024_nm_cnn.yaml")
print(cfg)
cfg.dataloader.num_of_frames = cfg.model.temporal_kernel_sizes[0]
cfg.dataloader.num_of_layers = len(cfg.model.temporal_kernel_sizes)
dataloaders = hydra.utils.instantiate(cfg.dataloader)
n_neurons_dict = {}
for index in cfg.dataloader.files.keys():
    n_neurons_dict[index] = dataloaders["train"][index].dataset.n_neurons

In [None]:
input_shape = next(iter(dataloaders["train"]["01"]))
cfg.model["in_shape"] = list(input_shape[0].shape[1:])
cfg.model["n_neurons_dict"] = n_neurons_dict
print(cfg.model.in_shape)

In [None]:
model = hydra.utils.instantiate(cfg.model)
batch = next(iter(dataloaders["train"]["01"]))
print("input img shape", batch[0].shape)
print("response shape", batch[1].shape)

In [None]:
model(batch[0])

In [None]:
log_save_path = os.path.join("/projects/extern/nhr/nhr_ni/nim00010/dir.project/logs", "cnns_nm/")
os.makedirs(
    log_save_path,
    exist_ok=True,
)

logger = lightning.pytorch.loggers.WandbLogger(
    name="",
    save_dir=log_save_path,
)
early_stopping = lightning.pytorch.callbacks.EarlyStopping(
    monitor="val_correlation",
    patience=10,
    mode="max",
    verbose=False,
    min_delta=0.001,
)

lr_monitor = lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch")

model_checkpoint = lightning.pytorch.callbacks.ModelCheckpoint(
    monitor="val_correlation", mode="max", save_weights_only=False
)

trainer = lightning.Trainer(
    max_epochs=100,
    logger=logger,
    callbacks=[early_stopping, lr_monitor, model_checkpoint],
    accelerator="gpu",
    log_every_n_steps=10,
)

In [None]:
train_loader = LongCycler(dataloaders["train"])
val_loader = LongCycler(dataloaders["validation"])

In [None]:
trainer.fit(model, train_loader, val_loader)

In [None]:
wandb.finish()