In [1]:
#Install prereqs, if necessary
%pip install h5py nilearn transformers torch vit-pytorch torchmetrics

Note: you may need to restart the kernel to use updated packages.


In [1]:
#Magic that doesn't work on lightning.ai for some reason
%load_ext autoreload
%autoreload 2

### Enable access to dataset module

In [2]:
import sys
import os

sys.path.append("/teamspace/studios/this_studio/mindscape/")

### Initialize the dataset & image pre-processing

In [3]:
from ImageVoxelsDataset import ImageVoxelsDataset
from torchvision import transforms
import torch

#image pre-processing
transform = transforms.Compose([
    transforms.ToPILImage(), # convert np array from dataset to PIL
    transforms.Resize((224, 224)), # resize to ViT dimensions
    transforms.ToTensor() # convert back to tensor
])

#train for subject 1
dataset = ImageVoxelsDataset('/teamspace/studios/this_studio/mindscape/nsd',
                                subject=1, 
                                transform=transform, 
                                preload_imgs = True)

### Convert betas back to percent signal change

According to the [NSD Data Manual](https://cvnlab.slite.page/p/6CusMRYfk0/Untitled):


> ...for some of [the] NSD data files that we have prepared, the betas have been multiplied by 300 and converted to int16 format to reduce space usage. Upon loading the beta files, the values should be immediately converted back to percent signal change by casting to decimal format (e.g. single or double) and dividing by 300.


In [4]:
def to_percent_signal_change(betas, idx):
    return betas.float() / 300

dataset.target_transform = to_percent_signal_change

### Calculate STD and mean per session per voxel

In [5]:
from util import mean_sd_map
from tqdm import tqdm

session_size = 750

#each entry is of the form mean, std_dev
session_metrics = []

for i in tqdm(range(0, len(dataset), session_size)):
    session_set = torch.stack([dataset[i][1] for i in range(i, i+session_size)])
    session_metrics.append(mean_sd_map(session_set))

session_metrics = torch.stack(session_metrics, dim=0)

  0%|          | 0/40 [00:00<?, ?it/s]

100%|██████████| 40/40 [00:53<00:00,  1.35s/it]


In [6]:
session_metrics.shape

torch.Size([40, 2, 5246])

### Normalize the betas

Note: Mean and standard deviation before conversion of betas back to percent signal change are 1.762125820278688 and 3.0391348517898913, respectively.

In [6]:
# This takes ~10 minutes to calculate. The values are provided below to save time
# from util import find_mean_sd
# MEAN, STD_DEV = find_mean_sd(dataset)
# print(f"mean: {MEAN}\nstd. dev: {STD_DEV}")
# Before masking: MEAN, STD_DEV = 0.9780744683715907, 3.7325697644617897
# MEAN, STD_DEV  = 1.762125820278688, 3.0391348517898913
# def z_norm(n):
#     return (n-MEAN) / STD_DEV

# old_transform = dataset.target_transform
# dataset.target_transform = lambda betas: z_norm(old_transform(betas))


def norm(betas, idx):
    sess = idx // 750
    means, stds = session_metrics[sess]
    
    return (betas - means) / stds

old_transform = dataset.target_transform
dataset.target_transform = lambda betas, idx: norm(old_transform(betas, idx), idx)


### Instantiate model

In [7]:
from matrixvit import MatrixViTModel
from transformers import ViTConfig

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32



output_dims = dataset[0][1].shape

model = MatrixViTModel(
    output_dimensions=output_dims,
    image_size = (224, 224),
    patch_size = (16, 16),
    dim = 768, #hidden size
    depth = 12, #num hidden layers
    heads = 12,
    mlp_dim = 1024, #intermediate size
    drop_rate = 0.2  
).to(device=device, dtype=dtype)

### Filter out duplicate images
In practice these should really be averaged, but this is a proof of concept to see if this is what is confusing the model.

In [8]:
from tqdm import tqdm

seen_images = set()

unique_indices = []


pbar = tqdm(dataset)

for i, (image, _) in enumerate(pbar):
    image = image.numpy().tobytes()
    if image not in seen_images:
        seen_images.add(image)
        unique_indices.append(i)

  0%|          | 0/30000 [00:00<?, ?it/s]

100%|██████████| 30000/30000 [01:05<00:00, 456.67it/s]


In [9]:
len(unique_indices)

10000

In [10]:
from util import FilteredDataset

dataset = FilteredDataset(dataset, unique_indices)

### Train

In [13]:
from torch.utils.data import DataLoader, random_split
import torch
from tqdm import tqdm
from torch import nn, optim
from util import set_seed
from torchmetrics import R2Score
import time

#seed for reproducibility:
set_seed(47)

def seed_worker(worker_id):
    set_seed(47 + worker_id)




P_TRAIN = 0.7
P_VAL = 0.1

train_size = int(P_TRAIN * len(dataset))
val_size = int(P_VAL * len(dataset))
eval_size = len(dataset) - val_size - train_size

train_dataset, val_dataset, eval_dataset = random_split(dataset, [train_size, val_size, eval_size])

train_loader = DataLoader(train_dataset,  batch_size=32, num_workers=2, worker_init_fn=seed_worker)
val_loader = DataLoader(val_dataset, batch_size=64, num_workers=1, worker_init_fn=seed_worker)


loss_function = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.1, verbose=True, threshold=0.001)

