# Import packages & functions

In [1]:
#!pip install ipykernel -U --force-reinstall

In [1]:
# # in a notebook cell
# !pip install --upgrade --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# !pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
# Uncomment and run these lines if you need to install or upgrade packages
# %pip install --upgrade huggingface_hub
# %pip install --upgrade diffusers
# # In a notebook cell
# %pip install --upgrade accelerate
import torch
import torchaudio
# --- ADDED CHECK: Verify CUDA is available early ---
assert torch.cuda.is_available(), "CUDA is not available. Check PyTorch installation and NVIDIA drivers/CUDA Toolkit."
print(f"CUDA is available! Found {torch.cuda.device_count()} CUDA device(s).")
print("PyTorch CUDA version:", torch.version.cuda)
print("TorchAudio version:", torchaudio.__version__)
# --- END ADDED CHECK ---

import diffusers
import huggingface_hub
from huggingface_hub import hf_hub_download # Use the new download API
print("diffusers:", diffusers.__version__)
print("huggingface_hub:", huggingface_hub.__version__)
from diffusers.models.autoencoders.vae import Decoder  # should import with no error

"""CUDA is available! Found 1 CUDA device(s).
PyTorch CUDA version: 12.1
TorchAudio version: 2.1.0+cu121"""

CUDA is available! Found 1 CUDA device(s).
PyTorch CUDA version: 12.1
TorchAudio version: 2.1.0+cu121
diffusers: 0.33.1
huggingface_hub: 0.30.2


'CUDA is available! Found 1 CUDA device(s).\nPyTorch CUDA version: 12.1\nTorchAudio version: 2.1.0+cu121'

In [2]:
import os
import sys

# --- Fix the CUDA runtimex path issue for Lightning.ai ---
# cuda_path = "/system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages/nvidia/cuda_runtime/lib"
# os.environ["LD_LIBRARY_PATH"] = cuda_path + ":" + os.environ.get("LD_LIBRARY_PATH", "")

# Force-load libcudart so Python sees it early
import ctypes
# ctypes.CDLL(os.path.join(cuda_path, "libcudart.so.11.0"))

# --------------------------------------------------------
# Now safe to import other modules
import json
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
import h5py
from tqdm import tqdm
import webdataset as wds
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from accelerate import Accelerator

# SDXL unCLIP requires code from https://github.com/Stability-AI/generative-models/tree/main
sys.path.append('generative_models/')
import sgm
from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder # bigG embedder

# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = False   # prevents long autotune

# custom functions #
import utils



In [3]:
# Cell 4 (Modified)

### Multi-GPU config ###
local_rank = os.getenv('RANK')
if local_rank is None:
    local_rank = 0
else:
    local_rank = int(local_rank)
print("LOCAL RANK ", local_rank)

data_type = torch.float32 # Use full precision
num_devices = torch.cuda.device_count()
if num_devices==0: num_devices = 1

# First use "accelerate config" in terminal and setup using deepspeed stage 2 with CPU offloading!
print("Initializing Accelerator...")
accelerator = Accelerator(split_batches=False, mixed_precision="no") # Disable AMP
accelerator.gradient_clipping = 1.0 # Keep clipping enabled
print(f"Gradient clipping enabled with max_norm={accelerator.gradient_clipping}")
# -----------------------------

# if utils.is_interactive(): # set batch size here if using interactive notebook instead of submitting job
#     global_batch_size = batch_size = 4
# else:
#     global_batch_size_env = os.environ.get("GLOBAL_BATCH_SIZE")
#     if global_batch_size_env is None:
#         print("Warning: GLOBAL_BATCH_SIZE environment variable not set. Using default batch_size=8.")
#         global_batch_size = batch_size = 4
#     else:
#         global_batch_size = int(global_batch_size_env)
#         batch_size = global_batch_size // num_devices

LOCAL RANK  0
Initializing Accelerator...
Gradient clipping enabled with max_norm=1.0


In [4]:
# %%
print("PID of this process =",os.getpid())
device = accelerator.device
# --- ADDED CHECK: Verify accelerator picked a CUDA device ---
print("Accelerator selected device:", device)
assert str(device).startswith("cuda"), f"Accelerator did not select a CUDA device (selected: {device}). Ensure GPU is visible and accelerate is configured correctly."
# --- END ADDED CHECK ---
world_size = accelerator.state.num_processes
distributed = not accelerator.state.distributed_type == 'NO'
# num_devices = torch.cuda.device_count() # Already defined
if num_devices==0 or not distributed: num_devices = 1
# num_workers = num_devices # num_workers for DataLoader, might need separate tuning
num_workers = num_devices
print(f"DataLoader num_workers = {num_workers}")

print(accelerator.state)

print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size, "data_type =", data_type)
print = accelerator.print # only print if local_rank=0

PID of this process = 709045
Accelerator selected device: cuda
DataLoader num_workers = 1
Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: no

distributed = False num_devices = 1 local rank = 0 world size = 1 data_type = torch.float32


# Configurations

In [None]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    model_name = "fmri_model_v1_5ses_50ep"
    print("model_name:", model_name)
    
    # global_batch_size and batch_size should already be defined in the 2nd cell block
    jupyter_args = f"--data_path={os.getcwd()}/data \
                --cache_dir={os.getcwd()}/cache \
                --model_name={model_name} \
                --no-multi_subject --subj=1 --batch_size=8 --num_sessions=5 \
                    --hidden_dim=1024 --clip_scale=1. \
                    --blurry_recon --blur_scale=.5  \
                    --use_prior --prior_scale=30 \
                    --n_blocks=4 --max_lr=3e-4 --mixup_pct=.33 --num_epochs=50 --no-use_image_aug \
                    --ckpt_interval=999 --ckpt_saving --wandb_log"
    # --multisubject_ckpt=../train_logs/multisubject_subj01_1024_24bs_nolow
    # --use_prior --prior_scale=30 \

    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

model_name: fmri_model_v1_1ses_50ep
--data_path=/workspace/MindEyeV2/MindEyeV2/src/data                 --cache_dir=/workspace/MindEyeV2/MindEyeV2/src/cache                 --model_name=fmri_model_v1_1ses_50ep                 --no-multi_subject --subj=1 --batch_size=4 --num_sessions=1                     --hidden_dim=1024 --clip_scale=1.                     --blurry_recon --blur_scale=.5                      --use_prior --prior_scale=30                     --n_blocks=4 --max_lr=3e-4 --mixup_pct=.33 --num_epochs=50 --no-use_image_aug                     --ckpt_interval=999 --ckpt_saving --wandb_log


In [6]:
print(os.getcwd()+"/data")

/workspace/MindEyeV2/MindEyeV2/src/data


In [7]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="testing",
    help="name of model, used for ckpt saving and wandb logging (if enabled)",
)
parser.add_argument(
    "--data_path", type=str, default=os.getcwd()+"/data",
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--cache_dir", type=str, default=os.getcwd(),
    help="Path to where misc. files downloaded from huggingface are stored. Defaults to current src directory.",
)
parser.add_argument(
    "--subj",type=int, default=1, choices=[1,2,3,4,5,6,7,8],
    help="Validate on which subject?",
)
parser.add_argument(
    "--multisubject_ckpt", type=str, default=None,
    help="Path to pre-trained multisubject model to finetune a single subject from. multisubject must be False.",
)
parser.add_argument(
    "--num_sessions", type=int, default=1,
    help="Number of training sessions to include",
)
parser.add_argument(
    "--use_prior",action=argparse.BooleanOptionalAction,default=True,
    help="whether to train diffusion prior (True) or just rely on retrieval part of the pipeline (False)",
)
parser.add_argument(
    "--batch_size", type=int, default=16,
    help="Batch size can be increased by 10x if only training retreival submodule and not diffusion prior",
)
parser.add_argument(
    "--wandb_log",action=argparse.BooleanOptionalAction,default=True,
    help="whether to log to wandb",
)
parser.add_argument(
    "--wandb_project",type=str,default="stability",
    help="wandb project name",
)
parser.add_argument(
    "--mixup_pct",type=float,default=.33,
    help="proportion of way through training when to switch from BiMixCo to SoftCLIP",
)
parser.add_argument(
    "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
    help="whether to output blurry reconstructions",
)
parser.add_argument(
    "--blur_scale",type=float,default=.5,
    help="multiply loss from blurry recons by this number",
)
parser.add_argument(
    "--clip_scale",type=float,default=1.,
    help="multiply contrastive loss by this number",
)
parser.add_argument(
    "--prior_scale",type=float,default=30,
    help="multiply diffusion prior loss by this",
)
parser.add_argument(
    "--use_image_aug",action=argparse.BooleanOptionalAction,default=False,
    help="whether to use image augmentation",
)
parser.add_argument(
    "--num_epochs",type=int,default=3,
    help="number of epochs of training",
)
parser.add_argument(
    "--multi_subject",action=argparse.BooleanOptionalAction,default=False,
)
parser.add_argument(
    "--new_test",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--n_blocks",type=int,default=4,
)
parser.add_argument(
    "--hidden_dim",type=int,default=1024,
)
parser.add_argument(
    "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
)
parser.add_argument(
    "--ckpt_saving",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--ckpt_interval",type=int,default=5,
    help="save backup ckpt and reconstruct every x epochs",
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--max_lr",type=float,default=1e-1,
)

if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)
    
# seed all random functions
utils.seed_everything(seed)

outdir = os.path.abspath(f'../train_logs/{model_name}')
if not os.path.exists(outdir) and ckpt_saving:
    os.makedirs(outdir,exist_ok=True)
    
if use_image_aug or blurry_recon:
    import kornia
    from kornia.augmentation.container import AugmentationSequential
if use_image_aug:
    img_augment = AugmentationSequential(
        kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.3),
        same_on_batch=False,
        data_keys=["input"],
    )
    
if multi_subject:
    subj_list = np.arange(1,9)
    subj_list = subj_list[subj_list != subj]
else:
    subj_list = [subj]

print("subj_list", subj_list, "num_sessions", num_sessions)

subj_list [1] num_sessions 1


# Prep data, models, and dataloaders

### Creating wds dataloader, preload betas and all 73k possible images

In [8]:

def my_split_by_node(urls): return urls
num_voxels_list = []

if multi_subject:
    nsessions_allsubj=np.array([40, 40, 32, 30, 40, 32, 40, 30])
    num_samples_per_epoch = (750*40) // num_devices 
else:
    num_samples_per_epoch = (750*num_sessions) // num_devices 

print("dividing batch size by subj_list, which will then be concatenated across subj during training...") 
batch_size = batch_size // len(subj_list)

num_iterations_per_epoch = num_samples_per_epoch // (batch_size*len(subj_list))

print("batch_size =", batch_size, "num_iterations_per_epoch =",num_iterations_per_epoch, "num_samples_per_epoch =",num_samples_per_epoch)

dividing batch size by subj_list, which will then be concatenated across subj during training...
batch_size = 4 num_iterations_per_epoch = 187 num_samples_per_epoch = 750


In [9]:
train_data = {}
train_dl = {}
num_voxels = {}
voxels = {}
num_voxels_list = [] # Reset this list

for s in subj_list:
    print(f"Processing subj0{s}...")
    print(f"Training with {num_sessions} sessions")
    if multi_subject:
        train_url = f"{data_path}/wds/subj0{s}/train/" + "{0.." + f"{nsessions_allsubj[s-1]-1}" + "}.tar"
    else:
        train_url = f"{data_path}/wds/subj0{s}/train/" + "{0.." + f"{num_sessions-1}" + "}.tar"
    print("Train URL:", train_url)

    # Setup DataLoader (unchanged)
    train_data[f'subj0{s}'] = wds.WebDataset(train_url,resampled=True,nodesplitter=my_split_by_node)\
                        .shuffle(750, initial=1500, rng=random.Random(42))\
                        .decode("torch")\
                        .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
                        .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
    train_dl[f'subj0{s}'] = torch.utils.data.DataLoader(train_data[f'subj0{s}'], batch_size=batch_size, shuffle=False, drop_last=True, pin_memory=True)

    # Load betas (reverted to direct loading)
    beta_file_path = f'{data_path}/betas_all_subj0{s}_fp32_renorm.hdf5'
    print(f"Loading betas from {beta_file_path}")
    try:
        f = h5py.File(beta_file_path, 'r')
        # --- Reverted beta loading ---
        betas = f['betas'][:] # Load all voxels directly
        f.close()
    except Exception as e:
        print(f"ERROR loading {beta_file_path}: {e}")
        raise


    betas = torch.Tensor(betas).to("cpu").to(data_type) # Use full betas, convert dtype
    current_num_voxels = betas.shape[1]
    num_voxels_list.append(current_num_voxels)
    num_voxels[f'subj0{s}'] = current_num_voxels
    voxels[f'subj0{s}'] = betas
    print(f"num_voxels for subj0{s}: {num_voxels[f'subj0{s}']}")

print("Updated num_voxels_list:", num_voxels_list) # Verify the list contains smaller numbers
print("Loaded all subj train dls and subsetted betas!\n")


# Validate only on one subject
if multi_subject: 
    subj = subj_list[0] # cant validate on the actual held out person so picking first in subj_list
