# Algonauts Challenge

Modified algonauts' challenge template for training VAE-based model

## Environment configuraton

### User parameters

In [None]:
rand_seed = 5 #@param {allow-input: true}

platform = 'colab' #@param ['colab', 'jupyter_notebook'] {allow-input: true}

device = 'cuda' #@param ['cpu', 'cuda'] {allow-input: true}

subj = 6 #@param ["1", "2", "3", "4", "5", "6", "7", "8"] {type:"raw", allow-input: true}

batch_size = 64 #@param {type:"integer", allow-input: true}

latent_dim = 100 #@param {type:"integer", allow-input: true}

retrain = True #@param {type:"boolean", allow-input: true}

model_path = '/NeuroAI/vae_100' #@param {type:"string", allow-input: true}

### Set up environment

In [None]:
!git clone https://github.com/AntixK/PyTorch-VAE

fatal: destination path 'PyTorch-VAE' already exists and is not an empty directory.


In [None]:
!cd PyTorch-VAE

In [None]:
!pip install -r PyTorch-VAE/requirements.txt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import sys
sys.path.insert(0, '/content/PyTorch-VAE')

In [None]:
import os
import numpy as np
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import matplotlib
from matplotlib import pyplot as plt

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
from torchvision import transforms
from scipy.stats import pearsonr as corr
from torchsummary import summary

from models import VanillaVAE

In [None]:
if platform == 'colab':
    from google.colab import drive
    drive.mount('/content/drive/', force_remount=True)
    data_dir = '/content/drive/MyDrive/algonauts_2023_tutorial_data' #@param {type:"string"}
    parent_submission_dir = '/content/drive/MyDrive/algonauts_2023_challenge_submission' #@param {type:"string"}
elif platform == 'jupyter_notebook':
    data_dir = '../algonauts_2023_challenge_data'
    parent_submission_dir = '../algonauts_2023_challenge_submission'

Mounted at /content/drive/


In [None]:
device = torch.device(device)

In [None]:
class argObj:
  def __init__(self, data_dir, parent_submission_dir, subj):
    
    self.subj = format(subj, '02')
    self.data_dir = os.path.join(data_dir, 'subj'+self.subj)
    self.parent_submission_dir = parent_submission_dir
    self.subject_submission_dir = os.path.join(self.parent_submission_dir,
        'subj'+self.subj)

    # Create the submission directory if not existing
    if not os.path.isdir(self.subject_submission_dir):
        os.makedirs(self.subject_submission_dir)

args = argObj(data_dir, parent_submission_dir, subj)

## Preprocess data

### Load voxel data

In [None]:
fmri_dir = os.path.join(args.data_dir, 'training_split', 'training_fmri')
lh_fmri = np.load(os.path.join(fmri_dir, 'lh_training_fmri.npy'))
rh_fmri = np.load(os.path.join(fmri_dir, 'rh_training_fmri.npy'))

print('LH training fMRI data shape:')
print(lh_fmri.shape)
print('(Training stimulus images × LH vertices)')

print('\nRH training fMRI data shape:')
print(rh_fmri.shape)
print('(Training stimulus images × RH vertices)')

LH training fMRI data shape:
(9082, 18978)
(Training stimulus images × LH vertices)

RH training fMRI data shape:
(9082, 20220)
(Training stimulus images × RH vertices)


### Load images

In [None]:
train_img_dir  = os.path.join(args.data_dir, 'training_split', 'training_images')
test_img_dir  = os.path.join(args.data_dir, 'test_split', 'test_images')

# Create lists will all training and test image file names, sorted
train_img_list = os.listdir(train_img_dir)
train_img_list.sort()
test_img_list = os.listdir(test_img_dir)
test_img_list.sort()
print('Training images: ' + str(len(train_img_list)))
print('Test images: ' + str(len(test_img_list)))

Training images: 9082
Test images: 293


In [None]:
train_img_file = train_img_list[0]
print('Training image file name: ' + train_img_file)
print('73k NSD images ID: ' + train_img_file[-9:-4])

Training image file name: train-0001_nsd-00004.png
73k NSD images ID: 00004


### Spit into train/test

In [None]:
np.random.seed(rand_seed)

# Calculate how many stimulus images correspond to 90% of the training data
num_train = int(np.round(len(train_img_list) / 100 * 90))
# Shuffle all training stimulus images
idxs = np.arange(len(train_img_list))
np.random.shuffle(idxs)
# Assign 90% of the shuffled stimulus images to the training partition,
# and 10% to the test partition
idxs_train, idxs_val = idxs[:num_train], idxs[num_train:]
# No need to shuffle or split the test stimulus images
idxs_test = np.arange(len(test_img_list))