num_betas = dataset[0][1].shape[0]
r2_score = R2Score(num_outputs = num_betas, multioutput="raw_values").to("cuda")

num_epochs = 300


model.to("cuda")
import time


do_val = True

try:
    for epoch in range(num_epochs):
        print(f"Starting epoch {epoch+1}")
        model.train()
        train_loss = 0.0

        #Train
        for images, targets in tqdm(train_loader):
            images, targets = images.to('cuda'), targets.to('cuda')            
            
            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_function(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        print(f'Epoch {epoch+1}, Training Loss: {avg_train_loss}')

        if not do_val:
            continue

        #Validation
        model.eval()
        val_loss = 0.0
        r2_score.reset()

        with torch.no_grad():
            for images, targets in val_loader:
                
                images, targets = images.to('cuda'), targets.to('cuda')
                outputs = model(images)

                loss = loss_function(outputs, targets)
                val_loss += loss.item()
                r2_score.update(outputs, targets)

        avg_val_loss = val_loss / len(val_loader)
        val_r2 = r2_score.compute()  # Compute final R2 score for this epoch

        print(f'Epoch {epoch+1}, Validation Loss: {avg_val_loss}, Median R^2 Score: {torch.median(val_r2)}, Mean R^2 Score: {torch.mean(val_r2)}')
        scheduler.step(avg_val_loss)
        #Save checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
            'val_r2_score': val_r2  # Save the R2 score in the checkpoint

        }
        torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth')
except KeyboardInterrupt:
    train_loader = None



Starting epoch 1


  0%|          | 0/219 [00:00<?, ?it/s]

100%|██████████| 219/219 [00:39<00:00,  5.55it/s]

Epoch 1, Training Loss: 0.9637892883117885





Epoch 1, Validation Loss: 1.521146073937416, Median R^2 Score: -0.36218559741973877, Mean R^2 Score: -0.5272493362426758
Starting epoch 2


100%|██████████| 219/219 [00:39<00:00,  5.54it/s]

Epoch 2, Training Loss: 1.0109483381928919





Epoch 2, Validation Loss: 2.400041453540325, Median R^2 Score: -0.5865364074707031, Mean R^2 Score: -1.3933160305023193
Starting epoch 3


100%|██████████| 219/219 [00:39<00:00,  5.54it/s]

Epoch 3, Training Loss: 0.9981727257166824





Epoch 3, Validation Loss: 2.782250866293907, Median R^2 Score: -0.7537497282028198, Mean R^2 Score: -1.7738702297210693
Starting epoch 4


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 4, Training Loss: 0.9955018532874922





Epoch 4, Validation Loss: 2.240910343825817, Median R^2 Score: -0.5120524168014526, Mean R^2 Score: -1.238886833190918
Starting epoch 5


