In [1]:
# Standard Python libraries
import os
import sys
import json
import yaml
import copy
import math
import time
import random

# Scientific computing libraries
import numpy as np
import matplotlib.pyplot as plt
import h5py

# Deep learning frameworks
import torch
import torch.nn as nn
from torchvision import transforms

# Custom utilities
import utils

# Progress bar library
from tqdm import tqdm

# Interactive environment setup (only if running interactively)
if utils.is_interactive():
    %load_ext autoreload
    %autoreload 2
    from tqdm.notebook import tqdm

# GPU configuration and setup
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True  # fixes Conv3D if used

# Multi-GPU configuration
device_count = torch.cuda.device_count()
print(f"Number of available CUDA devices: {device_count}")

local_rank = os.getenv('LOCAL_RANK')
local_rank = 0 if local_rank is None else int(local_rank)
print(f"LOCAL RANK={local_rank}")

num_devices = os.getenv('NUM_GPUS')
num_devices = 1 if num_devices is None else int(num_devices)
print(f"NUM GPUS={num_devices}")
distributed = True if num_devices > 1 else False
if distributed: assert device_count == num_devices

node = os.getenv('SLURM_NODEID')
node = 0 if node is None else int(node)
print(f"NODE={node}")

global_rank = os.getenv('RANK')
global_rank = 0 if global_rank is None else int(global_rank)
print(f"GLOBAL RANK={global_rank}")

world_size = os.getenv('WORLD_SIZE')
world_size = 1 if world_size is None else int(world_size)
print(f"WORLD_SIZE={world_size}")

# Load parameters from yaml config
config = yaml.load(open('config.yaml', 'r'), Loader=yaml.FullLoader)

# Create global variables from the config
print("\n__CONFIG__")
for attribute_name in config.keys():
    print(f"{attribute_name} = {config[attribute_name]}")
    globals()[attribute_name] = config[f'{attribute_name}']
print("\n")

# Set up data type and device
data_type = torch.float32  # change depending on your mixed_precision
global_batch_size = batch_size * world_size
device = torch.device('cuda')

# Print setup information
print("device =", device, "distributed =", distributed, "num_devices =", num_devices,
      "local rank =", local_rank, "world size =", world_size, "data_type =", data_type)

# Seed all random functions
utils.seed_everything(seed + global_rank)


Number of available CUDA devices: 1
LOCAL RANK=0
NUM GPUS=1
NODE=0
GLOBAL RANK=0
WORLD_SIZE=1

__CONFIG__
model_name = ridge_CLIPbase
data_path = /teamspace/studios/this_studio/nsd
clip_model_name = ViT-B/32
batch_size = 128
wandb_log = True
ckpt_interval = 10
ckpt_saving = False
seed = 0
max_lr = 0.0003
num_epochs = 30
temperature = 0.006


device = cuda distributed = False num_devices = 1 local rank = 0 world size = 1 data_type = torch.float32


## Prep data, models, and dataloaders

#### Load all voxels for a given NSD subject

In [2]:
subject_id = 1
f = h5py.File(f'{data_path}/betas_all_subj0{subject_id}_fp32_renorm.hdf5', 'r')
voxels = f['betas'][:] # preloading all voxels for given subject
print("voxels", voxels.shape)

voxels (30000, 15724)


#### Load all images

In [3]:
f = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
images = f['images'][:] # preloading all images for easy indexing
print("images", images.shape)

images (73000, 3, 224, 224)


#### Pair voxels to images/image captions and define train/test splits

In [None]:
shared1000 = np.load(f"{data_path}/shared1000.npy")

# image train/test split
file_path = f"{data_path}/COCO_73k_subj_indices.hdf5"
with h5py.File(file_path, 'r') as f:
    data = {key: f[key][()] for key in f.keys()}

image_indices = data["subj01"]
is_shared = shared1000[image_indices]

image_indices_train = image_indices[~is_shared]
image_indices_test = image_indices[is_shared]

images_train = torch.tensor(images[image_indices_train]).float() # (27000, 3, 224, 224)
images_test = torch.tensor(images[image_indices_test]).float() # (3000, 3, 224, 224)

print("images_train", images_train.shape)
print("images_test", images_test.shape)

# load image captions
annots = np.load(f'{data_path}/COCO_73k_annots_curated.npy')
print("annots", annots.shape)

# check that annotations line up with image order
display(utils.torch_to_Image(images_train[:1]))
print(utils.get_annotations(annots[image_indices_train][:1],random=False))

# voxel train/test split
shared_voxel_indices = shared1000[data["subj01"]]
voxels_train = torch.tensor(voxels[~shared_voxel_indices]) # (27000, 15724)
voxels_test = torch.tensor(voxels[shared_voxel_indices]) # (3000, 15724)

