# Training Spatial Contrast (SC) Models

This notebook demonstrates how to train the Spatial Contrast model on the Sridhar 2025 marmoset dataset.

The SC model extends the Linear-Nonlinear (LN) model by incorporating a local spatial contrast term:
```
y = nonlinearity(imean + w * lsc)
```

Where:
- `imean` = spatial mean weighted by the spatial filter (from STA)
- `lsc` = local spatial contrast
- `w` = learnable weight for the contrast term
- `nonlinearity(x) = a * log(1 + exp(b*x + c))` with learnable a, b, c

The SC model uses **pre-computed STAs** for the spatial and temporal filters, with only **4 learnable parameters** (a, b, c, w).

In [1]:
import os
import torch

import hydra
import lightning

from openretina.data_io.cyclers import LongCycler
from openretina.data_io.sridhar_2025.dataloader_utils import make_file_name
from openretina.utils.file_utils import get_local_file_path


os.environ["OPENRETINA_CACHE_DIRECTORY"] = "/mnt/big_storage/openretina_cache/"

## Configuration

Load the Hydra configuration for the SC model. Use:
- `sridhar_2025_wn_sc.yaml` for white noise dataset
- `sridhar_2025_nm_sc.yaml` for natural movie dataset

In [None]:
config_name = "sridhar_2025_wn_sc.yaml"  # use sridhar_2025_nm_sc.yaml for the natural movie dataset

with hydra.initialize(config_path="../configs", version_base="1.3"):
    cfg = hydra.compose(config_name=config_name)

# Show the config structure
print("Top-level config keys:", list(cfg.keys()))
print("Model config keys:", list(cfg.model.keys()))

## Dataloader

If this is your first time working with the `sridhar_2025` dataset, when calling the dataloader for the first time the dataset will be downloaded at the `cfg.paths.cache_dir` location.

In [3]:
cfg.paths.cache_dir

'/mnt/big_storage/openretina_cache/'

In [4]:
dataloader = hydra.utils.instantiate(cfg.dataloader)

Random seed 1000 has been set.
train idx: [7 6 9 1 2 8 0 3]
val idx: [ 4 10  5]


## Model Setup

The SC model requires:
- `in_shape`: Input stimulus shape (channels, time, height, width)
- `sta_dir`: Directory containing pre-computed STA files
- `sta_file_name`: Name of the STA file for the specific cell

The STA file naming convention is: `cell_data_{retina_index}_WN_stas_cell_{cell_index}.npy`

In [5]:
# ============================================================
# Customize filter sizes here (access via cfg.model.*)
# ============================================================
# Temporal filter: number of frames to crop from the STA
cfg.model.temporal_crop_frames = 30  # Default: 30

# Spatial crop: size of the spatial patch around the RF center [height, width]
cfg.model.spat_crop_size = [15, 15]  # Default: [15, 15]

# Sigma contour: number of standard deviations for spatial filter mask
cfg.model.sigma_contour = 3.0  # Default: 3.0
# ============================================================

print(f"Temporal crop frames: {cfg.model.temporal_crop_frames}")
print(f"Spatial crop size: {cfg.model.spat_crop_size}")
print(f"Sigma contour: {cfg.model.sigma_contour}")

Temporal crop frames: 30
Spatial crop size: [15, 15]
Sigma contour: 3.0


In [6]:
# Get input shape from dataloader
retina_index = cfg.dataloader.retina_index
cell_index = cfg.dataloader.cell_index
input_shape = next(iter(dataloader["train"][retina_index]))[0].shape[1:]

# Set model configuration (note: access via cfg.model.*)
cfg.model.in_shape = (input_shape[0], cfg.dataloader.num_of_frames, *input_shape[2:])

# Get the actual basepath where the dataset was downloaded
# The dataloader uses get_local_file_path to resolve the dataset location
basepath = get_local_file_path(str(cfg.dataloader.basepath))

# Set STA directory and file name
# The STA files are stored in the 'stas' subdirectory of the downloaded dataset
sta_dir = os.path.join(basepath, cfg.dataloader.sta_dir)
sta_file_name = make_file_name(cell_index, retina_index)

cfg.model.sta_dir = sta_dir
cfg.model.sta_file_name = sta_file_name

# IMPORTANT: The STA must be cropped the same way as the stimulus images!
# The dataloader crops stimulus images using big_crops[retina_index].
# We must apply the same crop to the STA so that RF coordinates match.
if cfg.dataloader.retina_specific_crops:
    sta_crop = list(cfg.dataloader.big_crops[retina_index])
else:
    sta_crop = cfg.dataloader.crop if hasattr(cfg.dataloader, 'crop') else 0