if not new_test: # using old test set from before full dataset released (used in original MindEye paper)
    if subj==3:
        num_test=2113
    elif subj==4:
        num_test=1985
    elif subj==6:
        num_test=2113
    elif subj==8:
        num_test=1985
    else:
        num_test=2770
    test_url = f"{data_path}/wds/subj0{subj}/test/" + "0.tar"
elif new_test: # using larger test set from after full dataset released
    if subj==3:
        num_test=2371
    elif subj==4:
        num_test=2188
    elif subj==6:
        num_test=2371
    elif subj==8:
        num_test=2188
    else:
        num_test=3000
    test_url = f"{data_path}/wds/subj0{subj}/new_test/" + "0.tar"
print(test_url)
test_data = wds.WebDataset(test_url,resampled=False,nodesplitter=my_split_by_node)\
                    .shuffle(750, initial=1500, rng=random.Random(42))\
                    .decode("torch")\
                    .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy", olds_behav="olds_behav.npy")\
                    .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
test_dl = torch.utils.data.DataLoader(test_data, batch_size=num_test, shuffle=False, drop_last=True, pin_memory=True)
print(f"Loaded test dl for subj{subj}!\n")

Processing subj01...
Training with 1 sessions
Train URL: /workspace/MindEyeV2/MindEyeV2/src/data/wds/subj01/train/{0..0}.tar
Loading betas from /workspace/MindEyeV2/MindEyeV2/src/data/betas_all_subj01_fp32_renorm.hdf5
num_voxels for subj01: 15724
Updated num_voxels_list: [15724]
Loaded all subj train dls and subsetted betas!

/workspace/MindEyeV2/MindEyeV2/src/data/wds/subj01/new_test/0.tar
Loaded test dl for subj1!



In [10]:
# Load 73k NSD images
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
images = f['images']
print("Loaded all 73k possible NSD images to cpu!", images.shape)

Loaded all 73k possible NSD images to cpu! (73000, 3, 224, 224)


## Load models

### CLIP image embeddings  model

In [11]:
# Cell 16 (Modified to use ViT-L-14)
os.environ["HF_HOME"] = "/workspace/hf_cache"   # or set HF_HUB_CACHE
print("Initializing **SMALLER** FrozenOpenCLIPImageEmbedder (ViT-L-14)...")

# In Cell 16, where you initialize the CLIP embedder

clip_img_embedder = FrozenOpenCLIPImageEmbedder(
    arch="ViT-B-32",
    version="openai",
    output_tokens=True,
    only_tokens=True,
)

# Convert the CLIP model to float32 explicitly BEFORE moving to device
clip_img_embedder.to(data_type)
print(f"CLIP model explicitly set to {data_type}")

# Now move to device
clip_img_embedder.to(device)
# --- CORRECTED CLIP dimensions based on ACTUAL output ---
clip_seq_dim = 49    # Reduced from 256 (7x7 grid for 32px patches)
clip_emb_dim = 768   # Reduced from 514 or something (base CLIP size)
total_dim = clip_seq_dim * clip_emb_dim 
print(f"Total CLIP feature dimension: {total_dim}")
# ----------------------------------------------------------

# param_size = sum(p.numel() * p.element_size() for p in clip_img_embedder.parameters()) / 1024**2
# print(f"CLIP Model Size: {param_size:.1f} MB")

# Add this after CLIP embedder initialization
accelerator.print("Testing CLIP output dimensions...")
print("Testing CLIP output dimensions...")
with torch.no_grad():
    test_img = torch.randn(1, 3, 224, 224).to(device)
    test_out = clip_img_embedder(test_img)
    print(f"CLIP output shape: {test_out.shape}")
    actual_seq_dim = test_out.shape[1]
    actual_emb_dim = test_out.shape[2]
    print(f"Actual seq_dim: {actual_seq_dim}")
    print(f"Actual emb_dim: {actual_emb_dim}")

# --- IMPORTANT: Re-run the cell defining model.backbone (Cell 20) ---
# If the BrainNetwork's `out_dim` depends on `clip_emb_dim * clip_seq_dim`,
# it needs to be re-initialized with the *new* smaller dimensions from ViT-L.
# Otherwise, you might get shape mismatch errors later.


Initializing **SMALLER** FrozenOpenCLIPImageEmbedder (ViT-L-14)...
CLIP model explicitly set to torch.float32
Total CLIP feature dimension: 37632
Testing CLIP output dimensions...
Testing CLIP output dimensions...
CLIP output shape: torch.Size([1, 49, 768])
Actual seq_dim: 49
Actual emb_dim: 768


## SD VAE

In [12]:
if blurry_recon:
    from diffusers import AutoencoderKL    
    autoenc = AutoencoderKL(
        down_block_types=['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'],
        up_block_types=['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'],
        block_out_channels=[128, 256, 512, 512],
        layers_per_block=2,
        sample_size=256,
    )
    print(f'{cache_dir}')
    ckpt = torch.load(f'{cache_dir}/sd_image_var_autoenc.pth')
    autoenc.load_state_dict(ckpt)
    
    autoenc.eval()
    autoenc.requires_grad_(False)
    autoenc.to(device)
    utils.count_params(autoenc)
    
    from autoencoder.convnext import ConvnextXL
    cnx = ConvnextXL(f'{cache_dir}/convnext_xlarge_alpha0.75_fullckpt.pth')
    cnx.requires_grad_(False)
    cnx.eval()
    cnx.to(device)
    
    mean = torch.tensor([0.485, 0.456, 0.406]).to(device).reshape(1,3,1,1)
    std = torch.tensor([0.228, 0.224, 0.225]).to(device).reshape(1,3,1,1)
    
    blur_augs = AugmentationSequential(
        kornia.augmentation.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8),
        kornia.augmentation.RandomGrayscale(p=0.1),
        kornia.augmentation.RandomSolarize(p=0.1),
        kornia.augmentation.RandomResizedCrop((224,224), scale=(.9,.9), ratio=(1,1), p=1.0),
        data_keys=["input"],
    )

/workspace/MindEyeV2/MindEyeV2/src/cache
param counts:
83,653,863 total
0 trainable


### MindEye modules

In [13]:
class MindEyeModule(nn.Module):
    def __init__(self):
        super(MindEyeModule, self).__init__()
    def forward(self, x):
        return x
        
model = MindEyeModule()
model

MindEyeModule()

