In [None]:
#Install prereqs, if necessary
%pip install h5py nilearn

In [2]:
#Magic that doesn't work on lightning.ai for some reason
%load_ext autoreload
%autoreload 2

### Enable access to dataset module

In [None]:
import sys
import os

parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)

### Initialize the dataset & image pre-processing

In [3]:
from ImageVoxelDataset import ImageVoxelsDataset
from torchvision import transforms
import torch

#image pre-processing
transform = transforms.Compose([
    transforms.ToPILImage(), # convert np array from dataset to PIL
    transforms.Resize((224, 224)), # resize to ViT dimensions
    transforms.ToTensor() # convert back to tensor
])

#train for subject 1
dataset = ImageVoxelsDataset('./nsd', 1, transform=transform, cache_size = 1)

### Convert betas back to percent signal change

According to the [NSD Data Manual](https://cvnlab.slite.page/p/6CusMRYfk0/Untitled):


> ...for some of [the] NSD data files that we have prepared, the betas have been multiplied by 300 and converted to int16 format to reduce space usage. Upon loading the beta files, the values should be immediately converted back to percent signal change by casting to decimal format (e.g. single or double) and dividing by 300.


In [None]:
def to_percent_signal_change(betas):
    return betas.float() / 300

dataset.target_transform = to_percent_signal_change

### Normalize the betas

Note: Mean and standard deviation before conversion of betas back to percent signal change are 293.4223405114772 and 1119.770929338537, respectively.

In [14]:
# This takes ~10 minutes to calculate. The values are provided below to save time
# from util import find_mean_sd
# MEAN, STD_DEV = find_mean_sd(dataset)
MEAN, STD_DEV = 0.9780744683715907, 3.7325697644617897

def z_norm(n):
    return (n-MEAN) / STD_DEV

old_transform = dataset.target_transform
dataset.target_transform = lambda betas: z_norm(old_transform(betas))

### Instantiate model

In [None]:
from matrixvit import MatrixViTModel

output_dims = dataset[0][1].shape

model = MatrixViTModel(
    output_dimensions=output_dims
)

### Train

In [None]:
from torch.utils.data import DataLoader, random_split
import torch
from torch import nn, optim

P_TRAIN = 0.8

train_size = int(P_TRAIN * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

#not shuffling allows the cache to be effectively used
train_loader = DataLoader(train_dataset, shuffle=False, batch_size=64, num_workers=4)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=64, num_workers=4)


loss_function = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 300

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0

    #Train
    for images, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_function(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    print(f'Epoch {epoch+1}, Training Loss: {avg_train_loss}')

    #Validation
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for images, targets in val_loader:
            outputs = model(images)
            loss = loss_function(outputs, targets)
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader)
    print(f'Epoch {epoch+1}, Validation Loss: {avg_val_loss}')


    #Save checkpoint
    checkpoint = {
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss
    }
    torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth')