In this notebook we will train a Core/Readout model on data from Hoefling et al., 2024: ["A chromatic feature detector in the retina signals visual context changes"](https://elifesciences.org/articles/86860).

We will closely follow the structure of our unified training script, `openretina.cli.train.py`, including using Hydra to import and examine model config files. 

Note that using `openretina.cli.train.py`, and the corresponding command `openretina train` is the recommended way to run model training, as for some configurations it can take some time. 


# Imports

In [1]:
import logging
import os

import hydra
import lightning
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from einops import rearrange

from openretina.data_io.base import compute_data_info
from openretina.data_io.cyclers import LongCycler, ShortCycler
from openretina.data_io.hoefling_2024.dataloaders import natmov_dataloaders_v2
from openretina.data_io.hoefling_2024.responses import filter_responses, make_final_responses
from openretina.data_io.hoefling_2024.stimuli import movies_from_pickle
from openretina.eval.metrics import correlation_numpy, feve
from openretina.models.core_readout import CoreReadout
from openretina.utils.file_utils import get_local_file_path
from openretina.utils.h5_handling import load_h5_into_dict
from openretina.utils.misc import CustomPrettyPrinter
from openretina.utils.plotting import (
    numpy_to_mp4_video,
)

matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)  # to display logs in jupyter

%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

pp = CustomPrettyPrinter(indent=4, max_lines=30)

Let's import also the global config file for this model using hydra.

In [2]:
torch.cuda.empty_cache()

In [16]:
with hydra.initialize(config_path=os.path.join("..", "configs"), version_base="1.3"):
    cfg = hydra.compose(config_name="hoefling_2024_core_readout_low_res.yaml")

# Loading data

The first step in loading data is determining from where it will be fetched / stored.

Let's see how this is handled in the configs:

In [4]:
pp.pprint(cfg.paths)

{   'cache_dir': '${oc.env:OPENRETINA_CACHE_DIRECTORY}',
    'data_dir': '${paths.cache_dir}/data/',
    'load_model_path': None,
    'log_dir': '.',
    'movies_path': 'https://huggingface.co/datasets/open-retina/open-retina/blob/main/euler_lab/hoefling_2024/stimuli/rgc_natstim_18x16_joint_normalized_2024-01-11.zip',
    'output_dir': '${hydra:runtime.output_dir}',
    'responses_path': 'https://huggingface.co/datasets/open-retina/open-retina/resolve/main/euler_lab/hoefling_2024/responses/rgc_natstim_2024-08-14.zip'}


The config contains the path from where files will be downloaded, and also requires the `cache_dir` to be set by the user: this is the directory where the data will be stored on download.

When using the training script, if cache_dir is not set by the user in the config files or somewhere in the script, this will fall back to the `OPENRETINA_CACHE_DIRECTORY` environment variable, which by default points to `~/openretina_cache`.

If set, the `cache_dir` is also what the package will use in place of the default openretina cache folder. Let's set both here:

In [21]:
your_chosen_root_folder = "/home/bethge/bkr618/openretina_cache"  # Change this with your desired path.

cfg.paths.cache_dir = your_chosen_root_folder

# We will also overwrite the output directory for the logs/model to the local folder.
cfg.paths.log_dir = your_chosen_root_folder
cfg.paths.output_dir = your_chosen_root_folder

os.environ["OPENRETINA_CACHE_DIRECTORY"] = your_chosen_root_folder

## Stimuli

Loading of the stimuli is achieved, in the training script, via:
```
movies_dict = hydra.utils.call(cfg.data_io.stimuli)
```

Let's unpack it here.

In [7]:
pp.pprint(cfg.data_io.stimuli)

{   '_convert_': 'object',
    '_target_': 'openretina.data_io.hoefling_2024.stimuli.movies_from_pickle',
    'file_path': {   '_convert_': 'object',
                     '_target_': 'openretina.utils.file_utils.get_local_file_path',
                     'cache_folder': '${paths.cache_dir}',
                     'file_path': '${paths.movies_path}'}}