print('Training stimulus images: ' + format(len(idxs_train)))
print('\nValidation stimulus images: ' + format(len(idxs_val)))
print('\nTest stimulus images: ' + format(len(idxs_test)))

Training stimulus images: 8174

Validation stimulus images: 908

Test stimulus images: 293


### Create data pipeline

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)), # resize the images to 224x224 pixels
    transforms.ToTensor(), # convert the images to a PyTorch tensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # normalize the images color channels
    transforms.RandomRotation(90)
    #transforms.ColorJitter(0.5, 0.5, 0.5, 0.5)
])

### Define class for easy image manipulation

In [None]:
class ImageDataset(Dataset):
    def __init__(self, imgs_paths, idxs, transform):
        self.imgs_paths = np.array(imgs_paths)[idxs]
        self.transform = transform

    def __len__(self):
        return len(self.imgs_paths)

    def __getitem__(self, idx):
        # Load the image
        img_path = self.imgs_paths[idx]
        img = Image.open(img_path).convert('RGB')
        # Preprocess the image and send it to the chosen device ('cpu' or 'cuda')
        if self.transform:
            img = self.transform(img).to(device)
        return img

In [None]:
# Get the paths of all image files
train_imgs_paths = sorted(list(Path(train_img_dir).iterdir()))
test_imgs_paths = sorted(list(Path(test_img_dir).iterdir()))

# The DataLoaders contain the ImageDataset class
train_imgs_dataloader = DataLoader(
    ImageDataset(train_imgs_paths, idxs_train, transform), 
    batch_size=batch_size
)
val_imgs_dataloader = DataLoader(
    ImageDataset(train_imgs_paths, idxs_val, transform), 
    batch_size=batch_size
)
test_imgs_dataloader = DataLoader(
    ImageDataset(test_imgs_paths, idxs_test, transform), 
    batch_size=batch_size
)

### Delete fmri data (not needed for training)

In [None]:
del lh_fmri, rh_fmri

## Model training

### Load model

In [None]:
if (retrain):
  from google.colab import drive
  drive.mount('/content/drive/', force_remount=True)
  model = torch.load(os.path.join('/content/drive/MyDrive/', model_path))
else:
  model = VanillaVAE(in_channels=3, latent_dim=latent_dim)

model.to(device)

summary(model, (3, 64, 64), device='cuda')

Mounted at /content/drive/
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 32, 32]             896
       BatchNorm2d-2           [-1, 32, 32, 32]              64
         LeakyReLU-3           [-1, 32, 32, 32]               0
            Conv2d-4           [-1, 64, 16, 16]          18,496
       BatchNorm2d-5           [-1, 64, 16, 16]             128
         LeakyReLU-6           [-1, 64, 16, 16]               0
            Conv2d-7            [-1, 128, 8, 8]          73,856
       BatchNorm2d-8            [-1, 128, 8, 8]             256
         LeakyReLU-9            [-1, 128, 8, 8]               0
           Conv2d-10            [-1, 256, 4, 4]         295,168
      BatchNorm2d-11            [-1, 256, 4, 4]             512
        LeakyReLU-12            [-1, 256, 4, 4]               0
           Conv2d-13            [-1, 512, 2, 2]       1,180,160
      BatchN

### Define training loop

In [None]:
def Training(model, dataloader, epochs=1):
  
  # parameters
  lr = 0.005
  train_loss = 0.0

  # optimizer
  optimizer = torch.optim.Adam(model.parameters(),
                               lr=lr,
                               weight_decay=1e-5)
  
  # iterate the dataloader
  for _, x in tqdm(enumerate(dataloader), total=len(dataloader)):
    for _ in range(epochs):

      x = x.to(device)

      # run data through model
      downsampled = torch.nn.functional.interpolate(x, size=[64, 64], mode='bilinear')
      out, input, mu, log_var = model(downsampled)
      x_hat = torch.nn.functional.interpolate(out, size=[224, 224], mode='bilinear')

      # evaluate loss
      loss = model.loss_function(out, input, mu, log_var, M_N=1)

      # backward pass
      optimizer.zero_grad()
      loss['loss'].backward()
      optimizer.step()

      # print batch loss
      print('\t train loss: %f' % (loss['loss'].item()))
      train_loss += loss['loss'].item()

  return train_loss / len(dataloader.dataset)

