## In this notebook I'll train a brain to image and text features and combine their predictions

In [1]:
import numpy as np
import nibabel as nib
import nilearn 
import matplotlib.pyplot as plt
import os
from os.path import join as opj
import pandas as pd
import seaborn as sns
import glob
from nilearn import plotting
from nilearn.image import *
import tqdm
from PIL import Image
from sklearn.model_selection import train_test_split
from nilearn.plotting import plot_stat_map
from nilearn.image import mean_img
from nilearn.plotting import plot_img, plot_epi
from nilearn.maskers import NiftiMasker
from sklearn.preprocessing import StandardScaler
import wandb
import pickle
from torch.utils.data import Dataset, DataLoader
from dataset import fMRI_Dataset, fMRI_Multi_Dataset
import torch
from torch import nn
from pytorch_lightning.callbacks import EarlyStopping,ModelCheckpoint

from pytorch_lightning.loggers import WandbLogger
from network import Encoder, ContrastiveModel
import torch
import torch.nn as nn
import pytorch_lightning as pl

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
import dataset
import importlib
importlib.reload(dataset)
from dataset import fMRI_Dataset, fMRI_Multi_Dataset

In [2]:
wandb.login()
wandb.init(project="BrainTuning",config={"model":"multimodal_model"})

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmatteoferrante[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [13]:
## load the data

train_datasets_images=[]
val_datasets_images=[]
test_datasets_images=[]

train_datasets_text=[]
val_datasets_text=[]
test_datasets_text=[]


for subj in tqdm.tqdm(["CSI1","CSI2","CSI3","CSI4"]):

    subj_id = int(subj.split("CSI")[1])

    data_path =  f"/home/matteo/storage/brain_tuning/{subj}"

    train_fmri = np.load(opj(data_path, "train_fmri_top.npy"))
    val_fmri = np.load(opj(data_path, "val_fmri_top.npy"))
    test_fmri = np.load(opj(data_path, "test_fmri_top.npy"))

    ##load the images
    img_train = np.load(opj(data_path, "img_train.npy"),allow_pickle=True)
    img_val = np.load(opj(data_path, "img_val.npy"),allow_pickle=True)
    img_test = np.load(opj(data_path, "img_test.npy"),allow_pickle=True)

     ##load the captions
    train_captions = np.load(opj(data_path, "train_captions.npy"),allow_pickle=True)
    val_captions = np.load(opj(data_path, "val_captions.npy"),allow_pickle=True)
    test_captions = np.load(opj(data_path, "test_captions.npy"),allow_pickle=True)

    ## load the features
    train_features = np.load(opj(data_path, "train_image_features.npy"))
    val_features = np.load(opj(data_path, "val_image_features.npy"))
    test_features = np.load(opj(data_path, "test_image_features.npy"))

    ## load the text features
    train_text_features = np.load(opj(data_path, "train_text_features.npy"))
    val_text_features = np.load(opj(data_path, "val_text_features.npy"))
    test_text_features = np.load(opj(data_path, "test_text_features.npy"))

    ## create the dataset
    train_dataset = fMRI_Multi_Dataset(train_fmri,img_train,train_captions,train_features,train_text_features,subj_id,feature_type="image")
    val_dataset = fMRI_Multi_Dataset(val_fmri,img_val,val_captions,val_features,val_text_features,subj_id,feature_type="image")
    test_dataset = fMRI_Multi_Dataset(test_fmri,img_test,test_captions,test_features,test_text_features, subj_id,feature_type="image")

    ## append the datasets
    train_datasets_images.append(train_dataset)
    val_datasets_images.append(val_dataset)
    test_datasets_images.append(test_dataset)

    train_dataset_text = fMRI_Multi_Dataset(train_fmri,img_train,train_captions,train_features,train_text_features,subj_id,feature_type="text")
    val_dataset_text = fMRI_Multi_Dataset(val_fmri,img_val,val_captions,val_features,val_text_features,subj_id,feature_type="text")
    test_dataset_text = fMRI_Multi_Dataset(test_fmri,img_test,test_captions,test_features,test_text_features,subj_id,feature_type="text")

    ## append the datasets
    train_datasets_text.append(train_dataset_text)
    val_datasets_text.append(val_dataset_text)
    test_datasets_text.append(test_dataset_text)







100%|██████████| 4/4 [00:07<00:00,  1.97s/it]


In [14]:
## Concatenate the datasets
train_dataset_images = torch.utils.data.ConcatDataset(train_datasets)
val_dataset_images = torch.utils.data.ConcatDataset(val_datasets)
test_dataset_images = torch.utils.data.ConcatDataset(test_datasets)

train_dataset_text = torch.utils.data.ConcatDataset(train_datasets_text)
val_dataset_text = torch.utils.data.ConcatDataset(val_datasets_text)
test_dataset_text = torch.utils.data.ConcatDataset(test_datasets_text)




In [15]:
BATCH_SIZE = 512
train_loader = DataLoader(train_dataset_images, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset_images, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset_images, batch_size=BATCH_SIZE, shuffle=False)

train_loader_text = DataLoader(train_dataset_text, batch_size=BATCH_SIZE, shuffle=True)
val_loader_text = DataLoader(val_dataset_text, batch_size=BATCH_SIZE, shuffle=False)
test_loader_text = DataLoader(test_dataset_text, batch_size=BATCH_SIZE, shuffle=False)

## Train an Brain2ImageModel

In [5]:
## optimal parameter obtained from the hyperparameter search

act_fn = nn.Identity
base_channel_size = [2048]
hidden_dims = [1024]
latent_dim = 768

loss_type = "contrastive"
lr = 1e-4
temperature = 0.1
wd = 1e-5
alpha = 0.8

In [9]:
brain_image_model = ContrastiveModel(num_input_channels= 10000,
                                base_channel_size=base_channel_size, 
                                hidden_dims=hidden_dims,
                                latent_dim=latent_dim,
                                act_fn=act_fn,
                                loss_type=loss_type,
                                lr = lr,
                                wd = wd,
                                alpha=alpha)

# Set up early stopping to monitor 'val_loss'
early_stop_callback = EarlyStopping(monitor='val_loss', patience=1,verbose=True, mode='min')             # 'min' because we want to minimize val_loss
wandb_logger = WandbLogger()  # Logs the model and metrics to wandb


# Set up early stopping to monitor 'val_loss'
early_stop_callback = EarlyStopping(monitor='val_loss', patience=1,verbose=True, mode='min')             # 'min' because we want to minimize val_loss
wandb_logger = WandbLogger()  # Logs the model and metrics to wandb


# Create a unique checkpoint directory based on the run name or ID
run_name = "multimodal_model_IMAGE"
checkpoint_dir = os.path.join(data_path, "models_multi", run_name)
os.makedirs(checkpoint_dir, exist_ok=True)

# Model checkpoint configuration
checkpoint_callback = ModelCheckpoint(monitor='val_loss',dirpath=checkpoint_dir,filename='brain_image_model-{epoch:02d}-{val_loss:.2f}',save_top_k=3,mode='min',)


# Initialize trainer with logger

trainer = pl.Trainer(max_epochs=20, devices=[1], callbacks=[early_stop_callback,checkpoint_callback],logger=wandb_logger ) # Add the wandb logger here

trainer.fit(brain_image_model, train_loader, val_loader)

/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/matteo/anaconda3/envs/borg/lib/python3.8/site- ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:477: The total number of parameters detected may be inaccurate because the model contains an instance of `UninitializedParameter`. To get an accurate number, set `self.example_input_array` in your LightningModule.

  | Name  | Type    | Params | Mode 
------------------------------------------
0 | model | Encoder | 9.7 K  | train
------------------------------------------
9.7 K     Trainable params
0         Non-trainable params
9.7 K     Total params
0.039     Total estimated model params size (MB)
7         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:475: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  3.52it/s]