100%|██████████| 219/219 [00:39<00:00,  5.52it/s]

Epoch 5, Training Loss: 0.9922841651254592





Epoch 5, Validation Loss: 3.016848161816597, Median R^2 Score: -0.9200742244720459, Mean R^2 Score: -2.0089287757873535
Starting epoch 6


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 6, Training Loss: 0.9893394480012867





Epoch 6, Validation Loss: 2.4144762605428696, Median R^2 Score: -0.6490968465805054, Mean R^2 Score: -1.4190950393676758
Starting epoch 7


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 7, Training Loss: 0.9872315215737852





Epoch 7, Validation Loss: 2.03278748691082, Median R^2 Score: -0.413366436958313, Mean R^2 Score: -1.033522605895996
Starting epoch 8


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 8, Training Loss: 0.9949686222424791





Epoch 8, Validation Loss: 3.341104194521904, Median R^2 Score: -0.9111486673355103, Mean R^2 Score: -2.3422353267669678
Starting epoch 9


100%|██████████| 219/219 [00:39<00:00,  5.52it/s]

Epoch 9, Training Loss: 0.9892295310486398





Epoch 9, Validation Loss: 3.7484567761421204, Median R^2 Score: -1.0915157794952393, Mean R^2 Score: -2.7467265129089355
Starting epoch 10


100%|██████████| 219/219 [00:39<00:00,  5.48it/s]

Epoch 10, Training Loss: 0.9862001026602096





Epoch 10, Validation Loss: 4.105484053492546, Median R^2 Score: -1.294489860534668, Mean R^2 Score: -3.103586435317993
Starting epoch 11


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 11, Training Loss: 0.985757214021465





Epoch 11, Validation Loss: 4.204228654503822, Median R^2 Score: -1.372962236404419, Mean R^2 Score: -3.2020347118377686
Starting epoch 12


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 12, Training Loss: 0.9835700605013599





Epoch 12, Validation Loss: 4.311469659209251, Median R^2 Score: -1.476231575012207, Mean R^2 Score: -3.309481382369995
Starting epoch 13


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 13, Training Loss: 0.9822901784012851





Epoch 13, Validation Loss: 4.262280076742172, Median R^2 Score: -1.5460174083709717, Mean R^2 Score: -3.257892608642578
Starting epoch 14


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 14, Training Loss: 0.9805828444489605





Epoch 14, Validation Loss: 4.354208379983902, Median R^2 Score: -1.536623477935791, Mean R^2 Score: -3.349933385848999
Starting epoch 15


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 15, Training Loss: 0.9801158583871851





Epoch 15, Validation Loss: 4.345411166548729, Median R^2 Score: -1.5399086475372314, Mean R^2 Score: -3.3408355712890625
Starting epoch 16


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 16, Training Loss: 0.9801074979512114





Epoch 16, Validation Loss: 4.385260656476021, Median R^2 Score: -1.576338529586792, Mean R^2 Score: -3.3817362785339355
Starting epoch 17


100%|██████████| 219/219 [00:39<00:00,  5.54it/s]

Epoch 17, Training Loss: 0.9799796933452832





Epoch 17, Validation Loss: 4.393986389040947, Median R^2 Score: -1.5964796543121338, Mean R^2 Score: -3.3902156352996826
Starting epoch 18


100%|██████████| 219/219 [00:39<00:00,  5.54it/s]

Epoch 18, Training Loss: 0.9794497021801396





Epoch 18, Validation Loss: 4.414315283298492, Median R^2 Score: -1.6244428157806396, Mean R^2 Score: -3.411229372024536
Starting epoch 19


100%|██████████| 219/219 [00:39<00:00,  5.54it/s]

Epoch 19, Training Loss: 0.9799609627897881





Epoch 19, Validation Loss: 4.390752762556076, Median R^2 Score: -1.6319246292114258, Mean R^2 Score: -3.3876733779907227
Starting epoch 20


100%|██████████| 219/219 [00:39<00:00,  5.49it/s]