cfg.model.sta_crop = sta_crop

# For natural movie dataset, set flip_sta=True
# cfg.model.flip_sta = True  # Uncomment for natural movie dataset

print(f"Retina index: {retina_index}")
print(f"Cell index: {cell_index}")
print(f"Input shape: {cfg.model.in_shape}")
print(f"Dataset basepath: {basepath}")
print(f"STA directory: {sta_dir}")
print(f"STA file: {sta_file_name}")
print(f"STA crop (must match stimulus crop): {cfg.model.sta_crop}")

# Check if STA file exists
sta_full_path = os.path.join(sta_dir, sta_file_name)
if os.path.exists(sta_full_path):
    print(f"STA file found!")
else:
    print(f"STA file NOT found at: {sta_full_path}")
    print(f"  Please check if STAs are included in the dataset download.")

Retina index: 01
Cell index: 272
Input shape: [1, 30, 80, 90]
Dataset basepath: /mnt/big_storage/openretina_cache/gollisch_lab/sridhar_2025/marmoset/whitenoise
STA directory: /mnt/big_storage/openretina_cache/gollisch_lab/sridhar_2025/marmoset/whitenoise/stas
STA file: cell_data_01_WN_stas_cell_272.npy
STA crop (must match stimulus crop): [35, 35, 55, 55]
STA file found!


In [7]:
# Instantiate the model (note: use cfg.model, not cfg)
model = hydra.utils.instantiate(cfg.model)

print(f"\nModel RF location (from Gaussian fit): {model.rf_location}")
print(f"Gaussian fit success: {model.gaussian_params['success']}")
print(f"Temporal filter length: {model.temporal_filter.shape[0]}")
print(f"Spatial filter shape: {model.spatial_filter.shape}")


Model RF location (from Gaussian fit): (50, 52)
Gaussian fit success: True
Temporal filter length: 30
Spatial filter shape: torch.Size([15, 15])


In [8]:
# Verify the model has exactly 4 learnable parameters
num_learnable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of learnable parameters: {num_learnable}")
print(f"\nLearnable parameters:")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"  {name}: {param.data.item():.4f}")

Number of learnable parameters: 4

Learnable parameters:
  w: 0.0000
  nonlinearity.a: 4.0000
  nonlinearity.b: 1.0000
  nonlinearity.c: -5.0000


## Training Setup

Set up Lightning callbacks and trainer for training the SC model.

In [9]:
os.makedirs(
    cfg.paths.log_dir,
    exist_ok=True,
)

early_stopping = lightning.pytorch.callbacks.EarlyStopping(
    monitor="val_correlation",
    patience=10,
    mode="max",
    verbose=False,
    min_delta=0.001,
)

lr_monitor = lightning.pytorch.callbacks.LearningRateMonitor(logging_interval="epoch")

model_checkpoint = lightning.pytorch.callbacks.ModelCheckpoint(
    monitor="val_correlation", mode="max", save_weights_only=False
)

In [10]:
trainer = lightning.Trainer(
    max_epochs=10000,
    logger=None,
    callbacks=[early_stopping, lr_monitor, model_checkpoint],
    accelerator="gpu",
    log_every_n_steps=10,
)

train_loader = LongCycler(dataloader["train"])
val_loader = LongCycler(dataloader["validation"])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


## Train the Model

In [11]:
trainer.fit(model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type                        | Params | Mode  | FLOPs
--------------------------------------------------------------------------------
0 | loss            | PoissonLoss3d               | 0      | train | 0    
1 | validation_loss | CorrelationLoss3d           | 0      | train | 0    
2 | nonlinearity    | SpatialContrastNonlinearity | 3      | train | 0    
  | other params    | n/a                         | 1      | n/a   | n/a  
--------------------------------------------------------------------------------
4         Trainable params
0         Non-trainable params
4         Total params
0.000     Total estimated model params size (MB)
3         Modules in train mode
0         Modules in eval mode
0         Total Flops


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/shash/miniconda3/envs/openretina/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:317: The number of training batches (8) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

## View Trained Parameters

After training, inspect the learned parameter values.

In [12]:
print("Trained parameters:")
print(f"  w (spatial contrast weight): {model.w.data.item():.4f}")
print(f"  a (output scaling): {model.nonlinearity.a.data.item():.4f}")
print(f"  b (input scaling): {model.nonlinearity.b.data.item():.4f}")
print(f"  c (offset): {model.nonlinearity.c.data.item():.4f}")

Trained parameters:
  w (spatial contrast weight): 0.2819
  a (output scaling): 4.4086
  b (input scaling): 0.4983
  c (offset): -4.3630
