In [1]:
#Install prereqs, if necessary
%pip install h5py nilearn transformers torch vit-pytorch torchmetrics

Note: you may need to restart the kernel to use updated packages.


In [1]:
#Magic that doesn't work on lightning.ai for some reason
%load_ext autoreload
%autoreload 2

### Enable access to dataset module

In [2]:
import sys
import os

sys.path.append("/teamspace/studios/this_studio/mindscape/")

### Initialize the dataset & image pre-processing

In [4]:
from ImageVoxelsDataset import ImageVoxelsDataset
from torchvision import transforms
import torch

#image pre-processing
transform = transforms.Compose([
    transforms.ToPILImage(), # convert np array from dataset to PIL
    transforms.Resize((224, 224)), # resize to ViT dimensions
    transforms.ToTensor() # convert back to tensor
])

#train for subject 1
dataset = ImageVoxelsDataset('/teamspace/studios/this_studio/mindscape/nsd',
                                subject=1, 
                                transform=transform, 
                                preload_imgs = True)

Loading stimuli images...: 100%|██████████| 73000/73000 [00:22<00:00, 3264.42it/s]


### Convert betas back to percent signal change

According to the [NSD Data Manual](https://cvnlab.slite.page/p/6CusMRYfk0/Untitled):


> ...for some of [the] NSD data files that we have prepared, the betas have been multiplied by 300 and converted to int16 format to reduce space usage. Upon loading the beta files, the values should be immediately converted back to percent signal change by casting to decimal format (e.g. single or double) and dividing by 300.


In [5]:
def to_percent_signal_change(betas):
    return betas.float() / 300

dataset.target_transform = to_percent_signal_change

### Normalize the betas

Note: Mean and standard deviation before conversion of betas back to percent signal change are 293.4223405114772 and 1119.770929338537, respectively.

In [6]:
# This takes ~10 minutes to calculate. The values are provided below to save time
# from util import find_mean_sd
# MEAN, STD_DEV = find_mean_sd(dataset)
# print(f"mean: {MEAN}\nstd. dev: {STD_DEV}")
MEAN, STD_DEV = 0.9780744683715907, 3.7325697644617897

def z_norm(n):
    return (n-MEAN) / STD_DEV

old_transform = dataset.target_transform
dataset.target_transform = lambda betas: z_norm(old_transform(betas))

### Instantiate model

In [7]:
from matrixvit import MatrixViTModel
from transformers import ViTConfig

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32



output_dims = dataset[0][1].shape

model = MatrixViTModel(
    output_dimensions=output_dims,
    image_size = (224, 224),
    patch_size = (16, 16),
    dim = 768, #hidden size
    depth = 6, #num hidden layers
    heads = 8,
    mlp_dim = 1024 #intermediate size    
).to(device=device, dtype=dtype)

### Train

In [15]:
from torch.utils.data import DataLoader, random_split
import torch
from tqdm import tqdm
from torch import nn, optim
from util import set_seed
from torchmetrics import R2Score
import time

#seed for reproducibility:
set_seed(47)

def seed_worker(worker_id):
    set_seed(47 + worker_id)




P_TRAIN = 0.7
P_VAL = 0.1

train_size = int(P_TRAIN * len(dataset))
val_size = int(P_VAL * len(dataset))
eval_size = len(dataset) - val_size - train_size

train_dataset, val_dataset, eval_dataset = random_split(dataset, [train_size, val_size, eval_size])

train_loader = DataLoader(train_dataset,  batch_size=32, num_workers=2, worker_init_fn=seed_worker)
val_loader = DataLoader(val_dataset, batch_size=64, num_workers=1, worker_init_fn=seed_worker)


loss_function = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.1, verbose=True)

num_betas = dataset[0][1].shape[0]
r2_score = R2Score(num_outputs = num_betas, multioutput="uniform_average").to("cuda")

num_epochs = 300


model.to("cuda")
import time
try:
    for epoch in range(num_epochs):
        print(f"Starting epoch {epoch+1}")
        model.train()
        train_loss = 0.0

        #Train
        for images, targets in tqdm(train_loader):
            images, targets = images.to('cuda'), targets.to('cuda')            

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_function(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        print(f'Epoch {epoch+1}, Training Loss: {avg_train_loss}')

        #Validation
        model.eval()
        val_loss = 0.0
        r2_score.reset()

        with torch.no_grad():
            for images, targets in val_loader:
                
                images, targets = images.to('cuda'), targets.to('cuda')
                outputs = model(images)

                loss = loss_function(outputs, targets)
                val_loss += loss.item()
                r2_score.update(outputs, targets)

        avg_val_loss = val_loss / len(val_loader)
        val_r2 = r2_score.compute()  # Compute final R2 score for this epoch

        print(f'Epoch {epoch+1}, Validation Loss: {avg_val_loss}, R^2 Score: {val_r2}')

        scheduler.step(avg_val_loss)

        #Save checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
            'val_r2_score': val_r2  # Save the R2 score in the checkpoint

        }
        torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth')
except KeyboardInterrupt:
    train_loader = None



Starting epoch 1


  0%|          | 0/657 [00:00<?, ?it/s]

100%|██████████| 657/657 [00:50<00:00, 12.95it/s]

Epoch 1, Training Loss: 5.337504886056735





Epoch 1, Validation Loss: 0.4587599717556162, R^2 Score: -0.3577214777469635
Starting epoch 2


100%|██████████| 657/657 [00:50<00:00, 12.91it/s]

Epoch 2, Training Loss: 0.49092680646767173





Epoch 2, Validation Loss: 0.4183196799552187, R^2 Score: -0.10443037748336792
Starting epoch 3


100%|██████████| 657/657 [00:51<00:00, 12.85it/s]

Epoch 3, Training Loss: 0.442193719213956





Epoch 3, Validation Loss: 0.4197661109427188, R^2 Score: -0.10890914499759674
Starting epoch 4


100%|██████████| 657/657 [00:51<00:00, 12.82it/s]

Epoch 4, Training Loss: 0.4317242557146052





Epoch 4, Validation Loss: 0.40348075171734427, R^2 Score: -0.05296771973371506
Starting epoch 5


100%|██████████| 657/657 [00:51<00:00, 12.77it/s]

Epoch 5, Training Loss: 0.42209928688030446





Epoch 5, Validation Loss: 0.41671973783919153, R^2 Score: -0.07539434731006622
Starting epoch 6


100%|██████████| 657/657 [00:51<00:00, 12.87it/s]

Epoch 6, Training Loss: 0.4177285785272241





In [8]:
total_params = sum(p.numel() for p in model.parameters())

total_params

29009534

In [5]:
len(dataset)

30000