In [1]:
#Install prereqs, if necessary
%pip install h5py nilearn transformers torch vit-pytorch torchmetrics

Note: you may need to restart the kernel to use updated packages.


In [1]:
#Magic that doesn't work on lightning.ai for some reason
%load_ext autoreload
%autoreload 2

### Enable access to dataset module

In [2]:
import sys
import os

sys.path.append("/teamspace/studios/this_studio/mindscape/")

### Initialize the dataset & image pre-processing

In [3]:
from ImageVoxelsDataset import ImageVoxelsDataset
from torchvision import transforms
import torch

#image pre-processing
transform = transforms.Compose([
    transforms.ToPILImage(), # convert np array from dataset to PIL
    transforms.Resize((224, 224)), # resize to ViT dimensions
    transforms.ToTensor() # convert back to tensor
])

#train for subject 1
dataset = ImageVoxelsDataset('/teamspace/studios/this_studio/mindscape/nsd',
                                subject=1, 
                                transform=transform, 
                                preload_imgs = True)

Loading stimuli images...: 100%|██████████| 73000/73000 [00:15<00:00, 4578.35it/s]


### Convert betas back to percent signal change

According to the [NSD Data Manual](https://cvnlab.slite.page/p/6CusMRYfk0/Untitled):


> ...for some of [the] NSD data files that we have prepared, the betas have been multiplied by 300 and converted to int16 format to reduce space usage. Upon loading the beta files, the values should be immediately converted back to percent signal change by casting to decimal format (e.g. single or double) and dividing by 300.


In [4]:
def to_percent_signal_change(betas):
    return betas.float() / 300

dataset.target_transform = to_percent_signal_change

### Normalize the betas

Note: Mean and standard deviation before conversion of betas back to percent signal change are 1.762125820278688 and 3.0391348517898913, respectively.

In [5]:
# This takes ~10 minutes to calculate. The values are provided below to save time
# from util import find_mean_sd
# MEAN, STD_DEV = find_mean_sd(dataset)
# print(f"mean: {MEAN}\nstd. dev: {STD_DEV}")
# Before masking: MEAN, STD_DEV = 0.9780744683715907, 3.7325697644617897
MEAN, STD_DEV  = 1.762125820278688, 3.0391348517898913
def z_norm(n):
    return (n-MEAN) / STD_DEV

old_transform = dataset.target_transform
dataset.target_transform = lambda betas: z_norm(old_transform(betas))

### Instantiate model

In [6]:
from matrixvit import MatrixViTModel
from transformers import ViTConfig

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32



output_dims = dataset[0][1].shape

model = MatrixViTModel(
    output_dimensions=output_dims,
    image_size = (224, 224),
    patch_size = (16, 16),
    dim = 768, #hidden size
    depth = 6, #num hidden layers
    heads = 8,
    mlp_dim = 1024 #intermediate size    
).to(device=device, dtype=dtype)

### Train

In [11]:
from torch.utils.data import DataLoader, random_split
import torch
from tqdm import tqdm
from torch import nn, optim
from util import set_seed
from torchmetrics import R2Score
import time

#seed for reproducibility:
set_seed(47)

def seed_worker(worker_id):
    set_seed(47 + worker_id)




P_TRAIN = 0.7
P_VAL = 0.1

train_size = int(P_TRAIN * len(dataset))
val_size = int(P_VAL * len(dataset))
eval_size = len(dataset) - val_size - train_size

train_dataset, val_dataset, eval_dataset = random_split(dataset, [train_size, val_size, eval_size])

train_loader = DataLoader(train_dataset,  batch_size=32, num_workers=2, worker_init_fn=seed_worker)
val_loader = DataLoader(val_dataset, batch_size=64, num_workers=1, worker_init_fn=seed_worker)


loss_function = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.1, verbose=True, threshold=0.001)

num_betas = dataset[0][1].shape[0]
r2_score = R2Score(num_outputs = num_betas, multioutput="raw_values").to("cuda")

num_epochs = 300


model.to("cuda")
import time
try:
    for epoch in range(num_epochs):
        print(f"Starting epoch {epoch+1}")
        model.train()
        train_loss = 0.0

        #Train
        for images, targets in tqdm(train_loader):
            images, targets = images.to('cuda'), targets.to('cuda')            

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_function(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        print(f'Epoch {epoch+1}, Training Loss: {avg_train_loss}')

        #Validation
        model.eval()
        val_loss = 0.0
        r2_score.reset()

        with torch.no_grad():
            for images, targets in val_loader:
                
                images, targets = images.to('cuda'), targets.to('cuda')
                outputs = model(images)

                loss = loss_function(outputs, targets)
                val_loss += loss.item()
                r2_score.update(outputs, targets)

        avg_val_loss = val_loss / len(val_loader)
        val_r2 = r2_score.compute()  # Compute final R2 score for this epoch

        print(f'Epoch {epoch+1}, Validation Loss: {avg_val_loss}, Median R^2 Score: {torch.median(val_r2)}, Mean R^2 Score: {torch.mean(val_r2)}')

        scheduler.step(avg_val_loss)

        #Save checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
            'val_r2_score': val_r2  # Save the R2 score in the checkpoint

        }
        torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth')
except KeyboardInterrupt:
    train_loader = None



Starting epoch 1


  0%|          | 0/657 [00:00<?, ?it/s]

100%|██████████| 657/657 [00:51<00:00, 12.75it/s]

Epoch 1, Training Loss: 0.9876493509501627





Epoch 1, Validation Loss: 0.6246555358805554, Median R^2 Score: -0.14247524738311768
Starting epoch 2


100%|██████████| 657/657 [00:51<00:00, 12.79it/s]

Epoch 2, Training Loss: 0.937866181634151





Epoch 2, Validation Loss: 0.6071892253896023, Median R^2 Score: -0.07089722156524658
Starting epoch 3


100%|██████████| 657/657 [00:51<00:00, 12.80it/s]

Epoch 3, Training Loss: 0.8610368322019708





Epoch 3, Validation Loss: 0.6021114270737831, Median R^2 Score: -0.050757765769958496
Starting epoch 4


100%|██████████| 657/657 [00:51<00:00, 12.76it/s]

Epoch 4, Training Loss: 0.8419924426477975





Epoch 4, Validation Loss: 0.5991158941958813, Median R^2 Score: -0.0399165153503418
Starting epoch 5


100%|██████████| 657/657 [00:51<00:00, 12.80it/s]

Epoch 5, Training Loss: 0.7921791441364375





Epoch 5, Validation Loss: 0.5958746709722154, Median R^2 Score: -0.02740764617919922
Starting epoch 6


100%|██████████| 657/657 [00:51<00:00, 12.77it/s]

Epoch 6, Training Loss: 0.7631024257414598





Epoch 6, Validation Loss: 0.5955803381635788, Median R^2 Score: -0.02594304084777832
Starting epoch 7


100%|██████████| 657/657 [00:51<00:00, 12.78it/s]

Epoch 7, Training Loss: 0.7294833965860363





Epoch 7, Validation Loss: 0.5942013352475268, Median R^2 Score: -0.021456360816955566
Starting epoch 8


100%|██████████| 657/657 [00:51<00:00, 12.78it/s]

Epoch 8, Training Loss: 0.7222612160888799





Epoch 8, Validation Loss: 0.591855673079795, Median R^2 Score: -0.013492465019226074
Starting epoch 9


100%|██████████| 657/657 [00:51<00:00, 12.78it/s]

Epoch 9, Training Loss: 0.7063023976902257





Epoch 9, Validation Loss: 0.5907912596743158, Median R^2 Score: -0.009730696678161621
Starting epoch 10


100%|██████████| 657/657 [00:51<00:00, 12.70it/s]

Epoch 10, Training Loss: 0.6898867928818481





Epoch 10, Validation Loss: 0.5909087835474217, Median R^2 Score: -0.009861111640930176
Starting epoch 11


100%|██████████| 657/657 [00:51<00:00, 12.80it/s]

Epoch 11, Training Loss: 0.6785065482740533





Epoch 11, Validation Loss: 0.5896788011205957, Median R^2 Score: -0.006906986236572266
Starting epoch 12


