In [1]:
#Install prereqs, if necessary
%pip install h5py nilearn transformers torch vit-pytorch

Collecting vit-pytorch
  Downloading vit_pytorch-1.6.5-py3-none-any.whl.metadata (697 bytes)
Collecting einops>=0.7.0 (from vit-pytorch)
  Downloading einops-0.7.0-py3-none-any.whl.metadata (13 kB)
Downloading vit_pytorch-1.6.5-py3-none-any.whl (100 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m100.3/100.3 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops, vit-pytorch
Successfully installed einops-0.7.0 vit-pytorch-1.6.5
Note: you may need to restart the kernel to use updated packages.


In [1]:
#Magic that doesn't work on lightning.ai for some reason
%load_ext autoreload
%autoreload 2

### Enable access to dataset module

In [2]:
import sys
import os

sys.path.append("/teamspace/studios/this_studio/mindscape/")

### Initialize the dataset & image pre-processing

In [3]:
from ImageVoxelsDataset import ImageVoxelsDataset
from torchvision import transforms
import torch

#image pre-processing
transform = transforms.Compose([
    transforms.ToPILImage(), # convert np array from dataset to PIL
    transforms.Resize((224, 224)), # resize to ViT dimensions
    transforms.ToTensor() # convert back to tensor
])

#train for subject 1
dataset = ImageVoxelsDataset('/teamspace/studios/this_studio/mindscape/nsd', 1, transform=transform, cache_size = 1)

### Convert betas back to percent signal change

According to the [NSD Data Manual](https://cvnlab.slite.page/p/6CusMRYfk0/Untitled):


> ...for some of [the] NSD data files that we have prepared, the betas have been multiplied by 300 and converted to int16 format to reduce space usage. Upon loading the beta files, the values should be immediately converted back to percent signal change by casting to decimal format (e.g. single or double) and dividing by 300.


In [4]:
def to_percent_signal_change(betas):
    return betas.float() / 300

dataset.target_transform = to_percent_signal_change

### Normalize the betas

Note: Mean and standard deviation before conversion of betas back to percent signal change are 293.4223405114772 and 1119.770929338537, respectively.

In [5]:
# This takes ~10 minutes to calculate. The values are provided below to save time
# from util import find_mean_sd
# MEAN, STD_DEV = find_mean_sd(dataset)
MEAN, STD_DEV = 0.9780744683715907, 3.7325697644617897

def z_norm(n):
    return (n-MEAN) / STD_DEV

old_transform = dataset.target_transform
dataset.target_transform = lambda betas: z_norm(old_transform(betas))

### Instantiate model

In [6]:
from matrixvit import MatrixViTModel
from transformers import ViTConfig

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32



output_dims = dataset[0][1].shape

model = MatrixViTModel(
    output_dimensions=output_dims,
    image_size = (224, 224),
    patch_size = (16, 16),
    dim = 768, #hidden size
    depth = 6, #num hidden layers
    heads = 8,
    mlp_dim = 1024 #intermediate size    
).to(device=device, dtype=dtype)

### Train

In [7]:
from torch.utils.data import DataLoader, random_split
import torch
from tqdm import tqdm
from torch import nn, optim
from util import set_seed

#seed for reproducability:
set_seed(47)


P_TRAIN = 0.7
P_VAL = 0.1

train_size = int(P_TRAIN * len(dataset))
val_size = int(P_VAL * len(dataset))
eval_size = len(dataset) - val_size - train_size

train_dataset, val_dataset, eval_dataset = random_split(dataset, [train_size, val_size, eval_size])

#not shuffling allows the cache to be effectively used
train_loader = DataLoader(train_dataset, shuffle=False, batch_size=64, num_workers=2)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=64, num_workers=1)


loss_function = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 300


model.to("cuda")

try:
    for epoch in range(num_epochs):
        print(f"Starting epoch {epoch}")
        model.train()
        train_loss = 0.0

        #Train
        for images, targets in tqdm(train_loader):
            images, targets = images.to('cuda'), targets.to('cuda')
            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_function(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        print(f'Epoch {epoch+1}, Training Loss: {avg_train_loss}')

        #Validation
        model.eval()
        val_loss = 0.0

        with torch.no_grad():
            for images, targets in val_loader:
                images, targets = images.to('cuda'), targets.to('cuda')
                outputs = model(images)
                loss = loss_function(outputs, targets)
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_loader)
        print(f'Epoch {epoch+1}, Validation Loss: {avg_val_loss}')


        #Save checkpoint
        checkpoint = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss
        }
        torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth')
except KeyboardInterrupt:
    train_loader = None

Starting epoch 0


  0%|          | 0/329 [00:00<?, ?it/s]

  1%|          | 4/329 [11:40<15:48:33, 175.12s/it]


In [8]:
total_params = sum(p.numel() for p in model.parameters())

total_params

29009534

In [5]:
len(dataset)

30000