/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 512. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


                                                                           

/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.
/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (27) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 27/27 [00:24<00:00,  1.09it/s, v_num=9fzm, train_loss_step=5.120]

/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 317. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Epoch 0: 100%|██████████| 27/27 [00:28<00:00,  0.95it/s, v_num=9fzm, train_loss_step=5.120, val_loss=5.440, val_mse_loss=1.720, val_cosine_similarity=0.120, train_loss_epoch=5.800]

/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 359. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
Metric val_loss improved. New best score: 5.441


Epoch 1: 100%|██████████| 27/27 [00:27<00:00,  0.98it/s, v_num=9fzm, train_loss_step=4.420, val_loss=5.220, val_mse_loss=1.660, val_cosine_similarity=0.153, train_loss_epoch=5.010]

Metric val_loss improved by 0.223 >= min_delta = 0.0. New best score: 5.218


Epoch 2: 100%|██████████| 27/27 [00:25<00:00,  1.05it/s, v_num=9fzm, train_loss_step=4.120, val_loss=5.100, val_mse_loss=1.630, val_cosine_similarity=0.167, train_loss_epoch=4.610]

Metric val_loss improved by 0.117 >= min_delta = 0.0. New best score: 5.101


Epoch 3: 100%|██████████| 27/27 [00:36<00:00,  0.74it/s, v_num=9fzm, train_loss_step=3.830, val_loss=5.040, val_mse_loss=1.620, val_cosine_similarity=0.174, train_loss_epoch=4.270]

Metric val_loss improved by 0.060 >= min_delta = 0.0. New best score: 5.041


Epoch 4: 100%|██████████| 27/27 [00:43<00:00,  0.62it/s, v_num=9fzm, train_loss_step=3.520, val_loss=5.000, val_mse_loss=1.610, val_cosine_similarity=0.177, train_loss_epoch=3.980]

Metric val_loss improved by 0.038 >= min_delta = 0.0. New best score: 5.003


Epoch 5: 100%|██████████| 27/27 [00:48<00:00,  0.56it/s, v_num=9fzm, train_loss_step=3.320, val_loss=4.990, val_mse_loss=1.610, val_cosine_similarity=0.177, train_loss_epoch=3.730]

Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 4.987


Epoch 6: 100%|██████████| 27/27 [00:35<00:00,  0.76it/s, v_num=9fzm, train_loss_step=3.120, val_loss=4.980, val_mse_loss=1.610, val_cosine_similarity=0.175, train_loss_epoch=3.520]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 4.984


Epoch 7: 100%|██████████| 27/27 [00:25<00:00,  1.07it/s, v_num=9fzm, train_loss_step=2.960, val_loss=4.990, val_mse_loss=1.620, val_cosine_similarity=0.173, train_loss_epoch=3.330]

Monitored metric val_loss did not improve in the last 1 records. Best score: 4.984. Signaling Trainer to stop.


Epoch 7: 100%|██████████| 27/27 [00:25<00:00,  1.05it/s, v_num=9fzm, train_loss_step=2.960, val_loss=4.990, val_mse_loss=1.620, val_cosine_similarity=0.173, train_loss_epoch=3.330]


## Train the Text model

In [16]:
brain_text_model = ContrastiveModel(num_input_channels= 10000,
                                base_channel_size=base_channel_size, 
                                hidden_dims=hidden_dims,
                                latent_dim=512,
                                act_fn=act_fn,
                                loss_type=loss_type,
                                lr = lr,
                                wd = wd,
                                alpha=alpha)

# Set up early stopping to monitor 'val_loss'
early_stop_callback = EarlyStopping(monitor='val_loss', patience=1,verbose=True, mode='min')             # 'min' because we want to minimize val_loss
wandb_logger = WandbLogger()  # Logs the model and metrics to wandb


# Set up early stopping to monitor 'val_loss'
early_stop_callback = EarlyStopping(monitor='val_loss', patience=1,verbose=True, mode='min')             # 'min' because we want to minimize val_loss
wandb_logger = WandbLogger()  # Logs the model and metrics to wandb


# Create a unique checkpoint directory based on the run name or ID
run_name = run_name = "multimodal_model_TEXT"

checkpoint_dir = os.path.join(data_path, "models_multi", run_name)
os.makedirs(checkpoint_dir, exist_ok=True)

# Model checkpoint configuration
checkpoint_callback = ModelCheckpoint(monitor='val_loss',dirpath=checkpoint_dir,filename='TEXT_brain_text_model-{epoch:02d}-{val_loss:.2f}',save_top_k=3,mode='min',)


# Initialize trainer with logger

trainer = pl.Trainer(max_epochs=20, devices=[1], callbacks=[early_stop_callback,checkpoint_callback],logger=wandb_logger ) # Add the wandb logger here

trainer.fit(brain_text_model, train_loader_text, val_loader_text)

/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/matteo/anaconda3/envs/borg/lib/python3.8/site- ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name  | Type    | Params | Mode 
------------------------------------------
0 | model | Encoder | 9.2 K  | train
------------------------------------------
9.2 K     Trainable params
0        

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.


                                                                            

/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.
/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (27) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 27/27 [00:28<00:00,  0.94it/s, v_num=9fzm, train_loss_step=5.160, val_loss=5.530, val_mse_loss=1.800, val_cosine_similarity=0.111, train_loss_epoch=5.860]

Metric val_loss improved. New best score: 5.531


Epoch 1: 100%|██████████| 27/27 [00:26<00:00,  1.01it/s, v_num=9fzm, train_loss_step=4.530, val_loss=5.350, val_mse_loss=1.750, val_cosine_similarity=0.138, train_loss_epoch=5.080]

Metric val_loss improved by 0.184 >= min_delta = 0.0. New best score: 5.347


Epoch 2: 100%|██████████| 27/27 [00:23<00:00,  1.14it/s, v_num=9fzm, train_loss_step=4.150, val_loss=5.270, val_mse_loss=1.730, val_cosine_similarity=0.148, train_loss_epoch=4.650]

Metric val_loss improved by 0.077 >= min_delta = 0.0. New best score: 5.270


Epoch 3: 100%|██████████| 27/27 [00:26<00:00,  1.02it/s, v_num=9fzm, train_loss_step=3.860, val_loss=5.240, val_mse_loss=1.720, val_cosine_similarity=0.152, train_loss_epoch=4.290]

Metric val_loss improved by 0.028 >= min_delta = 0.0. New best score: 5.243


Epoch 4: 100%|██████████| 27/27 [00:30<00:00,  0.89it/s, v_num=9fzm, train_loss_step=3.650, val_loss=5.240, val_mse_loss=1.720, val_cosine_similarity=0.152, train_loss_epoch=3.990]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 5.241


Epoch 5: 100%|██████████| 27/27 [00:31<00:00,  0.86it/s, v_num=9fzm, train_loss_step=3.360, val_loss=5.250, val_mse_loss=1.730, val_cosine_similarity=0.150, train_loss_epoch=3.730]

Monitored metric val_loss did not improve in the last 1 records. Best score: 5.241. Signaling Trainer to stop.


Epoch 5: 100%|██████████| 27/27 [00:31<00:00,  0.85it/s, v_num=9fzm, train_loss_step=3.360, val_loss=5.250, val_mse_loss=1.730, val_cosine_similarity=0.150, train_loss_epoch=3.730]


## Multimodal Evaluation:

I'll wrap both models in a class, run predictions and concatenate outputs for subsequent evalautions

In [40]:
class MultimodalModelWrapper(nn.Module):
    def __init__(self, image_model, text_model):
        super(MultimodalModelWrapper, self).__init__()
        self.image_model = image_model
        self.text_model = text_model

    def forward(self, batch):
        # Extract relevant features from the batch

        x = batch["data"]
        text_embed = batch["text_features"]
        img_embed = batch["image_features"]
        k = batch["subject_id"]
        # Pass features through each model
        image_output,_ = self.image_model(x,img_embed,k=k)
        text_output,_ = self.text_model(x,text_embed,k=k)

        concatenated_embeddings = torch.cat((img_embed, text_embed), dim=-1)

        # Concatenate outputs along the last dimension
        combined_output = torch.cat((image_output, text_output), dim=-1)  # Shape: (batch_size, combined_dim)

        return combined_output, concatenated_embeddings

In [41]:
multimodal_model = MultimodalModelWrapper(brain_image_model, brain_text_model)

In [48]:
import importlib 
import multi_evaluation
importlib.reload(multi_evaluation)
from multi_evaluation import *
results_df, similarity_matrices, results = evaluate_and_log(test_loader_text,multimodal_model)


100%|██████████| 6/6 [00:05<00:00,  1.17it/s]


Starting evaluation...
Evaluating metrics for subject 1...
Computed similarity matrix for subject 1.
Top-1 Accuracy: 0.0583, Top-5 Accuracy: 0.1939 for subject 1.
Identification accuracy for subject 1: 0.9229
Logged top-5 retrievals for subject 1.
Evaluating metrics for subject 2...
Computed similarity matrix for subject 2.
Top-1 Accuracy: 0.0368, Top-5 Accuracy: 0.1191 for subject 2.
Identification accuracy for subject 2: 0.8924
Logged top-5 retrievals for subject 2.
Evaluating metrics for subject 3...
Computed similarity matrix for subject 3.
Top-1 Accuracy: 0.0266, Top-5 Accuracy: 0.1305 for subject 3.
Identification accuracy for subject 3: 0.8857
Logged top-5 retrievals for subject 3.
Evaluating metrics for subject 4...
Computed similarity matrix for subject 4.
Top-1 Accuracy: 0.0343, Top-5 Accuracy: 0.1370 for subject 4.
Identification accuracy for subject 4: 0.8635
Logged top-5 retrievals for subject 4.
Evaluation complete. Results loaded to wandb.


In [50]:
results_df

Unnamed: 0,Subject,Identification Accuracy (%),ID Accuracy Baseline (%),Top-1 Accuracy (%),Top1 Baseline (%),Top1 Improvement Over Baseline,Top-5 Accuracy (%),Top5 Baseline (%),Top5 Improvement Over Baseline
0,1,92.293142,50,5.830165,0.126743,46.0,19.391635,0.633714,30.6
1,2,89.242793,50,3.675539,0.126743,29.0,11.913815,0.633714,18.8
2,3,88.574337,50,2.661597,0.126743,21.0,13.054499,0.633714,20.6
3,4,86.345131,50,3.426124,0.214133,16.0,13.704497,1.070664,12.8


In [52]:
output_path = "/home/matteo/storage/brain_tuning/"
results_df.to_csv(opj(output_path,"results_multi_contrastive.csv"))