In [1]:
#Install prereqs, if necessary
%pip install h5py nilearn transformers torch vit-pytorch torchmetrics

Note: you may need to restart the kernel to use updated packages.


In [1]:
#Magic that doesn't work on lightning.ai for some reason
%load_ext autoreload
%autoreload 2

### Enable access to dataset module

In [1]:
import sys
import os

sys.path.append("/teamspace/studios/this_studio/mindscape/")

### Initialize the dataset & image pre-processing

In [2]:
from ImageVoxelsDataset import ImageVoxelsDataset
from torchvision import transforms
import torch

#image pre-processing
transform = transforms.Compose([
    transforms.ToPILImage(), # convert np array from dataset to PIL
    transforms.Resize((224, 224)), # resize to ViT dimensions
    transforms.ToTensor() # convert back to tensor
])

#train for subject 1
dataset = ImageVoxelsDataset('/teamspace/studios/this_studio/nsd',
                                subject=1, 
                                transform=transform
                                preload_imgs=True)

### Convert betas back to percent signal change

According to the [NSD Data Manual](https://cvnlab.slite.page/p/6CusMRYfk0/Untitled):


> ...for some of [the] NSD data files that we have prepared, the betas have been multiplied by 300 and converted to int16 format to reduce space usage. Upon loading the beta files, the values should be immediately converted back to percent signal change by casting to decimal format (e.g. single or double) and dividing by 300.


In [3]:
def to_percent_signal_change(betas):
    return betas.float() / 300

dataset.target_transform = to_percent_signal_change

### Normalize the betas

Note: Mean and standard deviation before conversion of betas back to percent signal change are 1.762125820278688 and 3.0391348517898913, respectively.

In [4]:
# This takes ~10 minutes to calculate. The values are provided below to save time
# from util import find_mean_sd
# MEAN, STD_DEV = find_mean_sd(dataset)
# print(f"mean: {MEAN}\nstd. dev: {STD_DEV}")
# Before masking: MEAN, STD_DEV = 0.9780744683715907, 3.7325697644617897
MEAN, STD_DEV  = 1.762125820278688, 3.0391348517898913
def z_norm(n):
    return (n-MEAN) / STD_DEV

old_transform = dataset.target_transform
dataset.target_transform = lambda betas: z_norm(old_transform(betas))

### Instantiate model

In [5]:
from IA_ViT import MatrixViTModel
from transformers import ViTConfig
import nilearn.image
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32

prf_atlas = nilearn.image.get_data('/teamspace/studios/this_studio/nsd/nsddata/ppdata/subj01/func1pt8mm/roi/prf-visualrois.nii.gz')
floc_atlas = nilearn.image.get_data('/teamspace/studios/this_studio/nsd/nsddata/ppdata/subj01/func1pt8mm/roi/floc-faces.nii.gz')

output_dims = [
    np.sum(prf_atlas==1),
    np.sum(prf_atlas==2),
    np.sum(prf_atlas==3),
    np.sum(prf_atlas==4),
    np.sum(prf_atlas==5),
    np.sum(prf_atlas==6),
    np.sum(floc_atlas==1),
    np.sum(floc_atlas==2),
    np.sum(floc_atlas==3),
    np.sum(floc_atlas==4),
    np.sum(floc_atlas==5),
]

model = MatrixViTModel(
    output_dimensions=output_dims,
    num_rois=11
    #image_size = (224, 224),
    #patch_size = (16, 16),
    #dim = 768, #hidden size
    #depth = 6, #num hidden layers
    #heads = 8,
    #mlp_dim = 1024 #intermediate size    
).to(device=device, dtype=dtype)



### Train

In [19]:
from torch.utils.data import DataLoader, random_split
from torchvision.utils import save_image
import torch
from tqdm import tqdm
from torch import nn, optim
from util import set_seed
from torchmetrics import R2Score
import time

#seed for reproducibility:
set_seed(47)

def seed_worker(worker_id):
    set_seed(47 + worker_id)




P_TRAIN = 0.7
P_VAL = 0.1

train_size = int(P_TRAIN * len(dataset))
val_size = int(P_VAL * len(dataset))
eval_size = len(dataset) - val_size - train_size

train_dataset, val_dataset, eval_dataset = random_split(dataset, [train_size, val_size, eval_size])

train_loader = DataLoader(train_dataset,  batch_size=32, num_workers=2, worker_init_fn=seed_worker)
val_loader = DataLoader(val_dataset, batch_size=64, num_workers=1, worker_init_fn=seed_worker)


loss_function = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.1, verbose=True, threshold=0.001)

num_betas = dataset[0][1].shape[0]
r2_score = R2Score(num_outputs = num_betas, multioutput="raw_values").to("cuda")

num_epochs = 300

#check = torch.load('/teamspace/studios/this_studio/mindscape/IA_ViT/checkpoint_epoch_19.pth')
#model.load.state_dict(check['optimizer_state_dict'])
#epoch = check['epoch']
#loss = checkpoint['loss']

model.to("cuda")
import time
try:
    for epoch in range(num_epochs):
        print(f"Starting epoch {epoch+1}")
        model.train()
        train_loss = 0.0

        #Train
        i = 0
        for images, targets in tqdm(train_loader):
            i = i + 1
            images, targets = images.to('cuda'), targets.to('cuda')            
            print(targets.shape)
            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_function(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        print(f'Epoch {epoch+1}, Training Loss: {avg_train_loss}')

        #Validation
        model.eval()
        val_loss = 0.0
        r2_score.reset()

        with torch.no_grad():
            for images, targets in val_loader:
                
                images, targets = images.to('cuda'), targets.to('cuda')
                outputs = model(images)

                loss = loss_function(outputs, targets)
                val_loss += loss.item()
                r2_score.update(outputs, targets)

        avg_val_loss = val_loss / len(val_loader)
        val_r2 = r2_score.compute()  # Compute final R2 score for this epoch

        print(f'Epoch {epoch+1}, Validation Loss: {avg_val_loss}, Median R^2 Score: {torch.median(val_r2)}, Mean R^2 Score: {torch.mean(val_r2)}')

        scheduler.step(avg_val_loss)

        #Save checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
            'val_r2_score': val_r2  # Save the R2 score in the checkpoint

        }
        torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth')
except KeyboardInterrupt:
    train_loader = None

Starting epoch 1


  0%|          | 0/657 [00:00<?, ?it/s]

torch.Size([32, 5278])


  0%|          | 1/657 [00:00<06:49,  1.60it/s]

torch.Size([32, 5278])


  0%|          | 2/657 [00:01<06:51,  1.59it/s]

torch.Size([32, 5278])