Epoch 20, Training Loss: 0.9796371258557115





Epoch 20, Validation Loss: 4.389676854014397, Median R^2 Score: -1.6406753063201904, Mean R^2 Score: -3.386960029602051
Starting epoch 21


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 21, Training Loss: 0.9800040694676577





Epoch 21, Validation Loss: 4.3881246745586395, Median R^2 Score: -1.6447696685791016, Mean R^2 Score: -3.385690927505493
Starting epoch 22


100%|██████████| 219/219 [00:39<00:00,  5.49it/s]

Epoch 22, Training Loss: 0.9791244611348191





Epoch 22, Validation Loss: 4.390263020992279, Median R^2 Score: -1.6560544967651367, Mean R^2 Score: -3.3880252838134766
Starting epoch 23


100%|██████████| 219/219 [00:39<00:00,  5.49it/s]

Epoch 23, Training Loss: 0.979466653305646





Epoch 23, Validation Loss: 4.392697736620903, Median R^2 Score: -1.6590919494628906, Mean R^2 Score: -3.3905112743377686
Starting epoch 24


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 24, Training Loss: 0.9792242825847782





Epoch 24, Validation Loss: 4.391322091221809, Median R^2 Score: -1.6634764671325684, Mean R^2 Score: -3.389288902282715
Starting epoch 25


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 25, Training Loss: 0.9799347114345255





Epoch 25, Validation Loss: 4.395066916942596, Median R^2 Score: -1.6636638641357422, Mean R^2 Score: -3.3931148052215576
Starting epoch 26


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 26, Training Loss: 0.9799565958650145





Epoch 26, Validation Loss: 4.3952890783548355, Median R^2 Score: -1.663811206817627, Mean R^2 Score: -3.3933417797088623
Starting epoch 27


100%|██████████| 219/219 [00:39<00:00,  5.48it/s]

Epoch 27, Training Loss: 0.9793758522974302





Epoch 27, Validation Loss: 4.3950018882751465, Median R^2 Score: -1.664095401763916, Mean R^2 Score: -3.393051862716675
Starting epoch 28


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 28, Training Loss: 0.9791837538758369





Epoch 28, Validation Loss: 4.394755333662033, Median R^2 Score: -1.6639573574066162, Mean R^2 Score: -3.392792224884033
Starting epoch 29


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 29, Training Loss: 0.9793508390313415





Epoch 29, Validation Loss: 4.395254164934158, Median R^2 Score: -1.664729118347168, Mean R^2 Score: -3.3932971954345703
Starting epoch 30


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 30, Training Loss: 0.979110502216914





Epoch 30, Validation Loss: 4.395416587591171, Median R^2 Score: -1.6658599376678467, Mean R^2 Score: -3.3934741020202637
Starting epoch 31


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 31, Training Loss: 0.9792667594674516





Epoch 31, Validation Loss: 4.395986989140511, Median R^2 Score: -1.6661696434020996, Mean R^2 Score: -3.3940489292144775
Starting epoch 32


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 32, Training Loss: 0.9796525765227401





Epoch 32, Validation Loss: 4.395302519202232, Median R^2 Score: -1.6669490337371826, Mean R^2 Score: -3.393373727798462
Starting epoch 33


100%|██████████| 219/219 [00:39<00:00,  5.54it/s]

Epoch 33, Training Loss: 0.9794817991452675





Epoch 33, Validation Loss: 4.395122990012169, Median R^2 Score: -1.6674883365631104, Mean R^2 Score: -3.393190383911133
Starting epoch 34


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 34, Training Loss: 0.9791626233488457





Epoch 34, Validation Loss: 4.3948091715574265, Median R^2 Score: -1.667285442352295, Mean R^2 Score: -3.3928847312927246
Starting epoch 35


100%|██████████| 219/219 [00:39<00:00,  5.54it/s]

Epoch 35, Training Loss: 0.9788017101483802





Epoch 35, Validation Loss: 4.395048558712006, Median R^2 Score: -1.6671323776245117, Mean R^2 Score: -3.393126964569092
Starting epoch 36