num_voxels = voxels_train.shape[-1]

print("voxels_train", voxels_train.shape)
print("voxels_test", voxels_test.shape)

del images # free up memory for images we arent using

#### Create torch dataloaders

In [None]:
train_data = torch.utils.data.TensorDataset(images_train, voxels_train, torch.tensor(image_indices_train))
test_data = torch.utils.data.TensorDataset(images_test, voxels_test, torch.tensor(image_indices_test))

In [None]:
train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True)
test_dl = torch.utils.data.DataLoader(test_data, batch_size=300, shuffle=False, drop_last=True, pin_memory=True)

num_samples_per_epoch= len(train_dl.dataset)
print("num_samples_per_epoch", num_samples_per_epoch)
num_iterations_per_epoch = len(train_dl)
print("num_iterations_per_epoch", num_iterations_per_epoch)

## Create/load models

In [None]:
import clip
clip_model, _ = clip.load(clip_model_name, device=device)
preprocess = transforms.Compose([
    transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                         std=[0.26862954, 0.26130258, 0.27577711]),
])
def clip_img_embedder(image):
    preproc_img = preprocess(image)
    return clip_model.encode_image(preproc_img)
if clip_model_name=="ViT-L/14":
    clip_emb_dim = 768
elif clip_model_name=="ViT-B/32":
    clip_emb_dim = 512

In [None]:
class RidgeRegression(torch.nn.Module):
    def __init__(self, input_size, final_out): 
        super(RidgeRegression, self).__init__()
        self.linears = torch.nn.Linear(input_size, final_out)
    def forward(self, x):
        out = self.linears(x)
        return out

model = RidgeRegression(num_voxels, final_out=clip_emb_dim)
utils.count_params(model)

# test on subject 1 with fake data
b = torch.randn((2,num_voxels))
print(b.shape, model(b).shape)

## Setup optimizer / lr / ckpt saving

In [None]:
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

opt_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters()], 'weight_decay': 1e-2},
]

optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=max_lr)

total_steps=int(np.floor(num_epochs*num_iterations_per_epoch))
print("total_steps", total_steps)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, 
    max_lr=max_lr,
    total_steps=total_steps,
    final_div_factor=1000,
    last_epoch=-1, pct_start=2/num_epochs
)

outdir = os.path.abspath(f'../ckpts/{model_name}')
    
def save_ckpt(tag):
    ckpt_path = outdir+f'/{tag}.pth'
    os.makedirs(outdir,exist_ok=True)
    torch.save({
        'epoch': epoch,
        'model_state_dict': unwrapped_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        }, ckpt_path)
    print(f"\n---saved {outdir}/{tag} ckpt!---\n")

def load_ckpt(tag,load_lr=True,load_optimizer=True,load_epoch=True,strict=True,outdir=outdir): 
    print(f"\n---loading {outdir}/{tag}.pth ckpt---\n")
    checkpoint = torch.load(outdir+'/last.pth', map_location='cpu')
    state_dict = checkpoint['model_state_dict']
    model.load_state_dict(state_dict, strict=strict)
    if load_epoch:
        globals()["epoch"] = checkpoint['epoch']
        print("Epoch",epoch)
    if load_optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if load_lr:
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    del checkpoint

print("\nDone with model preparations!")
num_params = utils.count_params(model)

## Weights and Biases

In [None]:
if local_rank==0 and wandb_log: # only use main process for wandb logging
    import wandb
    wandb_project = 'brain_filtering'
    print(f"wandb {wandb_project} run {model_name}")
    # need to configure wandb beforehand in terminal with "wandb init"!
    wandb_config = {
      "model_name": model_name,
      "global_batch_size": global_batch_size,
      "batch_size": batch_size,
      "num_epochs": num_epochs,
      "num_params": num_params,
      "max_lr": max_lr,
      "num_samples_per_epoch": num_samples_per_epoch,
      "ckpt_interval": ckpt_interval,
      "ckpt_saving": ckpt_saving,
      "seed": seed,
      "distributed": distributed,
      "num_devices": num_devices,
      "world_size": world_size,
    }
    print("wandb_config:\n",wandb_config)
    print("wandb_id:",model_name)
    wandb.init(
        id=model_name,
        project=wandb_project,
        name=model_name,
        config=wandb_config,
        resume="never",
    ) # would need to make modifications to training code to allow resume="allow"
else:
    wandb_log = False

## Main

In [None]:
epoch = 0
losses, test_losses, lrs = [], [], []
best_test_loss = 1e9
torch.cuda.empty_cache()
model.to(device)

In [None]:
print(f"{model_name} starting with epoch {epoch} / {num_epochs}")
progress_bar = tqdm(range(epoch,num_epochs), ncols=1000, disable=(local_rank!=0))
mse = nn.MSELoss()

