In [None]:
import os

import hydra
import lightning

from openretina.data_io.cyclers import LongCycler


In [None]:
config_name = "vystrcilova_2024_wn_ln.yaml"  # use vystrcilova_2024_nm_ln.yaml for the natural movie dataset

with hydra.initialize(config_path="../configs", version_base="1.3"):
    cfg = hydra.compose(config_name=config_name)

## Dataloader
If this is your first time working with the `sridhar_2025` dataset, when calling the dataloader for the first time the dataset will be downloaded at the `cfg.paths.cache_dir` location.

In [None]:
cfg.paths.cache_dir

In [None]:
dataloader = hydra.utils.instantiate(cfg.dataloader)

In [None]:
with hydra.initialize(config_path="../configs/model/", version_base="1.3"):
    model_cfg = hydra.compose(config_name="single_cell_lnp.yaml")


In [None]:
input_shape = next(iter(dataloader["train"][cfg.dataloader["retina_index"]]))[0].shape[1:]

model_cfg["in_shape"] = (input_shape[0], cfg.dataloader["num_of_frames"], *input_shape[2:])
retina_index = cfg.dataloader["retina_index"]
location = dataloader["train"][retina_index].dataset.locations
model = hydra.utils.instantiate(model_cfg)
model.location = location[0]

In [None]:
os.makedirs(
    cfg.paths.log_dir,
    exist_ok=True,
)

early_stopping = lightning.pytorch.callbacks.EarlyStopping(
    monitor="val_validation_metric",
    patience=10,
    mode="max",
    verbose=False,
    min_delta=0.001,
)

lr_monitor = lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch")

model_checkpoint = lightning.pytorch.callbacks.ModelCheckpoint(
    monitor="val_validation_metric", mode="max", save_weights_only=False
)

In [None]:
trainer = lightning.Trainer(
    max_epochs=10000,
    logger=None,
    callbacks=[early_stopping, lr_monitor, model_checkpoint],
    accelerator="gpu",
    log_every_n_steps=10,
)

train_loader = LongCycler(dataloader["train"])
val_loader = LongCycler(dataloader["validation"])

In [None]:
trainer.fit(model, train_loader, val_loader)