100%|██████████| 219/219 [00:39<00:00,  5.54it/s]

Epoch 36, Training Loss: 0.9783339015969402





Epoch 36, Validation Loss: 4.3947475254535675, Median R^2 Score: -1.6668341159820557, Mean R^2 Score: -3.392838716506958
Starting epoch 37


100%|██████████| 219/219 [00:39<00:00,  5.55it/s]

Epoch 37, Training Loss: 0.9793514122157336





Epoch 37, Validation Loss: 4.395013898611069, Median R^2 Score: -1.6673157215118408, Mean R^2 Score: -3.3930983543395996
Starting epoch 38


100%|██████████| 219/219 [00:39<00:00,  5.54it/s]

Epoch 38, Training Loss: 0.9792564220080092





Epoch 38, Validation Loss: 4.395497918128967, Median R^2 Score: -1.6678006649017334, Mean R^2 Score: -3.3935885429382324
Starting epoch 39


100%|██████████| 219/219 [00:39<00:00,  5.54it/s]

Epoch 39, Training Loss: 0.978376051606653





Epoch 39, Validation Loss: 4.39545302093029, Median R^2 Score: -1.6681764125823975, Mean R^2 Score: -3.3935353755950928
Starting epoch 40


100%|██████████| 219/219 [00:39<00:00,  5.54it/s]

Epoch 40, Training Loss: 0.9797724458180606





Epoch 40, Validation Loss: 4.395103722810745, Median R^2 Score: -1.6680908203125, Mean R^2 Score: -3.393176555633545
Starting epoch 41


100%|██████████| 219/219 [00:39<00:00,  5.54it/s]

Epoch 41, Training Loss: 0.9798733064028771





Epoch 41, Validation Loss: 4.394871637225151, Median R^2 Score: -1.6687147617340088, Mean R^2 Score: -3.3929524421691895
Starting epoch 42


100%|██████████| 219/219 [00:39<00:00,  5.54it/s]

Epoch 42, Training Loss: 0.9797297446695092





Epoch 42, Validation Loss: 4.3951345682144165, Median R^2 Score: -1.6689410209655762, Mean R^2 Score: -3.39320969581604
Starting epoch 43


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 43, Training Loss: 0.9795325836634527





Epoch 43, Validation Loss: 4.394900619983673, Median R^2 Score: -1.6692836284637451, Mean R^2 Score: -3.3929829597473145
Starting epoch 44


100%|██████████| 219/219 [00:39<00:00,  5.53it/s]

Epoch 44, Training Loss: 0.9795267407752607





Epoch 44, Validation Loss: 4.395301327109337, Median R^2 Score: -1.669940710067749, Mean R^2 Score: -3.393397331237793
Starting epoch 45


100%|██████████| 219/219 [00:39<00:00,  5.52it/s]

Epoch 45, Training Loss: 0.9796932630887315





Epoch 45, Validation Loss: 4.3963883221149445, Median R^2 Score: -1.6698765754699707, Mean R^2 Score: -3.39448618888855
Starting epoch 46


100%|██████████| 219/219 [00:39<00:00,  5.54it/s]

Epoch 46, Training Loss: 0.979146372782041





Epoch 46, Validation Loss: 4.396534830331802, Median R^2 Score: -1.6694822311401367, Mean R^2 Score: -3.3946406841278076
Starting epoch 47


100%|██████████| 219/219 [00:39<00:00,  5.54it/s]

Epoch 47, Training Loss: 0.9800410055678729





Epoch 47, Validation Loss: 4.396549075841904, Median R^2 Score: -1.6699309349060059, Mean R^2 Score: -3.394657611846924
Starting epoch 48


100%|██████████| 219/219 [00:39<00:00,  5.54it/s]

Epoch 48, Training Loss: 0.9794975293281416





Epoch 48, Validation Loss: 4.396066322922707, Median R^2 Score: -1.6705286502838135, Mean R^2 Score: -3.3941762447357178
Starting epoch 49


  0%|          | 0/219 [00:00<?, ?it/s]