for epoch in progress_bar:
    fwd_percent_correct, test_fwd_percent_correct = 0., 0.
    bwd_percent_correct, test_bwd_percent_correct = 0., 0.
    
    for train_i, data in enumerate(train_dl):
        model.train()
        optimizer.zero_grad()
    
        image, voxel, idx = data
        image = image.to(device)
        voxel = voxel.to(device)
        # annot = utils.get_annotations(annots[idx],random=False)
        # annot = clip.tokenize(annot).to(device)

        clip_brain = model(voxel)
        clip_image = clip_img_embedder(image).to(clip_brain.dtype)
        # clip_text = clip_model.encode_text(annot).to(clip_image.dtype) 
        
        labels = torch.arange(len(clip_image)).to(device)
        
        clip_image = nn.functional.normalize(clip_image, dim=-1)
        # clip_text = nn.functional.normalize(clip_text, dim=-1)
        clip_brain = nn.functional.normalize(clip_brain, dim=-1)

        logits_image_brain = (clip_image @ clip_brain.T) / temperature
        loss_image_brain = nn.functional.cross_entropy(logits_image_brain, labels)
        logits_brain_image = (clip_brain @ clip_image.T) / temperature
        loss_brain_image = nn.functional.cross_entropy(logits_brain_image, labels)

        loss = loss_image_brain + loss_brain_image
        
        fwd_ret = utils.topk(utils.batchwise_cosine_similarity(clip_brain, clip_image), labels, k=1).item()
        bwd_ret = utils.topk(utils.batchwise_cosine_similarity(clip_image, clip_brain), labels, k=1).item()
            
        fwd_percent_correct += fwd_ret
        bwd_percent_correct += bwd_ret
    
        utils.check_loss(loss)
        loss.backward()
        optimizer.step()
    
        losses.append(loss.item())
        lrs.append(optimizer.param_groups[0]['lr'])
    
        lr_scheduler.step()
    
    model.eval()
    with torch.no_grad(): 
        for test_i, test_data in enumerate(test_dl): 
            image, voxel, idx = test_data
            image = image.to(device)
            voxel = voxel.to(device)
            annot = utils.get_annotations(annots[idx],random=False)
            annot = clip.tokenize(annot).to(device)
            
            clip_brain = model(voxel)
            clip_image = clip_img_embedder(image).to(clip_brain.dtype)
            # clip_text = clip_model.encode_text(annot).to(clip_image.dtype)
            
            labels = torch.arange(len(clip_image)).to(device)

            clip_image = nn.functional.normalize(clip_image, dim=-1)
            # clip_text = nn.functional.normalize(clip_text, dim=-1)
            clip_brain = nn.functional.normalize(clip_brain, dim=-1)
    
            logits_image_brain = (clip_image @ clip_brain.T) / temperature
            loss_image_brain = nn.functional.cross_entropy(logits_image_brain, labels)
            logits_brain_image = (clip_brain @ clip_image.T) / temperature
            loss_brain_image = nn.functional.cross_entropy(logits_brain_image, labels)
    
            loss = loss_image_brain + loss_brain_image
            
            fwd_ret = utils.topk(utils.batchwise_cosine_similarity(clip_brain, clip_image), labels, k=1).item()
            bwd_ret = utils.topk(utils.batchwise_cosine_similarity(clip_image, clip_brain), labels, k=1).item()
                
            test_fwd_percent_correct += fwd_ret
            test_bwd_percent_correct += bwd_ret
            
            utils.check_loss(loss)                
            test_losses.append(loss.item())

        logs = {"train/loss": np.mean(losses[-(train_i+1):]),
            "test/loss": np.mean(test_losses[-(test_i+1):]),
            "train/lr": lrs[-1],
            "train/num_steps": len(losses),
            "test/num_steps": len(test_losses),
            "train/fwd_percent_correct": fwd_percent_correct / (train_i + 1),
            "test/test_fwd_percent_correct": test_fwd_percent_correct / (test_i + 1),
            "train/bwd_percent_correct": bwd_percent_correct / (train_i + 1),
            "test/test_bwd_percent_correct": test_bwd_percent_correct / (test_i + 1),
            }
        progress_bar.set_postfix(**logs)
        
        if wandb_log: wandb.log(logs)
                
        # Save model checkpoint and reconstruct
        if (ckpt_saving) and (epoch % ckpt_interval == 0): 
            save_ckpt(f'last')
    
        # wait for other GPUs to catch up if needed
        if distributed: 
            dist.barrier()

print("\n===Finished!===\n")

if ckpt_saving: 
    save_ckpt(f'last')

In [None]:
plt.plot(losses)
plt.show()
plt.plot(test_losses)
plt.show()