In [14]:
class RidgeRegression(torch.nn.Module):
    # make sure to add weight_decay when initializing optimizer to enable regularization
    def __init__(self, input_sizes, out_features): 
        super(RidgeRegression, self).__init__()
        self.out_features = out_features
        self.linears = torch.nn.ModuleList([
                torch.nn.Linear(input_size, out_features) for input_size in input_sizes
            ])
    def forward(self, x, subj_idx):
        out = self.linears[subj_idx](x[:,0]).unsqueeze(1)
        return out
        
model.ridge = RidgeRegression(num_voxels_list, out_features=hidden_dim)
utils.count_params(model.ridge)
utils.count_params(model)

# test on subject 1 with fake data
b = torch.randn((2,1,num_voxels_list[0]))
print(b.shape, model.ridge(b,0).shape)

param counts:
16,102,400 total
16,102,400 trainable
param counts:
16,102,400 total
16,102,400 trainable
torch.Size([2, 1, 15724]) torch.Size([2, 1, 1024])


In [15]:
# Cell 20 (Re-run this cell AFTER running the modified Cell 16)

from models import BrainNetwork
# from diffusers.models.autoencoders.vae import Decoder # Already imported if needed

# This initialization now uses clip_emb_dim = 1024 and clip_seq_dim = 256
# Ensure BrainNetwork's out_dim calculation uses these updated global variables correctly
model.backbone = BrainNetwork(h=hidden_dim, in_dim=hidden_dim, seq_len=1, n_blocks=n_blocks,
                          clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim,
                          blurry_recon=blurry_recon, clip_scale=clip_scale)

utils.count_params(model.backbone)
utils.count_params(model)

# test that the model works on some fake data
# The backbone output shape should now reflect the 1024 dimension
b = torch.randn((2,1,hidden_dim))
print("b.shape",b.shape)

# backbone_ output shape: (batch, seq_len, emb_dim) -> (2, 256, 1024)
# clip_ output shape: (batch, seq_len, emb_dim) -> (2, 256, 1024)
backbone_, clip_, blur_ = model.backbone(b)
print("Backbone Output Shape:", backbone_.shape)
print("Clip Output Shape:", clip_.shape)
if blurry_recon:
    print("Blurry Recon Shapes:", blur_[0].shape, blur_[1].shape)
else:
    print("Blurry recon disabled.")

param counts:
54,279,036 total
54,279,036 trainable
param counts:
70,381,436 total
70,381,436 trainable
b.shape torch.Size([2, 1, 1024])
Backbone Output Shape: torch.Size([2, 49, 768])
Clip Output Shape: torch.Size([2, 49, 768])
Blurry Recon Shapes: torch.Size([2, 4, 28, 28]) torch.Size([2, 49, 512])


### Adding diffusion prior + unCLIP if use_prior=True

In [16]:
if use_prior:
    from models import *

    # setup diffusion prior network
    out_dim = clip_emb_dim
    depth = 6
    dim_head = 52
    heads = clip_emb_dim//52 # heads * dim_head = clip_emb_dim
    timesteps = 100

    prior_network = PriorNetwork(
            dim=out_dim,
            depth=depth,
            dim_head=dim_head,
            heads=heads,
            causal=False,
            num_tokens = clip_seq_dim,
            learned_query_mode="pos_emb"
        )

    model.diffusion_prior = BrainDiffusionPrior(
        net=prior_network,
        image_embed_dim=out_dim,
        condition_on_text_encodings=False,
        timesteps=timesteps,
        cond_drop_prob=0.2,
        image_embed_scale=None,
    )
    
    utils.count_params(model.diffusion_prior)
    utils.count_params(model)

param counts:
55,096,640 total
55,096,624 trainable
param counts:
125,478,076 total
125,478,060 trainable


### Setup optimizer / lr / ckpt saving

In [17]:
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