100%|██████████| 657/657 [00:51<00:00, 12.80it/s]

Epoch 12, Training Loss: 0.6674859303317658





Epoch 12, Validation Loss: 0.5926982100973738, Median R^2 Score: -0.011166810989379883
Starting epoch 13


100%|██████████| 657/657 [00:51<00:00, 12.75it/s]

Epoch 13, Training Loss: 0.6619334485795763





Epoch 13, Validation Loss: 0.5896742140993159, Median R^2 Score: -0.0068247318267822266
Starting epoch 14


100%|██████████| 657/657 [00:51<00:00, 12.79it/s]

Epoch 14, Training Loss: 0.6538495726237014





Epoch 14, Validation Loss: 0.5885327750063957, Median R^2 Score: -0.0038884878158569336
Starting epoch 15


100%|██████████| 657/657 [00:51<00:00, 12.76it/s]

Epoch 15, Training Loss: 0.6475083892385346





Epoch 15, Validation Loss: 0.588391088424845, Median R^2 Score: -0.003548741340637207
Starting epoch 16


100%|██████████| 657/657 [00:51<00:00, 12.78it/s]

Epoch 16, Training Loss: 0.6419743649672881





Epoch 16, Validation Loss: 0.5884502834462105, Median R^2 Score: -0.0037424564361572266
Starting epoch 17


100%|██████████| 657/657 [00:51<00:00, 12.79it/s]

Epoch 17, Training Loss: 0.6368375134794679





Epoch 17, Validation Loss: 0.5879348858873895, Median R^2 Score: -0.00296175479888916
Starting epoch 18


100%|██████████| 657/657 [00:51<00:00, 12.80it/s]

Epoch 18, Training Loss: 0.6334773766577153





Epoch 18, Validation Loss: 0.5878688185772998, Median R^2 Score: -0.0029566287994384766
Starting epoch 19


100%|██████████| 657/657 [00:51<00:00, 12.80it/s]

Epoch 19, Training Loss: 0.6288444617567541





Epoch 19, Validation Loss: 0.587842880411351, Median R^2 Score: -0.0028524398803710938
Starting epoch 20


100%|██████████| 657/657 [00:51<00:00, 12.77it/s]

Epoch 20, Training Loss: 0.6246922523282253





Epoch 20, Validation Loss: 0.5879026232881749, Median R^2 Score: -0.0024698972702026367
Starting epoch 21


100%|██████████| 657/657 [00:51<00:00, 12.80it/s]

Epoch 21, Training Loss: 0.6218861697108415





Epoch 21, Validation Loss: 0.5877209153581173, Median R^2 Score: -0.002411484718322754
Starting epoch 22


100%|██████████| 657/657 [00:51<00:00, 12.78it/s]

Epoch 22, Training Loss: 0.6191015686255793





Epoch 22, Validation Loss: 0.5882710408657155, Median R^2 Score: -0.002777576446533203
Starting epoch 23


100%|██████████| 657/657 [00:51<00:00, 12.82it/s]

Epoch 23, Training Loss: 0.6167219814644557





Epoch 23, Validation Loss: 0.5875280106321294, Median R^2 Score: -0.0019606351852416992
Starting epoch 24


100%|██████████| 657/657 [00:51<00:00, 12.79it/s]

Epoch 24, Training Loss: 0.6144766330356104





Epoch 24, Validation Loss: 0.5865081700872867, Median R^2 Score: -0.000701904296875
Starting epoch 25


100%|██████████| 657/657 [00:51<00:00, 12.71it/s]

Epoch 25, Training Loss: 0.6134743457334045





Epoch 25, Validation Loss: 0.5864313003864694, Median R^2 Score: -0.00041854381561279297
Starting epoch 26


100%|██████████| 657/657 [00:51<00:00, 12.76it/s]

Epoch 26, Training Loss: 0.6130251228537189





Epoch 26, Validation Loss: 0.5863635476599348, Median R^2 Score: -0.00036025047302246094
Starting epoch 27


100%|██████████| 657/657 [00:51<00:00, 12.76it/s]

Epoch 27, Training Loss: 0.6127871453671332





Epoch 27, Validation Loss: 0.5862621538182522, Median R^2 Score: -0.00024449825286865234
Starting epoch 28