Essentially, using the `get_local_file_path` function, if `file_path` is not a local fiile, it will be downloaded to the cache folder and read from there.

In [7]:
movies_path = get_local_file_path(file_path=cfg.paths.movies_path, cache_folder=cfg.paths.data_dir)

movies_dict = movies_from_pickle(movies_path)

2025-10-23 09:40:08,639 - INFO - Fetching file list for open-retina/open-retina...
2025-10-23 09:40:08,823 - INFO - Downloading euler_lab/hoefling_2024/stimuli/rgc_natstim_18x16_joint_normalized_2024-01-11.zip...


euler_lab/hoefling_2024/stimuli/rgc_nats(…):   0%|          | 0.00/73.7M [00:00<?, ?B/s]

  [2m2025-10-23T07:40:11.564277Z[0m [31mERROR[0m  [31mPython exception updating progress:, error: PyErr { type: <class 'LookupError'>, value: LookupError(<ContextVar name='shell_parent' at 0x7f93e6d30860>), traceback: Some(<traceback object at 0x7f92ad1e25c0>) }, [1;31mcaller[0m[31m: "src/progress_update.rs:313"[0m
    [2;3mat[0m /home/runner/work/xet-core/xet-core/error_printer/src/lib.rs:28



Cancellation requested; stopping current tasks.


KeyboardInterrupt: 

In [17]:
cfg.paths.movies_path

'https://huggingface.co/datasets/open-retina/open-retina/blob/main/euler_lab/hoefling_2024/stimuli/rgc_natstim_18x16_joint_normalized_2024-01-11.zip'

In [5]:
file_path = '/home/bethge/bkr618/openretina_cache/euler_lab/hoefling_2024/stimuli/rgc_natstim_72x64_joint_normalized_2024-10-11.pkl'
movie_stimuli = movies_from_pickle(file_path)

In [45]:
numpy_to_mp4_video(movie_stimuli.test_movie, fps=30)

Let us also visualize a few seconds of the training video:

In [55]:
numpy_to_mp4_video(movie_stimuli.train[:, :600, ...])

## Responses

In the training script, responses are loaded through:

```
neuron_data_dict = hydra.utils.call(cfg.data_io.responses)
```

Let's unpack it here.

In [9]:
pp.pprint(cfg.data_io.responses)

{   '_convert_': 'object',
    '_target_': 'openretina.data_io.hoefling_2024.responses.make_final_responses',
    'data_dict': {   '_convert_': 'object',
                     '_target_': 'openretina.data_io.hoefling_2024.responses.filter_responses',
                     'all_responses': {   '_convert_': 'object',
                                          '_target_': 'openretina.utils.h5_handling.load_h5_into_dict',
                                          'file_path': {   '_convert_': 'object',
                                                           '_target_': 'openretina.utils.file_utils.get_local_file_path',
                                                           'cache_folder': '${paths.cache_dir}',
                                                           'file_path': '${paths.responses_path}'}},
                     'cell_types_list': '${quality_checks.cell_types_list}',
                     'chirp_qi': '${quality_checks.chirp_qi}',
                     'classifier_confid

While this may look complex, it effectively amounts to resolving a few intermediate steps in loading the data, and should be read from the inside out.

When written more simply, it is equivalent to the following:

In [6]:
#responses_path = get_local_file_path(file_path=cfg.paths.responses_path, cache_folder=cfg.paths.data_dir)
responses_path = "/home/bethge/bkr618/openretina_cache/data/euler_lab/hoefling_2024/responses/rgc_natstim_2024-08-14.h5"
responses_dict = load_h5_into_dict(file_path=responses_path)

filtered_responses_dict = filter_responses(responses_dict, **cfg.quality_checks)

final_responses = make_final_responses(filtered_responses_dict, response_type="natural")

Loading HDF5 file contents:   0%|          | 0/2077 [00:00<?, ?item/s]

Original dataset contains 7863 neurons over 67 fields
 ------------------------------------ 
Dropped 0 fields that did not contain the target cell types (67 remaining)
Overall, dropped 3034 neurons of non-target cell types (-38.59%).
 ------------------------------------ 
Dropped 0 fields with quality indices below threshold (67 remaining)
Overall, dropped 980 neurons over quality checks (-20.29%).
 ------------------------------------ 
Dropped 0 fields with classifier confidences below 0.25
Overall, dropped 705 neurons with classifier confidences below 0.25 (-18.32%).
 ------------------------------------ 
 ------------------------------------ 
Final dataset contains 3144 neurons over 67 fields
Total number of cells dropped: 4719 (-60.02%)


Upsampling natural spikes traces to get final responses.:   0%|          | 0/67 [00:00<?, ?it/s]

And here is how the final responses will be organised:

In [5]:
pp.pprint(final_responses)

{   'session_1_ventral1_20200226': ResponsesTrainTestSplit(train=numpy.ndarray(shape=(80, 16200)),
                                                           test_dict={   'test': numpy.ndarray(shape=(80, 750))},
                                                           test_by_trial=numpy.ndarray(shape=(80, 750, 3)),
                                                           stim_id='natural',
                                                           session_kwargs={   'eye': 'left',
                                                                              'group_assignment': numpy.ndarray(shape=(80,)),
                                                                              'roi_ids': numpy.ndarray(shape=(80,)),
                                                                              'roi_mask': numpy.ndarray(shape=(64, 64)),
                                                                              'scan_sequence_idx': np.int64(18)}),
    'session_1_ventral1_20200

# Creating dataloaders

The corresponding code in `train.py` is:
```
dataloaders = hydra.utils.instantiate(
        cfg.dataloader,
        neuron_data_dictionary=neuron_data_dict,
        movies_dictionary=movies_dict,
    )
```

In [5]:
pp.pprint(cfg.dataloader)

{   '_convert_': 'object',
    '_target_': 'openretina.data_io.hoefling_2024.dataloaders.natmov_dataloaders_v2',
    'allow_over_boundaries': True,
    'batch_size': 128,
    'train_chunk_size': 50,
    'validation_clip_indices': list(len=15)}


In [7]:
dataloaders = natmov_dataloaders_v2(
    neuron_data_dictionary=final_responses,
    movies_dictionary=movie_stimuli,
    allow_over_boundaries=True,
    batch_size=128,
    train_chunk_size=50,
    validation_clip_indices=cfg.dataloader.validation_clip_indices,
)

Creating movie dataloaders:   0%|          | 0/67 [00:00<?, ?it/s]

In [10]:
print(dataloaders.keys())

dict_keys(['train', 'validation', 'test'])


In [62]:
for key, loader in dataloaders['train'].items():
    batch = next(iter(loader))  # get first batch from this DataLoader
    print(key)
    print(batch.inputs.shape)   # if batch is a DataPoint
    print(batch.targets.shape)  # if available



session_1_ventral1_20200226
torch.Size([128, 2, 50, 72, 64])
torch.Size([128, 50, 80])
session_1_ventral1_20200528
torch.Size([128, 2, 50, 72, 64])
torch.Size([128, 50, 42])
session_1_ventral1_20200707
torch.Size([128, 2, 50, 72, 64])
torch.Size([128, 50, 74])
session_1_ventral1_20201021
torch.Size([128, 2, 50, 72, 64])
torch.Size([128, 50, 32])
session_1_ventral1_20201030
torch.Size([128, 2, 50, 72, 64])
torch.Size([128, 50, 40])
session_1_ventral1_20210929
torch.Size([128, 2, 50, 72, 64])
torch.Size([128, 50, 48])
session_1_ventral1_20210930
torch.Size([128, 2, 50, 72, 64])
torch.Size([128, 50, 26])
session_1_ventral2_20200302
torch.Size([128, 2, 50, 72, 64])
torch.Size([128, 50, 41])
session_1_ventral2_20200707
torch.Size([128, 2, 50, 72, 64])
torch.Size([128, 50, 56])
session_1_ventral2_20201021
torch.Size([128, 2, 50, 72, 64])
torch.Size([128, 50, 39])
session_1_ventral2_20201022
torch.Size([128, 2, 50, 72, 64])
torch.Size([128, 50, 42])
session_1_ventral2_20201030
torch.Size([128

KeyboardInterrupt: 

In [52]:
len(dataloaders["train"]['session_1_ventral1_20200226'])

2

In [42]:
for i, batch in enumerate(dataloaders["train"]["session_1_ventral1_20200226"]):
    print(f"n_sessions: {len(dataloaders['train'])}")
    
    print(batch.inputs.shape)
    print(batch.targets.shape)

n_sessions: 67
torch.Size([128, 2, 50, 72, 64])
torch.Size([128, 50, 80])
n_sessions: 67
torch.Size([128, 2, 50, 72, 64])
torch.Size([128, 50, 80])


In [47]:
for batch in dataloaders["test"]["session_2_ventral2_20201022"]:
    print(f"n_sessions: {len(dataloaders['train'])}")
    print(batch.inputs.shape)
    print(batch.targets.shape)

n_sessions: 67
torch.Size([1, 2, 750, 72, 64])
torch.Size([1, 750, 48])


In [11]:
for batch in dataloaders["validation"]["session_1_ventral1_20200226"]:
    print(f"n_sessions: {len(dataloaders['train'])}")
    print(batch.inputs.shape)
    print(batch.targets.shape)
    break

n_sessions: 67
torch.Size([15, 2, 150, 72, 64])
torch.Size([15, 150, 80])


In [49]:
pp.pprint(dataloaders['train'])

{   'session_1_ventral1_20200226': torch.utils.data.DataLoader(Dataset: MovieDataSet with 80 neuron responses to a movie of shape [2, 13950, 72, 64].),
    'session_1_ventral1_20200528': torch.utils.data.DataLoader(Dataset: MovieDataSet with 42 neuron responses to a movie of shape [2, 13950, 72, 64].),
    'session_1_ventral1_20200707': torch.utils.data.DataLoader(Dataset: MovieDataSet with 74 neuron responses to a movie of shape [2, 13950, 72, 64].),
    'session_1_ventral1_20201021': torch.utils.data.DataLoader(Dataset: MovieDataSet with 32 neuron responses to a movie of shape [2, 13950, 72, 64].),
    'session_1_ventral1_20201030': torch.utils.data.DataLoader(Dataset: MovieDataSet with 40 neuron responses to a movie of shape [2, 13950, 72, 64].),
    'session_1_ventral1_20210929': torch.utils.data.DataLoader(Dataset: MovieDataSet with 48 neuron responses to a movie of shape [2, 13950, 72, 64].),
    'session_1_ventral1_20210930': torch.utils.data.DataLoader(Dataset: MovieDataSet wit

In [None]:
print(f"n sessions train: {len(dataloaders['train'].keys())}")
print(dataloaders["train"]["session_1_ventral1_20200226"])
print(dataloaders['validation'].shape)
print(dataloaders['test'].shape)


n:sessions: 67


AttributeError: 'dict' object has no attribute 'shape'

In [7]:
pp.pprint(dataloaders)

{   'test': {   'session_1_ventral1_20200226': torch.utils.data.DataLoader(Dataset: MovieDataSet with 80 neuron responses to a movie of shape [2, 750, 72, 64].),
                'session_1_ventral1_20200528': torch.utils.data.DataLoader(Dataset: MovieDataSet with 42 neuron responses to a movie of shape [2, 750, 72, 64].),
                'session_1_ventral1_20200707': torch.utils.data.DataLoader(Dataset: MovieDataSet with 74 neuron responses to a movie of shape [2, 750, 72, 64].),
                'session_1_ventral1_20201021': torch.utils.data.DataLoader(Dataset: MovieDataSet with 32 neuron responses to a movie of shape [2, 750, 72, 64].),
                'session_1_ventral1_20201030': torch.utils.data.DataLoader(Dataset: MovieDataSet with 40 neuron responses to a movie of shape [2, 750, 72, 64].),
                'session_1_ventral1_20210929': torch.utils.data.DataLoader(Dataset: MovieDataSet with 48 neuron responses to a movie of shape [2, 750, 72, 64].),
                'session_1_v

Let's also compute `data_info`, which is used to initialise certain model components and to save important metadata about stimuli and responses within the model.

In [8]:
data_info = compute_data_info(neuron_data_dictionary=final_responses, movies_dictionary=movie_stimuli)

pp.pprint(data_info)

{   'input_shape': (2, 72, 64),
    'movie_norm_dict': {   'default': {   'norm_mean': 36.979288270899204,
                                          'norm_std': 36.98463253226166}},
    'n_neurons_dict': {   'session_1_ventral1_20200226': 80,
                          'session_1_ventral1_20200528': 42,
                          'session_1_ventral1_20200707': 74,
                          'session_1_ventral1_20201021': 32,
                          'session_1_ventral1_20201030': 40,
                          'session_1_ventral1_20210929': 48,
                          'session_1_ventral1_20210930': 26,
                          'session_1_ventral2_20200302': 41,
                          'session_1_ventral2_20200707': 56,
                          'session_1_ventral2_20201021': 39,
                          'session_1_ventral2_20201022': 42,
                          'session_1_ventral2_20201030': 84,
                          'session_1_ventral2_20201117': 46,
                         

# Tentatives Transf

In [7]:
from openretina.models.transformer_core import ViViTCore


# Model initialisation

Relevant `train.py` section:
```
cfg.model.n_neurons_dict = data_info["n_neurons_dict"]

model = hydra.utils.instantiate(cfg.model, data_info=data_info)
```

The config for the model will contain all the relevant hyperparameters for it:

In [18]:
pp.pprint(cfg.model)

{   'core': {   '_convert_': 'object',
                '_target_': 'openretina.modules.core.base_core.SimpleCoreWrapper',
                'channels': '???',
                'color_squashing_weights': None,
                'convolution_type': 'custom_separable',
                'cut_first_n_frames': 0,
                'downsample_input_kernel_size': None,
                'dropout_rate': 0.0,
                'gamma_hidden': 0.0,
                'gamma_in_sparse': 0.0,
                'gamma_input': 0.0,
                'gamma_temporal': 40.0,
                'hidden_padding': [0, 2, 2],
                'input_padding': 0,
                'maxpool_every_n_layers': None,
                'spatial_kernel_sizes': [11, 5],
                'temporal_kernel_sizes': [21, 11]},
    'hidden_channels': [16, 16],
    'in_shape': [2, 150, 18, 16],
    'learning_rate': 0.01,
    'n_neurons_dict': '???',
    'readout': {   '_convert_': 'object',
                   '_target_': 'openretina.modules.readout

As you can see, the value for `n_neurons_dict` is missing, and needs to be set from data_info.

In [19]:
n_neurons_dict = data_info["n_neurons_dict"]
from openretina.models.core_readout import TransformerCoreReadout

model = TransformerCoreReadout(
    in_shape=(128,2, 50, 72, 64),  # (batch, channels, time, height, width)
    n_neurons_dict=n_neurons_dict,
    emb_dim=32,
    in_channels=2,
    img_size=(72,64),
    patch_size=(8, 8),
    temporal_patch_size=5,
    num_spatial_blocks=4,
    num_temporal_blocks=4,
    num_heads=2,
    mlp_ratio=4.0,
    dropout=0.1,  
    chunk_size=64,
    pad_frame=True,
    temporal_stride=1,
    spatial_stride=1,
    ptoken=0.1,
    readout_bias=True,
    readout_init_mu_range=0.5,  
    readout_init_sigma_range=4.0, 
    readout_gamma=0.4,
    readout_reg_avg=False,
    learning_rate=0.01,
    data_info=data_info,
)


2025-10-31 16:09:32,711 - INFO - in_shape_readout=(32, 11, 10, 9)


1. Starting init
2. Basic attributes set
3. Padding computed: t=4, h=7, w=7
4. Padding layer created
5. Tokenizer created
6. Transformer created
7. Patch dimensions computed
8. About to initialize weights
✓ Weights initialized
9. Weights initialized - INIT COMPLETE


In [10]:
from pytorch_lightning.utilities.model_summary import summarize


summary = summarize(model, max_depth=-1)  # full depth
print(summary)



    | Name                                                         | Type                               | Params | Mode 
------------------------------------------------------------------------------------------------------------------------------
0   | core                                                         | TransformerCoreWrapper             | 442 K  | train
1   | core.pad                                                     | ZeroPad3d                          | 0      | train
2   | core.tokenizer                                               | VideoTokenizer                     | 42.3 K | train
3   | core.tokenizer.ln                                            | LayerNorm                          | 1.3 K  | train
4   | core.tokenizer.fc                                            | Linear                             | 41.0 K | train
5   | core.transformer                                             | SpatioTemporalTransformer          | 400 K  | train
6   | core.transformer.spa

# Training

With data imported, models initialised and dataloaders set up, we can turn to training. 

```
log_folder = os.path.join(cfg.paths.output_dir, cfg.exp_name)
os.makedirs(log_folder, exist_ok=True)
logger_array = []
for _, logger_params in cfg.logger.items():
    logger = hydra.utils.instantiate(logger_params, save_dir=log_folder)
    logger_array.append(logger)

callbacks = [
    hydra.utils.instantiate(callback_params) for callback_params in cfg.get("training_callbacks", {}).values()
]

trainer = hydra.utils.instantiate(cfg.trainer, logger=logger_array, callbacks=callbacks)
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=valid_loader)
```

This section is a bit more involved in `train.py`, to leave flexibility for different loggers and callbacks configurations. We are going to keep it simple here.

Let's first initialise a simple tensorboard logger:

In [22]:
log_save_path = os.path.join(cfg.paths.output_dir, "notebook_example")
os.makedirs(log_save_path, exist_ok=True)

logger = lightning.pytorch.loggers.TensorBoardLogger(
    name="tensorboard/",
    save_dir=log_save_path,
)

Then some training callbacks (i.e. utility functions that will be called during training):

In [23]:
early_stopping = lightning.pytorch.callbacks.EarlyStopping(
    monitor="val_correlation",
    patience=10,
    mode="max",
    verbose=False,
    min_delta=0.001,
)

lr_monitor = lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch")

model_checkpoint = lightning.pytorch.callbacks.ModelCheckpoint(
    monitor="val_correlation", mode="max", save_weights_only=False
)

We can then instantiate the trainer:

In [24]:
trainer = lightning.Trainer(max_epochs=100, logger=logger, callbacks=[early_stopping, lr_monitor, model_checkpoint])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Finally, we can start training. Before doing sp though, we can initialise the tensorboard jupyter integration, to visualize how training progresses.

Run the following cell once or twice until the tensorboard extension UI shows up. Once is shows, note that at the beginning it will show no data (unless you have run this notebook before), because we have not started the trainer yet.

When you run the cell containing `trainer.fit` you can then come back to the tensorboard extension, reload the window *within the extension* by clicking the refresh icon in the top right, and follow the training.

In [23]:
%reload_ext tensorboard

%tensorboard --logdir {log_save_path}

Reusing TensorBoard on port 6006 (pid 4040261), started 0:00:02 ago. (Use '!kill 4040261' to kill it.)

The only last important step before calling the trainer is to convert the dictionary of dataloaders we have into a unified iterator that will cycle through all sessions during training and evaluation:

In [25]:
train_loader = LongCycler(dataloaders["train"])
val_loader = ShortCycler(dataloaders["validation"])

In [26]:

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")


Total parameters: 246,168
Trainable parameters: 246,168


In [27]:
import torch

# Check if CUDA is available
print(f"CUDA available: {torch.cuda.is_available()}")

# Number of GPUs
print(f"Number of GPUs: {torch.cuda.device_count()}")

# Current GPU
if torch.cuda.is_available():
    print(f"Current GPU: {torch.cuda.current_device()}")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    
    # Detailed info for each GPU
    for i in range(torch.cuda.device_count()):
        print(f"\n=== GPU {i} ===")
        print(f"Name: {torch.cuda.get_device_name(i)}")
        props = torch.cuda.get_device_properties(i)
        print(f"Total Memory: {props.total_memory / 1e9:.2f} GB")
        print(f"Compute Capability: {props.major}.{props.minor}")
        
    # Current memory usage
    print(f"\n=== Current Memory Usage ===")
    print(f"Allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
    print(f"Cached: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")
    print(f"Free: {(props.total_memory - torch.cuda.memory_allocated(0)) / 1e9:.2f} GB")

CUDA available: True
Number of GPUs: 1
Current GPU: 0
GPU Name: NVIDIA GeForce RTX 2080 Ti

=== GPU 0 ===
Name: NVIDIA GeForce RTX 2080 Ti
Total Memory: 11.55 GB
Compute Capability: 7.5

=== Current Memory Usage ===
Allocated: 0.00 GB
Cached: 0.00 GB
Free: 11.55 GB


In [11]:
def get_model_memory(model):
    """Calculate memory used by model parameters"""
    param_size = 0
    buffer_size = 0
    
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    total_size = param_size + buffer_size
    
    print(f"Parameter memory: {param_size / 1e6:.2f} MB")
    print(f"Buffer memory: {buffer_size / 1e6:.2f} MB")
    print(f"Total model memory: {total_size / 1e6:.2f} MB")
    
    return total_size

# Usage
get_model_memory(model)

Parameter memory: 0.98 MB
Buffer memory: 0.00 MB
Total model memory: 0.98 MB


984672

In [12]:
def estimate_memory_for_batch_size(model, input_shape, batch_sizes):
    """Estimate GPU memory usage for models expecting (B,C,T,H,W)"""
    import gc
    model.eval()
    torch.cuda.empty_cache()

    for bs in batch_sizes:
        try:
            torch.cuda.reset_peak_memory_stats()

            # Replace batch dimension in input_shape
            shape = (bs, *input_shape[1:])
            dummy = torch.randn(*shape, device="cuda")

            with torch.no_grad():
                _ = model(dummy)

            peak = torch.cuda.max_memory_allocated() / 1e9
            print(f"Batch size {bs:3d}: {peak:.2f} GB")

        except RuntimeError:
            print(f"Batch size {bs:3d}: OOM (Out of Memory)")
            break

        finally:
            del dummy, _
            gc.collect()
            torch.cuda.empty_cache()



In [27]:
import torch
torch.cuda.empty_cache()
import gc
gc.collect()

3488

In [29]:
import torch
print(torch.cuda.get_device_name(0))
print(f"Allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB")
print(f"Total: {torch.cuda.get_device_properties(0).total_memory/1e9:.2f} GB")


NVIDIA GeForce RTX 2080 Ti
Allocated: 0.01 GB
Reserved: 1.19 GB
Total: 11.55 GB


In [30]:
with torch.cuda.profiler.profile():
    with torch.no_grad():
        _ = model(torch.randn(1, 2, 50, 72, 64, device='cuda'))


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument weight in method wrapper_CUDA__native_layer_norm)

In [23]:
nvidia-smi


NameError: name 'nvidia' is not defined

In [None]:
import torch

# Clear cache first
torch.cuda.empty_cache()

# Check before
print("=== BEFORE MODEL ===")
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

# Move model to GPU
model = model.to('cuda')

print("\n=== AFTER MODEL TO GPU ===")
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

# Create dummy batch
dummy_input = torch.randn(128, 2, 50, 72, 64).to('cuda')

print("\n=== AFTER CREATING INPUT ===")
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

# Forward pass
output = model(dummy_input)

print("\n=== AFTER FORWARD PASS ===")
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

# Backward pass
loss = output.sum()
loss.backward()

print("\n=== AFTER BACKWARD PASS ===")
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

=== BEFORE MODEL ===
Allocated: 0.00 GB
Reserved: 0.00 GB

=== AFTER MODEL TO GPU ===
Allocated: 0.00 GB
Reserved: 0.00 GB

=== AFTER CREATING INPUT ===
Allocated: 0.24 GB
Reserved: 0.24 GB


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.41 GiB. GPU 0 has a total capacity of 10.75 GiB of which 225.62 MiB is free. Including non-PyTorch memory, this process has 10.53 GiB memory in use. Of the allocated memory 10.35 GiB is allocated by PyTorch, and 1004.00 KiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

: 

In [None]:
import torch, gc

del model  # or any large tensors
gc.collect()
torch.cuda.empty_cache()


: 

In [28]:
trainer.fit(model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                               | Params | Mode 
--------------------------------------------------------------------------------
0 | core             | TransformerCoreWrapper             | 123 K  | train
1 | readout          | MultiSampledGaussianReadoutWrapper | 122 K  | train
2 | loss             | PoissonLoss3d                      | 0      | train
3 | correlation_loss | CorrelationLoss3d                  | 0      | train
--------------------------------------------------------------------------------
246 K     Trainable params
0         Non-trainable params
246 K     Total params
0.985     Total estimated model params size (MB)
170       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 170.00 MiB. GPU 0 has a total capacity of 10.75 GiB of which 107.62 MiB is free. Including non-PyTorch memory, this process has 10.64 GiB memory in use. Of the allocated memory 10.39 GiB is allocated by PyTorch, and 78.94 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

# Evaluation

Once the model is done training, we can turn to evaluation.

First, let's still use the trainer to see the poisson and correlation performance on each of the dataloaders.

In [None]:
test_loader = ShortCycler(dataloaders["test"])
trainer.test(model, dataloaders=[train_loader, val_loader, test_loader], ckpt_path="best")

We can also look at further evals, like the fraction of explainable variance explained for an example session.

In [7]:
# Let's pick an example session
example_session = list(final_responses.keys())[0]

# Extract responses by trial:
responses_by_trial = final_responses[example_session].test_by_trial

responses_by_trial.shape

(80, 750, 3)

In [8]:
# Get the test movie for that session:
test_movie = dataloaders["test"][example_session].dataset.movies

# Pass it through the model: move to gpu and add batch dimension
with torch.no_grad():
    model_predictions = model.forward(test_movie.to(model.device).unsqueeze(0), data_key=example_session)

model_predictions.shape

NameError: name 'dataloaders' is not defined

In [None]:
help(feve)

In [None]:
# We need to reshape the predictions and responses by trial to match what the function expects

feve_score = feve(
    rearrange(responses_by_trial, "neurons time trials -> time trials neurons")[20:],
    model_predictions.squeeze(0).cpu().numpy(),
)

print(f"Average FEVe score for session {example_session}: {feve_score.mean():.2f}")

Finally, we can plot an example neuron's predictions and its ground truth response.

In [None]:
neuron_idx = 4
session_idx = 0


example_session = list(final_responses.keys())[session_idx]

test_sample = next(iter(dataloaders["test"][example_session]))
responses_by_trial = final_responses[example_session].test_by_trial
mean_test_responses = final_responses[example_session].test_response

input_samples = test_sample.inputs
targets = test_sample.targets

model.eval()
model.cpu()

with torch.no_grad():
    reconstructions = model(input_samples.cpu(), example_session)
reconstructions = reconstructions.cpu().numpy().squeeze()

feve_score = feve(
    rearrange(responses_by_trial, "neurons time trials -> time trials neurons")[20:],
    model_predictions.squeeze(0).cpu().numpy(),
)

correlations = correlation_numpy(mean_test_responses.T[20:], model_predictions.squeeze(0).cpu().numpy(), axis=0)


targets = targets.cpu().numpy().squeeze()
window = 750
plt.figure(figsize=(10, 5))
plt.plot(np.arange(0, window), targets[:window, neuron_idx], label="target")
plt.plot(np.arange(20, window), reconstructions[:window, neuron_idx], label="prediction")
plt.suptitle(f"Neuron {neuron_idx} - FEVE: {feve_score[neuron_idx]:.2f} - Correlation: {correlations[neuron_idx]:.2f}")

plt.legend()
sns.despine()

---