opt_grouped_parameters = [
    {'params': [p for n, p in model.ridge.named_parameters()], 'weight_decay': 1e-2},
    {'params': [p for n, p in model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
    {'params': [p for n, p in model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
]
if use_prior:
    opt_grouped_parameters.extend([
        {'params': [p for n, p in model.diffusion_prior.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
        {'params': [p for n, p in model.diffusion_prior.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ])

optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)
#print("Trying SGD optimizer...")
#optimizer = torch.optim.SGD(opt_grouped_parameters, lr=1e-3) # Try SGD

if lr_scheduler_type == 'linear':
    lr_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer,
        total_iters=int(np.floor(num_epochs*num_iterations_per_epoch)),
        last_epoch=-1
    )
elif lr_scheduler_type == 'cycle':
    total_steps=int(np.floor(num_epochs*num_iterations_per_epoch))
    print("total_steps", total_steps)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=max_lr,
        total_steps=total_steps,
        final_div_factor=1000,
        last_epoch=-1, pct_start=2/num_epochs
    )
    
def save_ckpt(tag):
    ckpt_path = outdir+f'/{tag}.pth'
    if accelerator.is_main_process:
        unwrapped_model = accelerator.unwrap_model(model)
        torch.save({
            'epoch': epoch,
            'model_state_dict': unwrapped_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'train_losses': losses,
            'test_losses': test_losses,
            'lrs': lrs,
            }, ckpt_path)
    print(f"\n---saved {outdir}/{tag} ckpt!---\n")

def load_ckpt(tag,load_lr=True,load_optimizer=True,load_epoch=True,strict=True,outdir=outdir,multisubj_loading=False): 
    print(f"\n---loading {outdir}/{tag}.pth ckpt---\n")
    checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
    state_dict = checkpoint['model_state_dict']
    if multisubj_loading: # remove incompatible ridge layer that will otherwise error
        state_dict.pop('ridge.linears.0.weight',None)
    model.load_state_dict(state_dict, strict=strict)
    if load_epoch:
        globals()["epoch"] = checkpoint['epoch']
        print("Epoch",epoch)
    if load_optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if load_lr:
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    del checkpoint

print("\nDone with model preparations!")
num_params = utils.count_params(model)

total_steps 9350

Done with model preparations!
param counts:
125,478,076 total
125,478,060 trainable


# Weights and Biases

In [18]:
import wandb
global_batch_size = batch_size
wandb.init(project="fmri-small")
if local_rank==0 and wandb_log: # only use main process for wandb logging
    import wandb
    wandb_project = 'mindeye'
    print(f"wandb {wandb_project} run {model_name}")
    # need to configure wandb beforehand in terminal with "wandb init"!
    wandb_config = {
      "model_name": model_name,
      "global_batch_size": global_batch_size,
      "batch_size": batch_size,
      "num_epochs": num_epochs,
      "num_sessions": num_sessions,
      "num_params": num_params,
      "clip_scale": clip_scale,
      "prior_scale": prior_scale,
      "blur_scale": blur_scale,
      "use_image_aug": use_image_aug,
      "max_lr": max_lr,
      "mixup_pct": mixup_pct,
      "num_samples_per_epoch": num_samples_per_epoch,
      "num_test": num_test,
      "ckpt_interval": ckpt_interval,
      "ckpt_saving": ckpt_saving,
      "seed": seed,
      "distributed": distributed,
      "num_devices": num_devices,
      "world_size": world_size,
      "train_url": train_url,
      "test_url": test_url,
    }
    print("wandb_config:\n",wandb_config)
    print("wandb_id:",model_name)
    wandb.init(
        id=model_name,
        project=wandb_project,
        name=model_name,
        config=wandb_config,
        resume="allow",
    )
else:
    wandb_log = False

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malbst[0m ([33mfranek-liszka-it-university-of-copenhagen[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


wandb mindeye run fmri_model_v1_1ses_50ep
wandb_config:
 {'model_name': 'fmri_model_v1_1ses_50ep', 'global_batch_size': 4, 'batch_size': 4, 'num_epochs': 50, 'num_sessions': 1, 'num_params': 125478060, 'clip_scale': 1.0, 'prior_scale': 30.0, 'blur_scale': 0.5, 'use_image_aug': False, 'max_lr': 0.0003, 'mixup_pct': 0.33, 'num_samples_per_epoch': 750, 'num_test': 3000, 'ckpt_interval': 999, 'ckpt_saving': True, 'seed': 42, 'distributed': False, 'num_devices': 1, 'world_size': 1, 'train_url': '/workspace/MindEyeV2/MindEyeV2/src/data/wds/subj01/train/{0..0}.tar', 'test_url': '/workspace/MindEyeV2/MindEyeV2/src/data/wds/subj01/new_test/0.tar'}
wandb_id: fmri_model_v1_1ses_50ep


# Main

In [19]:
epoch = 0
losses, test_losses, lrs = [], [], []
best_test_loss = 1e9
torch.cuda.empty_cache()

In [20]:
# load multisubject stage1 ckpt if set
if multisubject_ckpt is not None:
    load_ckpt("last",outdir=multisubject_ckpt,load_lr=False,load_optimizer=False,load_epoch=False,strict=False,multisubj_loading=True)

In [21]:
# Cell 32 (Modified to prepare DataLoaders separately)

# Create the list of standard PyTorch DataLoaders first
original_train_dls = [train_dl[f'subj0{s}'] for s in subj_list]
print(f"Created list of {len(original_train_dls)} PyTorch DataLoaders.")

# --- Ensure model is in the correct dtype BEFORE preparing ---
print(f"Explicitly converting model to {data_type} before prepare...")
model.to(data_type)
# -------------------------------------------------------------

# --- Prepare the model, optimizer, and scheduler ---
print("Preparing model, optimizer, and scheduler with Accelerator...")
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
print("Model, optimizer, and scheduler prepared.")
# -------------------------------------------------------------

# --- Prepare each DataLoader individually ---
print(f"Preparing {len(original_train_dls)} DataLoader(s) individually...")
prepared_train_dls = []
for i, dl in enumerate(original_train_dls):
    prepared_train_dls.append(accelerator.prepare(dl))
    print(f"  DataLoader {i+1} prepared. Type: {type(prepared_train_dls[-1])}")

# Assign the list of *prepared* DataLoaders back to the variable used in the training loop
train_dls = prepared_train_dls
# -------------------------------------------------------------

print("\nAll components prepared.")
# leaving out test_dl since we will only have local_rank 0 device do evals

Created list of 1 PyTorch DataLoaders.
Explicitly converting model to torch.float32 before prepare...
Preparing model, optimizer, and scheduler with Accelerator...
Model, optimizer, and scheduler prepared.
Preparing 1 DataLoader(s) individually...
  DataLoader 1 prepared. Type: <class 'accelerate.data_loader.DataLoaderDispatcher'>

All components prepared.


In [None]:
# Cell 34 (Modified with Preloading Debugging AND Gradient Inspection)

print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
progress_bar = tqdm(range(epoch,num_epochs), ncols=1200, disable=(local_rank!=0))
test_image, test_voxel = None, None
mse = nn.MSELoss()
l1 = nn.L1Loss()
soft_loss_temps = utils.cosine_anneal(0.004, 0.0075, num_epochs - int(mixup_pct * num_epochs))

for epoch in progress_bar:
    epoch_start_time = time.time()
    model.train()

    # --- Reset counters for the epoch ---
    fwd_percent_correct = 0.
    bwd_percent_correct = 0.
    loss_clip_total = 0.
    # Reset other counters if added back based on config
    loss_blurry_total = 0.
    loss_blurry_cont_total = 0.
    loss_prior_total = 0.
    recon_cossim = 0.
    recon_mse = 0.
    blurry_pixcorr = 0.

    # --- Data Preloading Loop ---
    print(f"\n--- Starting Epoch {epoch} Preloading ---")
    voxel_iters = {} # empty dict because diff subjects have differing # of voxels
    # Ensure image_iters uses the correct data_type for storage from the start
    image_iters = torch.zeros(num_iterations_per_epoch, batch_size*len(subj_list), 3, 224, 224, dtype=data_type)
    annot_iters = {}
    perm_iters, betas_iters, select_iters = {}, {}, {}
    effective_iters_processed = 0 # Counter for successfully processed iterations

    # --- Loop over subjects/dataloaders ---
    for s, train_dl in enumerate(train_dls):
        subject_id = subj_list[s] # Get the actual subject number
        print(f"  Preloading for subj0{subject_id}...")
        # Use try-finally to ensure we see print statements even if loop breaks early
        try:
            # iter = -1 # Reset iter concept - we rely on effective_iters_processed now
            # --- Loop over batches in dataloader ---
            for batch_idx, (behav0, past_behav0, future_behav0, old_behav0) in enumerate(train_dl):
                # Check if we have already preloaded enough valid iterations
                if effective_iters_processed >= num_iterations_per_epoch:
                    print(f"    subj0{subject_id}: Reached max iterations ({num_iterations_per_epoch}). Breaking inner batch loop.")
                    break # Exit inner batch loop for this subject

                # Load image indices and check for duplicates
                image_idx = behav0[:,0,0].cpu().long().numpy()
                image0, image_sorted_idx = np.unique(image_idx, return_index=True)

                # --- <<< APPLY ORIGINAL SKIP LOGIC >>> ---
                # Skip if duplicates are found OR if drop_last=True and batch isn't full
                # Note: drop_last=True in DataLoader definition is still recommended!
                if len(image0) != len(image_idx) or len(image_idx) != batch_size:
                    if len(image0) != len(image_idx):
                        print(f"      Duplicate image indices found in batch {batch_idx}. Skipping.")
                    else: # Only possible if drop_last=False and it's the last batch
                        print(f"      Incomplete batch {batch_idx} (size {len(image_idx)}). Skipping.")
                    continue # Skip this batch entirely
                # --- <<< END SKIP LOGIC >>> ---

                # If we reach here, the batch has batch_size unique images.

                current_iter_index = effective_iters_processed # Use the counter for dictionary keys

                # --- Load Image Data ---
                try:
                    # Ensure indices are within bounds (using unique image0)
                    max_index = images.shape[0] - 1
                    if np.any(image0 >= images.shape[0]) or np.any(image0 < 0):
                        print(f"      ERROR: Image indices out of bounds! Min: {np.min(image0)}, Max: {np.max(image0)}, Allowed: 0-{max_index}. Skipping batch {batch_idx}.")
                        continue
                    image_data_np = images[image0] # Use unique indices image0
                    image_data_torch = torch.tensor(image_data_np, dtype=data_type)
                except Exception as e:
                    print(f"      ERROR loading images for batch {batch_idx}, indices {image0}: {e}. Skipping.")
                    continue

                # Store image data using the effective iteration index
                image_iters[current_iter_index, s*batch_size : s*batch_size + batch_size] = image_data_torch

                # --- Load Voxel Data ---
                voxel_idx = behav0[:,0,5].cpu().long().numpy()
                # Index voxels corresponding to the *unique* images loaded
                voxel_sorted_idx = voxel_idx[image_sorted_idx]
                try:
                    max_voxel_idx = voxels[f'subj0{subject_id}'].shape[0] - 1
                    if np.any(voxel_sorted_idx >= voxels[f'subj0{subject_id}'].shape[0]) or np.any(voxel_sorted_idx < 0):
                        print(f"      ERROR: Voxel indices out of bounds! Min: {np.min(voxel_sorted_idx)}, Max: {np.max(voxel_sorted_idx)}, Allowed: 0-{max_voxel_idx}. Skipping batch {batch_idx}.")
                        continue
                    voxel0_np = voxels[f'subj0{subject_id}'][voxel_sorted_idx]
                    # Voxel0 should now reliably have batch_size dimension
                    voxel0 = torch.Tensor(voxel0_np).unsqueeze(1)
                except Exception as e:
                    print(f"      ERROR loading voxels for batch {batch_idx}, indices {voxel_sorted_idx}: {e}. Skipping.")
                    continue

                # --- Mixco Logic ---
                # voxel0 batch dim is now guaranteed to be batch_size
                if epoch < int(mixup_pct * num_epochs):
                    voxel0, perm, betas, select = utils.mixco(voxel0) # perm/betas will have size batch_size
                    # Store using effective iteration index
                    perm_iters[f"subj0{subject_id}_iter{current_iter_index}"] = perm
                    betas_iters[f"subj0{subject_id}_iter{current_iter_index}"] = betas
                    select_iters[f"subj0{subject_id}_iter{current_iter_index}"] = select

                # --- Store Voxel Data ---
                voxel_key = f"subj0{subject_id}_iter{current_iter_index}"
                voxel_iters[voxel_key] = voxel0

                # Increment the counter *only* after successfully processing and storing
                effective_iters_processed += 1

            # End of inner batch loop for one dataloader
            # Check again if we've processed enough iterations overall before moving to next subject's DL
            if effective_iters_processed >= num_iterations_per_epoch:
                print(f"  Preloading complete for epoch after processing subj0{subject_id}. Total iters: {effective_iters_processed}")
                break # Exit outer subject loop

        finally:
            # This print might be less informative now, relies on effective_iters_processed
            print(f"  Finished preloading attempts for subj0{subject_id}.")


    print(f"--- Epoch {epoch} Preloading Complete ---")
    # Ensure we actually processed the expected number of iterations
    if effective_iters_processed < num_iterations_per_epoch:
        print(f"  WARNING: Only preloaded {effective_iters_processed} / {num_iterations_per_epoch} iterations due to skipped batches.")
        # You might need to adjust num_iterations_per_epoch used in the training loop below
        # num_iterations_per_epoch = effective_iters_processed # Risky if running multi-subject/multi-gpu
        # Safer to just proceed, the training loop might throw KeyError if it tries to access missing iters
    print(f"  Total voxel iters stored: {len(voxel_iters)}") # Should match effective_iters_processed


    # --- Main Training Loop ---
    print(f"\n--- Starting Epoch {epoch} Training Loop ---")
    nan_or_inf_grad_detected_this_epoch = False

    # Use the *original* num_iterations_per_epoch calculated, but be aware of potential KeyErrors if not enough iters were preloaded
    for train_i in range(num_iterations_per_epoch):
        try:
            # --- Autocast and Optimizer Zero Grad ---
            with torch.cuda.amp.autocast(dtype=data_type):
                optimizer.zero_grad()
                loss=torch.tensor(0.0, device=device)

                # --- Load Batch Data ---
                voxel_list = []
                for si, s_num in enumerate(subj_list):
                    voxel_key = f"subj0{s_num}_iter{train_i}" # Use train_i directly
                    if voxel_key not in voxel_iters:
                        # This might happen if preloading skipped too many batches
                        print(f"\nFATAL ERROR: Missing preloaded key '{voxel_key}' during training loop!")
                        print(f"  Check preloading warnings. Effective preloaded iters: {effective_iters_processed}")
                        raise KeyError(f"Preloaded key '{voxel_key}' missing!")
                    voxel_list.append(voxel_iters[voxel_key].detach().to(device)) # Should always have batch_size dim now

                # Determine actual batch size for this iteration (should be batch_size * len(subj_list))
                current_iter_batch_size = voxel_list[0].shape[0] * len(subj_list)

                # Load corresponding image data
                image = image_iters[train_i, :current_iter_batch_size].detach().to(device) # Should align
                # print(f"  Training Iter {train_i}: Data moved to device.") # Optional: verbose

                # --- Image Augmentation ---
                if use_image_aug:
                    image = img_augment(image)

                # --- Get CLIP Target ---
                clip_target = clip_img_embedder(image)
                assert not torch.any(torch.isnan(clip_target)), f"NaN detected in clip_target at iter {train_i}"
                assert torch.isfinite(clip_target).all(), f"Inf detected in clip_target at iter {train_i}"

                # --- Prepare Mixco Inputs ---
                if epoch < int(mixup_pct * num_epochs):
                    perm_list = []
                    betas_list = []
                    select_list = []
                    for s_num in subj_list:
                        perm_key = f"subj0{s_num}_iter{train_i}"
                        # Check keys exist (important if preloading was incomplete)
                        if perm_key not in perm_iters: raise KeyError(f"Missing Mixco perm key: {perm_key}")
                        betas_key = perm_key.replace("perm","betas")
                        if betas_key not in betas_iters: raise KeyError(f"Missing Mixco betas key: {betas_key}")
                        select_key = perm_key.replace("perm","select")
                        if select_key not in select_iters: raise KeyError(f"Missing Mixco select key: {select_key}")
                        # Append tensors (should always have batch_size dimension)
                        perm_list.append(perm_iters[perm_key].detach().to(device))
                        betas_list.append(betas_iters[betas_key].detach().to(device))
                        select_list.append(select_iters[select_key].detach().to(device))
                    perm = torch.cat(perm_list, dim=0)   # Final size should match feature batch size
                    betas = torch.cat(betas_list, dim=0)
                    select = torch.cat(select_list, dim=0)

                # --- Model Forward Pass ---
                voxel_ridge_list = [model.ridge(voxel_list[si],si) for si,s in enumerate(subj_list)]
                voxel_ridge = torch.cat(voxel_ridge_list, dim=0) # Will have batch_size dimension
                backbone, clip_voxels, blurry_image_enc_ = model.backbone(voxel_ridge) # Features will have batch_size dimension
                assert not torch.any(torch.isnan(clip_voxels)), f"NaN detected in clip_voxels (backbone output) at iter {train_i}"
                assert torch.isfinite(clip_voxels).all(), f"Inf detected in clip_voxels (backbone output) at iter {train_i}"

                # --- Normalize Features & Calculate Loss ---
                if clip_scale > 0:
                    clip_voxels_norm = nn.functional.normalize(clip_voxels.flatten(1), dim=-1, p=2.0, eps=1e-12) # Size [batch_size * Nsubj, feats]
                    clip_target_norm = nn.functional.normalize(clip_target.flatten(1), dim=-1, p=2.0, eps=1e-12) # Size [batch_size * Nsubj, feats]
                    assert not torch.any(torch.isnan(clip_voxels_norm)), f"NaN detected after normalizing clip_voxels at iter {train_i}"
                    assert not torch.any(torch.isnan(clip_target_norm)), f"NaN detected after normalizing clip_target at iter {train_i}"
                    assert torch.isfinite(clip_voxels_norm).all(), f"Inf detected after normalizing clip_voxels at iter {train_i}"
                    assert torch.isfinite(clip_target_norm).all(), f"Inf detected after normalizing clip_target at iter {train_i}"

                    if epoch < int(mixup_pct * num_epochs):
                        current_temp = 0.006
                        loss_clip = utils.mixco_nce(clip_voxels_norm, clip_target_norm, temp=current_temp, perm=perm, betas=betas, select=select)
                    else:
                        epoch_temp = soft_loss_temps[epoch-int(mixup_pct*num_epochs)]
                        current_temp = epoch_temp
                        loss_clip = utils.soft_clip_loss(clip_voxels_norm, clip_target_norm, temp=epoch_temp)

                    assert not torch.any(torch.isnan(loss_clip)), f"NaN detected in calculated loss_clip at iter {train_i} with temp={current_temp}"
                    assert torch.isfinite(loss_clip).all(), f"Inf detected in calculated loss_clip at iter {train_i} with temp={current_temp}"

                    loss_clip_total += loss_clip.item()
                    loss_clip = loss_clip * clip_scale
                    loss = loss + loss_clip

                    # Accuracy calculation moved here as it depends on loss_clip components
                    labels = torch.arange(len(clip_voxels_norm)).to(clip_voxels_norm.device)
                    fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_voxels_norm, clip_target_norm), labels, k=1).item()
                    bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_target_norm, clip_voxels_norm), labels, k=1).item()


                # --- Blurry Recon Loss (if enabled) ---
                if blurry_recon:
                    image_enc_pred, transformer_feats = blurry_image_enc_
                    # ... (rest of blurry recon logic, ensure it's compatible with current batch size) ...
                    # ... (Need to handle `autoenc`, `cnx`, `blur_augs` potentially being None if disabled) ...
                    if autoenc is not None and cnx is not None: # Check if models are loaded
                        image_enc = autoenc.encode(2*image-1).latent_dist.mode() * 0.18215
                        loss_blurry = l1(image_enc_pred, image_enc)
                        loss_blurry_total += loss_blurry.item()

                        image_norm = (image - mean)/std
                        image_aug = (blur_augs(image) - mean)/std
                        _, cnx_embeds = cnx(image_norm)
                        _, cnx_aug_embeds = cnx(image_aug)

                        cont_loss = utils.soft_cont_loss(
                            nn.functional.normalize(transformer_feats.reshape(-1, transformer_feats.shape[-1]), dim=-1),
                            nn.functional.normalize(cnx_embeds.reshape(-1, cnx_embeds.shape[-1]), dim=-1),
                            nn.functional.normalize(cnx_aug_embeds.reshape(-1, cnx_embeds.shape[-1]), dim=-1),
                            temp=0.2)
                        loss_blurry_cont_total += cont_loss.item()

                        loss_blurry_combined = (loss_blurry + 0.1*cont_loss) * blur_scale
                        loss = loss + loss_blurry_combined

                        # Pixcorr calculation (if needed)
                        with torch.no_grad():
                            random_samps = np.random.choice(np.arange(len(image)), size=max(1, len(image)//5), replace=False) # Ensure at least 1 sample
                            blurry_recon_images = (autoenc.decode(image_enc_pred[random_samps]/0.18215).sample/ 2 + 0.5).clamp(0,1)
                            pixcorr = utils.pixcorr(image[random_samps], blurry_recon_images)
                            blurry_pixcorr += pixcorr.item()

                # --- Diffusion Prior Loss (if enabled) ---
                if use_prior:
                     if model.diffusion_prior is not None: # Check if model is loaded
                          loss_prior, prior_out = model.diffusion_prior(text_embed=backbone, image_embed=clip_target)
                          loss_prior_total += loss_prior.item()
                          loss_prior = loss_prior * prior_scale
                          loss = loss + loss_prior

                          # Recon metrics calculation (if needed)
                          recon_cossim += nn.functional.cosine_similarity(prior_out, clip_target).mean().item()
                          recon_mse += mse(prior_out, clip_target).item()

            # --- End of Autocast ---

            # --- Final Loss Check ---
            # print(f"[DEBUG train_i={train_i}] Final Loss before check: {loss.item():.4f}") # Optional
            utils.check_loss(loss) # Check loss is finite before backward

            # --- Backward Pass ---
            # print(f"[DEBUG train_i={train_i}] Starting backward pass...") # Optional
            accelerator.backward(loss)
            # print(f"[DEBUG train_i={train_i}] Backward pass complete.") # Optional

            # --- **** GRADIENT INSPECTION & CONDITIONAL OPTIMIZER STEP **** ---
            # print(f"[DEBUG train_i={train_i}] Inspecting gradients before optimizer step...") # Optional: Verbose
            found_inf_or_nan_grad = False
            # Clip gradients before checking/stepping if clipping enabled
            # Note: accelerator.backward might handle clipping if configured, but explicit clip_grad_norm_ is safer assurance
            if accelerator.gradient_clipping is not None:
                 # accelerator.clip_grad_norm_ might need to be called if not automatic
                 # For now, let's assume backward handles it or inspect first
                 pass # Placeholder - rely on inspection first

            for name, param in model.named_parameters():
                if param.grad is None:
                    continue
                if not torch.isfinite(param.grad).all():
                    print(f"  WARNING: Found non-finite gradients (NaN or Inf) in param: {name} at iter {train_i}")
                    found_inf_or_nan_grad = True
                    nan_or_inf_grad_detected_this_epoch = True # Set epoch flag
                    break # Stop checking after first bad grad found

            if not found_inf_or_nan_grad:
                 # print(f"[DEBUG train_i={train_i}] Gradients seem finite. Proceeding with optimizer step...") # Optional
                 optimizer.step() # Only step if gradients are valid
            else:
                 print(f"  Skipping optimizer step for iter {train_i} due to non-finite gradients.")

            optimizer.zero_grad() # Zero grad regardless of step
            # --- **** END GRADIENT INSPECTION & CONDITIONAL OPTIMIZER STEP **** ---


            # --- Logging ---
            losses.append(loss.item()) # Log loss even if step was skipped
            lrs.append(optimizer.param_groups[0]['lr'])

            if lr_scheduler_type is not None:
                lr_scheduler.step()

        # --- Error Catching ---
        except KeyError as e:
             print(f"\n\n*** KeyError Detected at Epoch {epoch}, Training Iteration {train_i} ***")
             print(f"Likely missing preloaded data for key: {e}")
             print("Check the preloading loop logs for warnings or errors.")
             raise e
        except ValueError as e:
            if 'NaN loss' in str(e):
                print(f"\n\n*** NaN Loss Detected at Epoch {epoch}, Iteration {train_i} ***")
            elif 'Attempting to unscale FP16 gradients' in str(e):
                 print(f"\n\n*** FP16 Unscale Error at Epoch {epoch}, Iteration {train_i} ***")
                 print("This happened DESPITE inspecting gradients. Check if inspection missed something or if clipping is active/effective.")
            else:
                 print(f"\n\n*** ValueError at Epoch {epoch}, Iteration {train_i}: {e} ***")
            raise e
        except Exception as e:
            print(f"\n\n*** An unexpected error occurred at Epoch {epoch}, Iteration {train_i} ***")
            print(f"Error Type: {type(e)}")
            print(f"Error Details: {e}")
            raise e


    # Check if training was unstable in this epoch
    if nan_or_inf_grad_detected_this_epoch:
         print(f"\n*** WARNING: Non-finite gradients were detected and optimizer steps skipped during Epoch {epoch}. Training might be unstable. Consider lowering LR or checking model/data further. ***\n")


    # --- Eval Section ---
    model.eval()
    if local_rank==0: # Only evaluate on main process
        # ... (Your existing evaluation code - ensure it handles potential None models if disabled) ...
        pass # Placeholder for your eval code

    # --- Logging and Saving ---
    # --- Logging and Saving ---
    if local_rank == 0: # Only log and save on main process
        # Calculate epoch duration
        epoch_duration = time.time() - epoch_start_time
        hours, remainder = divmod(epoch_duration, 3600)
        minutes, seconds = divmod(remainder, 60)
        
        # Calculate average metrics for the epoch
        avg_loss = np.mean(losses[-(train_i+1):]) if train_i >= 0 else 0
        avg_fwd_acc = (fwd_percent_correct / (train_i + 1)) if train_i >= 0 else 0
        avg_bwd_acc = (bwd_percent_correct / (train_i + 1)) if train_i >= 0 else 0
        avg_loss_clip = (loss_clip_total / (train_i + 1)) if train_i >= 0 else 0
        # ... calculate other averages similarly ...

        logs = {"train/loss": avg_loss,
                "train/lr": lrs[-1] if lrs else 0,
                "train/fwd_pct_correct": avg_fwd_acc,
                "train/bwd_pct_correct": avg_bwd_acc,
                "train/loss_clip_total": avg_loss_clip,
                "train/epoch_time_hours": f"{hours:02.0f}:{minutes:02.0f}:{seconds:05.2f}"
                # Add other averaged train metrics here if needed
                # Add test metrics here if calculated
                }
        # Add formatted time string to progress bar
        time_str = f"Time: {hours:02.0f}h {minutes:02.0f}m {seconds:05.2f}s"
        progress_bar.set_postfix(**logs, time=time_str)
        
        if wandb_log: 
            # Also log the raw seconds for plotting
            logs["train/epoch_time_seconds"] = epoch_duration
            wandb.log(logs, step=epoch) # Log per epoch

        if (ckpt_saving) and (epoch % ckpt_interval == 0 or epoch == num_epochs - 1):
            save_ckpt(f'last') # Save last ckpt potentially more often or at end
            # Add best ckpt saving logic if needed based on a validation metric

        if (ckpt_saving) and (epoch % ckpt_interval == 0 or epoch == num_epochs - 1):
            save_ckpt(f'last') # Save last ckpt potentially more often or at end
            # Add best ckpt saving logic if needed based on a validation metric


    # wait for other GPUs to catch up if needed
    accelerator.wait_for_everyone()
    torch.cuda.empty_cache()


# --- End of Epoch Loop ---

print("\n===Finished!===\n")
if ckpt_saving and accelerator.is_main_process: # Ensure saving happens only on main process
    save_ckpt(f'final') # Save final checkpoint

fmri_model_v1_1ses_50ep starting with epoch 0 / 50


  0%|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   


--- Starting Epoch 0 Preloading ---
  Preloading for subj01...
      Duplicate image indices found in batch 103. Skipping.
      Duplicate image indices found in batch 174. Skipping.
    subj01: Reached max iterations (187). Breaking inner batch loop.
  Preloading complete for epoch after processing subj01. Total iters: 187
  Finished preloading attempts for subj01.
--- Epoch 0 Preloading Complete ---
  Total voxel iters stored: 187

--- Starting Epoch 0 Training Loop ---




In [None]:
plt.plot(losses)
plt.show()
plt.plot(test_losses)
plt.show()