In [None]:
import os

import hydra
import lightning
import wandb

from openretina.data_io.cyclers import LongCycler
from openretina.models.core_readout import UnifiedCoreReadout
from openretina.utils.model_utils import get_core_output_based_on_dimensions


##  Initialize config

In [None]:
config_name = "vystrcilova_2024_wn_cnn.yaml"  # use vystrcilova_2024_nm_cnn.yaml for the natural movie dataset

with hydra.initialize(config_path="../configs", version_base="1.3"):
    cfg = hydra.compose(config_name=config_name)

## Downloading datasets
Dataset download will happen only once and automatically, the first time you attempt to instantiate the dataloader, at the location specified by `cfg.cache_dir`.


In [None]:
cfg.paths.cache_dir

In [None]:
movies_dict = hydra.utils.call(cfg.data_io.stimuli)
neuron_data_dict = hydra.utils.call(cfg.data_io.responses)

if cfg.check_stimuli_responses_match:
    for session, neuron_data in neuron_data_dict.items():
        neuron_data.check_matching_stimulus(movies_dict[session])

In [None]:
print(cfg)
dataloaders = hydra.utils.instantiate(cfg.dataloader)
n_neurons_dict = {}
retina_indices = list(dataloaders["train"].keys())
for index in retina_indices:
    n_neurons_dict[index] = dataloaders["train"][index].dataset.n_neurons


In [None]:
print(retina_indices)

In [None]:
input_shape = next(iter(dataloaders["train"][retina_indices[0]]))
cfg.model["in_shape"] = list(input_shape[0].shape[1:])
cfg.model["n_neurons_dict"] = n_neurons_dict
cfg.model["core"]["channels"] = [input_shape[0].shape[1]] + cfg.model.hidden_channels
cfg.model["n_neurons_dict"] = n_neurons_dict
cfg.model["readout"]["in_shape"] = get_core_output_based_on_dimensions(cfg.model)
print(cfg.model.in_shape)

In [None]:
cfg.model["in_shape"]

In [None]:
batch = next(iter(dataloaders["train"][retina_indices[0]]))
print("input img shape", batch[0].shape)
print("response shape", batch[1].shape)

In [None]:
model = UnifiedCoreReadout(**cfg.model)
model = model.float()

In [None]:
log_save_path = os.path.join(cfg.paths.log_dir, "cnns_wn/")
os.makedirs(
    log_save_path,
    exist_ok=True,
)

logger = lightning.pytorch.loggers.WandbLogger(
    name="",
    save_dir=log_save_path,
)
early_stopping = lightning.pytorch.callbacks.EarlyStopping(
    monitor="val_validation_loss",
    patience=10,
    mode="max",
    verbose=False,
    min_delta=0.001,
)

lr_monitor = lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch")

model_checkpoint = lightning.pytorch.callbacks.ModelCheckpoint(
    monitor="val_validation_loss", mode="max", save_weights_only=False
)

trainer = lightning.Trainer(
    max_epochs=100,
    logger=None,
    callbacks=[early_stopping, lr_monitor, model_checkpoint],
    accelerator="gpu",
    log_every_n_steps=10,
)

In [None]:
train_loader = LongCycler(dataloaders["train"])
val_loader = LongCycler(dataloaders["validation"])

In [None]:
%pdb ON

In [None]:
trainer.fit(model, train_loader, val_loader)

In [None]:
wandb.finish()