In [1]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch

from openretina.data_io.hoefling_2024.stimuli import movies_from_pickle
from openretina.utils.plotting import (
    numpy_to_mp4_video,
)
from openretina.utils.file_utils import get_local_file_path
from openretina.utils.h5_handling import load_h5_into_dict
from openretina.data_io.cyclers import LongCycler, ShortCycler
from openretina.data_io.hoefling_2024.dataloaders import natmov_dataloaders_v2
from openretina.data_io.hoefling_2024.responses import filter_responses, make_final_responses
from openretina.data_io.hoefling_2024.stimuli import movies_from_pickle
import os
import hydra

In [2]:
with hydra.initialize(config_path=os.path.join("..", "configs"), version_base="1.3"):
    cfg = hydra.compose(config_name="hoefling_2024_core_readout_low_res.yaml")

/home/bethge/bkr618/openretina_cache/notebook_example/tensorboard

In [5]:
your_chosen_root_folder = "/home/bethge/bkr618/openretina_cache"  # Change this with your desired path.

cfg.paths.cache_dir = your_chosen_root_folder

# We will also overwrite the output directory for the logs/model to the local folder.
cfg.paths.log_dir = your_chosen_root_folder
cfg.paths.output_dir = your_chosen_root_folder

os.environ["OPENRETINA_CACHE_DIRECTORY"] = your_chosen_root_folder

In [4]:
file_path = '/home/bethge/bkr618/openretina_cache/euler_lab/hoefling_2024/stimuli/rgc_natstim_72x64_joint_normalized_2024-10-11.pkl'
movie_stimuli = movies_from_pickle(file_path)

In [5]:
responses_path = "/home/bethge/bkr618/openretina_cache/data/euler_lab/hoefling_2024/responses/rgc_natstim_2024-08-14.h5"
responses_dict = load_h5_into_dict(file_path=responses_path)

filtered_responses_dict = filter_responses(responses_dict, **cfg.quality_checks)

final_responses = make_final_responses(filtered_responses_dict, response_type="natural")

Loading HDF5 file contents:   0%|          | 0/2077 [00:00<?, ?item/s]

Original dataset contains 7863 neurons over 67 fields
 ------------------------------------ 
Dropped 0 fields that did not contain the target cell types (67 remaining)
Overall, dropped 3034 neurons of non-target cell types (-38.59%).
 ------------------------------------ 
Dropped 0 fields with quality indices below threshold (67 remaining)
Overall, dropped 980 neurons over quality checks (-20.29%).
 ------------------------------------ 
Dropped 0 fields with classifier confidences below 0.25
Overall, dropped 705 neurons with classifier confidences below 0.25 (-18.32%).
 ------------------------------------ 
 ------------------------------------ 
Final dataset contains 3144 neurons over 67 fields
Total number of cells dropped: 4719 (-60.02%)


Upsampling natural spikes traces to get final responses.:   0%|          | 0/67 [00:00<?, ?it/s]

In [6]:
dataloaders = natmov_dataloaders_v2(
    neuron_data_dictionary=final_responses,
    movies_dictionary=movie_stimuli,
    allow_over_boundaries=True,
    batch_size=128,
    train_chunk_size=50,
    validation_clip_indices=cfg.dataloader.validation_clip_indices,
)

Creating movie dataloaders:   0%|          | 0/67 [00:00<?, ?it/s]

In [7]:
from openretina.data_io.base import compute_data_info
data_info = compute_data_info(neuron_data_dictionary=final_responses, movies_dictionary=movie_stimuli)


In [8]:
train_loader = LongCycler(dataloaders["train"])
val_loader = ShortCycler(dataloaders["validation"])

In [9]:
n_neurons_dict = data_info["n_neurons_dict"]
from openretina.models.core_readout import ViViTCoreReadout

model = ViViTCoreReadout(
    input_shape=(128,2,50,72,64),
    n_neurons_dict=n_neurons_dict,
    Demb=128,  # Embedding dimension
    patch_size=8,  # Spatial patch size (H, W)
    temporal_patch_size=6,  # Temporal patch size
    num_spatial_blocks=3,  # Number of spatial transformer blocks
    num_temporal_blocks=3,  # Number of temporal transformer blocks
    num_heads=4,  # Number of attention heads
    mlp_ratio=4.0,  # MLP expansion ratio
    dropout=0.1,
    pad_frame=True,
    temporal_stride=1,
    spatial_stride=6,
    ptoken=0.1,  # Token dropout probability
    readout_bias=True,
    readout_init_mu_range=0.05,
    readout_init_sigma_range=0.01,
    readout_gamma=0.4,
    readout_reg_avg=False,
    learning_rate=0.001,
    norm="layernorm",
    patch_mode=1,
)
model = model.to('cuda')






1. Creating Tokenizer...
2. Tokenizer created. Output shape: (50, 132, 128)
3. ViViT created with input shape: (50, 132, 128)
4. Spatial shape after patching: h=12, w=11
5. Core output shape: (128, 50, 12, 11)


In [10]:
n_neurons_dict = data_info["n_neurons_dict"]
from openretina.models.core_readout import ViViTCoreReadout
model = ViViTCoreReadout(
    input_shape=(128, 2, 50, 72, 64),
    n_neurons_dict=n_neurons_dict,
    Demb=96,                    # ↓ smaller embedding dimension
    patch_size=8,
    temporal_patch_size=6,
    num_spatial_blocks=2,       # ↓ fewer spatial transformer blocks
    num_temporal_blocks=2,      # ↓ fewer temporal transformer blocks
    num_heads=4,
    mlp_ratio=3.0,              # ↓ smaller MLP expansion ratio
    dropout=0.3,                # ↑ stronger dropout
    pad_frame=True,
    temporal_stride=1,
    spatial_stride=6,
    ptoken=0.2,                 # ↑ stronger token dropout
    readout_bias=True,
    chunk_size = 64,
    readout_init_mu_range=0.05,
    readout_init_sigma_range=0.01,
    readout_gamma=0.4,
    readout_reg_avg=True,       # ↑ regularization on readout weights
    learning_rate=5e-4,         # ↓ smaller learning rate for smoother convergence
    norm="layernorm",
    patch_mode=1,
).to("cuda")


1. Creating Tokenizer...
2. Tokenizer created. Output shape: (50, 132, 96)
3. ViViT created with input shape: (50, 132, 96)
4. Spatial shape after patching: h=12, w=11
5. Core output shape: (96, 50, 12, 11)


In [10]:
from pytorch_lightning.utilities.model_summary import summarize


summary = summarize(model, max_depth=-1)  # full depth
print(summary)

    | Name                                                  | Type                               | Params | Mode 
-----------------------------------------------------------------------------------------------------------------------
0   | core                                                  | ViViTCoreWrapper                   | 1.3 M  | train
1   | core.tokenizer                                        | Tokenizer                          | 98.6 K | train
2   | core.tokenizer.pad                                    | ZeroPad3d                          | 0      | train
3   | core.tokenizer.proj                                   | Conv3d                             | 98.3 K | train
4   | core.tokenizer.norm                                   | LayerNorm                          | 256    | train
5   | core.vivit                                            | ViViT                              | 1.2 M  | train
6   | core.vivit.spatial_transformer                        | Transformer         

In [11]:
import lightning

In [12]:
log_save_path = os.path.join(cfg.paths.output_dir, "notebook_example")
os.makedirs(log_save_path, exist_ok=True)

logger = lightning.pytorch.loggers.TensorBoardLogger(
    name="tensorboard/",
    save_dir=log_save_path,
)

In [13]:
early_stopping = lightning.pytorch.callbacks.EarlyStopping(
    monitor="val_correlation",
    patience=10,
    mode="max",
    verbose=False,
    min_delta=0.001,
)

lr_monitor = lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch")

model_checkpoint = lightning.pytorch.callbacks.ModelCheckpoint(
    monitor="val_correlation", mode="max", save_weights_only=False
)

In [14]:

trainer = lightning.Trainer(max_epochs=100, logger=logger, callbacks=[early_stopping, lr_monitor, model_checkpoint], precision = '16-mixed') #add precision

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [15]:
trainer.fit(model, train_loader, val_loader)