In [None]:
def Validate(model, dataloader, epochs=1, retrain=False):

  # parameters
  lr = 0.005
  val_loss = 0.0

  # optimizer
  optimizer = torch.optim.Adam(model.parameters(),
                               lr=lr,
                               weight_decay=1e-5)
  
  # iterate the dataloader
  for _, x in tqdm(enumerate(dataloader), total=len(dataloader)):
    for _ in range(epochs):

      x = x.to(device)

      # run data through model
      downsampled = torch.nn.functional.interpolate(x, size=[64, 64], mode='bilinear')
      out, input, mu, log_var = model(downsampled)
      x_hat = torch.nn.functional.interpolate(out, size=[224, 224], mode='bilinear')

      # evaluate loss
      loss = model.loss_function(out, input, mu, log_var, M_N=1)

      # backward pass
      if (retrain):
        optimizer.zero_grad()
        loss['loss'].backward()
        optimizer.step()

      # print batch loss
      print('\t validation loss: %f' % (loss['loss'].item()))
      val_loss += loss['loss'].item()

  return val_loss / len(dataloader.dataset)

### Train variational autoencoder

In [None]:
train_loss = Training(model, train_imgs_dataloader)

  1%|          | 1/128 [00:58<2:04:50, 58.98s/it]

	 train loss: 1.251184


  2%|▏         | 2/128 [01:57<2:03:11, 58.66s/it]

	 train loss: 1.343740


  2%|▏         | 3/128 [02:54<2:00:36, 57.89s/it]

	 train loss: 1.259307


  3%|▎         | 4/128 [03:51<1:59:14, 57.70s/it]

	 train loss: 1.285219


  4%|▍         | 5/128 [04:49<1:58:38, 57.88s/it]

	 train loss: 1.316059


  5%|▍         | 6/128 [05:48<1:57:50, 57.95s/it]

	 train loss: 1.261640


  5%|▌         | 7/128 [06:44<1:55:58, 57.51s/it]

	 train loss: 1.264335


  6%|▋         | 8/128 [07:42<1:55:09, 57.58s/it]

	 train loss: 1.340415


  7%|▋         | 9/128 [08:40<1:54:27, 57.71s/it]

	 train loss: 1.269695


  8%|▊         | 10/128 [09:38<1:53:33, 57.74s/it]

	 train loss: 1.193209


  9%|▊         | 11/128 [10:35<1:52:19, 57.60s/it]

	 train loss: 1.418529


  9%|▉         | 12/128 [11:32<1:50:53, 57.36s/it]

	 train loss: 1.258581


 10%|█         | 13/128 [12:28<1:49:23, 57.07s/it]

	 train loss: 1.184211


 11%|█         | 14/128 [13:25<1:48:13, 56.96s/it]

	 train loss: 1.215200


 12%|█▏        | 15/128 [14:22<1:47:29, 57.07s/it]

	 train loss: 1.233599


 12%|█▎        | 16/128 [15:20<1:47:06, 57.38s/it]

	 train loss: 1.238379


 13%|█▎        | 17/128 [16:17<1:45:55, 57.26s/it]

	 train loss: 1.291192


 14%|█▍        | 18/128 [17:14<1:44:29, 57.00s/it]

	 train loss: 1.230030


 15%|█▍        | 19/128 [18:10<1:43:08, 56.77s/it]

	 train loss: 1.256905


 16%|█▌        | 20/128 [19:06<1:41:58, 56.65s/it]

	 train loss: 1.224156


 16%|█▋        | 21/128 [20:03<1:41:03, 56.67s/it]

	 train loss: 1.157879


 17%|█▋        | 22/128 [21:01<1:41:03, 57.20s/it]

	 train loss: 1.192342


 18%|█▊        | 23/128 [21:58<1:39:36, 56.92s/it]

	 train loss: 1.231669


 19%|█▉        | 24/128 [22:54<1:38:31, 56.84s/it]

	 train loss: 1.230034


 20%|█▉        | 25/128 [23:51<1:37:14, 56.65s/it]

	 train loss: 1.124402


 20%|██        | 26/128 [24:47<1:36:20, 56.67s/it]

	 train loss: 1.385329


 20%|██        | 26/128 [25:04<1:38:22, 57.86s/it]


KeyboardInterrupt: ignored

In [None]:
val_loss = Training(model, val_imgs_dataloader, epochs=2)

In [None]:
test_loss = Validate(model, test_imgs_dataloader, epochs=2, retrain=True)

## Save VAE

In [None]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

torch.save(model, os.path.join('/content/drive/MyDrive/', model_path))

Mounted at /content/drive/