100%|██████████| 657/657 [00:51<00:00, 12.78it/s]

Epoch 28, Training Loss: 0.6124325647746047





Epoch 28, Validation Loss: 0.5861276071122352, Median R^2 Score: -0.00014710426330566406
Starting epoch 29


100%|██████████| 657/657 [00:51<00:00, 12.75it/s]

Epoch 29, Training Loss: 0.6116818421870425





Epoch 29, Validation Loss: 0.5861455747421752, Median R^2 Score: -1.6689300537109375e-06
Starting epoch 30


100%|██████████| 657/657 [00:51<00:00, 12.73it/s]

Epoch 30, Training Loss: 0.6119976044608396





Epoch 30, Validation Loss: 0.586189349915119, Median R^2 Score: -0.00020682811737060547
Starting epoch 31


100%|██████████| 657/657 [00:51<00:00, 12.79it/s]

Epoch 31, Training Loss: 0.6114513613680544





Epoch 31, Validation Loss: 0.586047852292974, Median R^2 Score: 7.003545761108398e-05
Starting epoch 32


100%|██████████| 657/657 [00:51<00:00, 12.77it/s]

Epoch 32, Training Loss: 0.6112824715617222





Epoch 32, Validation Loss: 0.5859822148972369, Median R^2 Score: 0.0001856088638305664
Starting epoch 33


100%|██████████| 657/657 [00:51<00:00, 12.68it/s]

Epoch 33, Training Loss: 0.6113129817006069





Epoch 33, Validation Loss: 0.5859804115396865, Median R^2 Score: 0.0002009272575378418
Starting epoch 34


100%|██████████| 657/657 [00:51<00:00, 12.76it/s]

Epoch 34, Training Loss: 0.6107709244506





Epoch 34, Validation Loss: 0.5859918873360817, Median R^2 Score: 0.00024133920669555664
Starting epoch 35


100%|██████████| 657/657 [00:51<00:00, 12.75it/s]

Epoch 35, Training Loss: 0.6111370282448952





Epoch 35, Validation Loss: 0.5859990005797528, Median R^2 Score: 0.000225067138671875
Starting epoch 36


100%|██████████| 657/657 [00:51<00:00, 12.74it/s]

Epoch 36, Training Loss: 0.6110005366929226





Epoch 36, Validation Loss: 0.5859723154534685, Median R^2 Score: 0.00024235248565673828
Starting epoch 37


100%|██████████| 657/657 [00:51<00:00, 12.71it/s]

Epoch 37, Training Loss: 0.6106319146250663





Epoch 37, Validation Loss: 0.5859750483898406, Median R^2 Score: 0.0002338886260986328
Starting epoch 38


100%|██████████| 657/657 [00:51<00:00, 12.74it/s]

Epoch 38, Training Loss: 0.6110705082819342





Epoch 38, Validation Loss: 0.5859808439904071, Median R^2 Score: 0.00023359060287475586
Starting epoch 39


100%|██████████| 657/657 [00:51<00:00, 12.78it/s]

Epoch 39, Training Loss: 0.6108981071541842





Epoch 39, Validation Loss: 0.5859865647681216, Median R^2 Score: 0.00024712085723876953
Starting epoch 40


100%|██████████| 657/657 [00:51<00:00, 12.76it/s]

Epoch 40, Training Loss: 0.6113773138918651





Epoch 40, Validation Loss: 0.5859872622692839, Median R^2 Score: 0.00023955106735229492
Starting epoch 41


100%|██████████| 657/657 [00:51<00:00, 12.73it/s]

Epoch 41, Training Loss: 0.6113821763426202





Epoch 41, Validation Loss: 0.5859881068797822, Median R^2 Score: 0.00023674964904785156
Starting epoch 42


100%|██████████| 657/657 [00:51<00:00, 12.67it/s]

Epoch 42, Training Loss: 0.6112435723185721





Epoch 42, Validation Loss: 0.5859881220979893, Median R^2 Score: 0.00023478269577026367
Starting epoch 43


 56%|█████▋    | 371/657 [00:29<00:22, 12.51it/s]