You are using a CUDA device ('NVIDIA A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/bethge/bkr618/.local/lib/python3.13/site-packages/lightning/pytorch/utilities/model_summary/model_summary.py:231: Precision 16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name             | Type                               | Params | Mode 
--------------------------------------------------------------------------------
0 | core             | ViViTCoreWrapper                   | 443 K  | train
1 | readout          | MultiSampledGaussianReadoutWrapper | 323 K  | train
2 | loss             | PoissonLoss3d  

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  regularization_loss_core = self.core.regularizer()


Training: |          | 0/? [00:00<?, ?it/s]

  regularization_loss_core = self.core.regularizer()


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [13]:
torch.cuda.memory_summary(device='cuda', abbreviated=True)




In [14]:
print(torch.cuda.memory_allocated()/1e9)


0.652217344


In [10]:
# Clear cache first
torch.cuda.empty_cache()

# Check before
print("=== BEFORE MODEL ===")
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

# Move model to GPU
model = model.to('cuda')

print("\n=== AFTER MODEL TO GPU ===")
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")


=== BEFORE MODEL ===
Allocated: 0.00 GB
Reserved: 0.00 GB

=== AFTER MODEL TO GPU ===
Allocated: 0.00 GB
Reserved: 0.00 GB


In [21]:
# Create dummy batch
dummy_input = torch.randn(64, 2, 50, 72, 64).to('cuda')

print("\n=== AFTER CREATING INPUT ===")
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")



=== AFTER CREATING INPUT ===
Allocated: 0.13 GB
Reserved: 0.14 GB


In [24]:
# Forward pass
output = model(dummy_input)

print("\n=== AFTER FORWARD PASS ===")
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

OutOfMemoryError: CUDA out of memory. Tried to allocate 226.00 MiB. GPU 0 has a total capacity of 39.39 GiB of which 73.94 MiB is free. Including non-PyTorch memory, this process has 39.31 GiB memory in use. Of the allocated memory 38.73 GiB is allocated by PyTorch, and 94.88 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [20]:
import gc
del model, dummy_input, output  # etc.
gc.collect()
torch.cuda.empty_cache()


In [None]:
# Backward pass
loss = output.sum()
loss.backward()

print("\n=== AFTER BACKWARD PASS ===")
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

NameError: name 'output' is not defined

: 

In [None]:
import gc
from tqdm import tqdm
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

tokenizer = VideoTokenizer(
    img_size=(72, 64),
    patch_size=(8, 8),
    temporal_patch_size=5,
    in_channels=2,
    Demb=128,
    ptoken=0.1
).to(device)

transformer = SpatioTemporalTransformer(
    Demb=128,
    num_spatial_blocks=4,
    num_temporal_blocks=4,
    num_heads=8,
    mlp_ratio=4.0,
    dropout=0.1,
    chunk_size=64
).to(device)

tokenizer.eval()
transformer.eval()

# dictionary for results
outputs_dict = {}

# loop with progress bar
for session_idx, item in enumerate(tqdm(train_loader, desc="Processing sessions", unit="session")):
    inputs = item[1].inputs.to(device)
    session_name = item[0]

    with torch.no_grad():
        embeddings, TP, SP = tokenizer(inputs)
        output = transformer(embeddings, TP, SP)
        output_cpu = output.cpu()

    outputs_dict[session_name] = output_cpu

    del inputs, embeddings, output
    torch.cuda.empty_cache()
    gc.collect()

    torch.cuda.synchronize()  # ensure memory freed before next iteration

# summary
print(f"\nProcessed {len(outputs_dict)} sessions total.")


Processing batches:   0%|          | 0/134 [00:00<?, ?batch/s]

Processing batches: 100%|██████████| 134/134 [01:31<00:00,  1.47batch/s]


Processed 67 batches total.





In [2]:
import gc
import os
import torch
from pathlib import Path
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize the TransformerCoreWrapper
core = TransformerCoreWrapper(
    input_shape = (128, 2, 50, 72, 64),
    in_channels=2,
    img_size=(72, 64),
    patch_size=(8, 8),
    temporal_patch_size=5,
    emb_dim=128,
    ptoken=0.1,
    num_spatial_blocks=4,
    num_temporal_blocks=4,
    num_heads=8,
    mlp_ratio=4.0,
    dropout=0.1,
    chunk_size=64,
    gamma_weights=0.001,
    gamma_attention=0.01,
).to(device)

# Put in eval mode
core.eval()

dummy_input = torch.randn(128, 2, 50, 72, 64).to(device)  # (B, C, T, H, W)

with torch.no_grad():
    output = core(dummy_input)
    print(f"Input shape: {dummy_input.shape}")
    print(f"Output shape: {output.shape}")
    
del dummy_input, output
torch.cuda.empty_cache()
gc.collect()

Input shape: torch.Size([128, 2, 50, 72, 64])
Output shape: torch.Size([128, 128, 10, 9, 8])


177

In [16]:
import gc
import os
from pathlib import Path
from tqdm import tqdm
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
viz_folder = Path("/home/bethge/bkr618/open-retina/core_visualizations")
os.makedirs(viz_folder, exist_ok=True)

core = TransformerCoreWrapper(
    input_shape=(128, 2, 50, 72, 64),
    in_channels=2,
    img_size=(72, 64),
    patch_size=(8, 8),
    temporal_patch_size=5,
    emb_dim=128,
    ptoken=0.1,
    num_spatial_blocks=4,
    num_temporal_blocks=4,
    num_heads=8,
    mlp_ratio=4.0,
    dropout=0.1,
    chunk_size=64,
    gamma_weights=0.001,
    gamma_attention=0.01,
).to(device)

core.eval()

outputs_dict = {}
VISUALIZE_EVERY = 20  # adjust frequency

for session_idx, item in enumerate(tqdm(train_loader, desc="Processing sessions", unit="session")):
    session_name = item[0]
    inputs = item[1].inputs.to(device, non_blocking=True)

    try:
        with torch.no_grad():
            
            output = core(inputs)
            # store CPU copy only
            outputs_dict[session_name] = output.detach().cpu()

    finally:
        # drop GPU refs
        del inputs
        if "output" in locals(): del output
        torch.cuda.empty_cache()
        gc.collect()
        if device == "cuda":
            torch.cuda.synchronize()

    if (session_idx + 1) % VISUALIZE_EVERY == 0:
        core.save_weight_visualizations(
            folder_path=str(viz_folder),
            file_format="png",
            state_suffix=f"_session_{session_idx+1}"
        )
        torch.cuda.empty_cache()
        gc.collect()
        if device == "cuda":
            torch.cuda.synchronize()

print(f"\nProcessed {len(outputs_dict)} sessions total.")



Processing sessions:   0%|          | 0/134 [00:00<?, ?session/s]

Processing sessions:  15%|█▍        | 20/134 [00:21<06:12,  3.26s/session]

Saved transformer visualizations to /home/bethge/bkr618/open-retina/core_visualizations/transformer_visualizations


Processing sessions:  29%|██▉       | 39/134 [00:34<01:07,  1.40session/s]

Saved transformer visualizations to /home/bethge/bkr618/open-retina/core_visualizations/transformer_visualizations


Processing sessions:  44%|████▍     | 59/134 [00:59<00:54,  1.38session/s]

Saved transformer visualizations to /home/bethge/bkr618/open-retina/core_visualizations/transformer_visualizations


Processing sessions:  59%|█████▉    | 79/134 [01:23<00:41,  1.32session/s]

Saved transformer visualizations to /home/bethge/bkr618/open-retina/core_visualizations/transformer_visualizations


Processing sessions:  74%|███████▍  | 99/134 [01:49<00:29,  1.20session/s]

Saved transformer visualizations to /home/bethge/bkr618/open-retina/core_visualizations/transformer_visualizations


Processing sessions:  90%|████████▉ | 120/134 [02:26<00:50,  3.63s/session]

Saved transformer visualizations to /home/bethge/bkr618/open-retina/core_visualizations/transformer_visualizations


Processing sessions: 100%|██████████| 134/134 [02:39<00:00,  1.19s/session]


Processed 67 sessions total.



