## In this notebook I'll train a brain to image and text features and combine their predictions

In [31]:
import numpy as np
import nibabel as nib
import nilearn 
import matplotlib.pyplot as plt
import os
from os.path import join as opj
import pandas as pd
import seaborn as sns
import glob
from nilearn import plotting
from nilearn.image import *
import tqdm
from PIL import Image
from sklearn.model_selection import train_test_split
from nilearn.plotting import plot_stat_map
from nilearn.image import mean_img
from nilearn.plotting import plot_img, plot_epi
from nilearn.maskers import NiftiMasker
from sklearn.preprocessing import StandardScaler
import wandb
import pickle
from torch.utils.data import Dataset, DataLoader
from dataset import fMRI_Dataset, fMRI_Multi_Dataset
import torch
from torch import nn
from pytorch_lightning.callbacks import EarlyStopping,ModelCheckpoint

from pytorch_lightning.loggers import WandbLogger
from network import Encoder, ContrastiveModel
import torch
import torch.nn as nn
import pytorch_lightning as pl

In [32]:
import dataset
import importlib
importlib.reload(dataset)
from dataset import fMRI_Dataset, fMRI_Multi_Dataset

In [33]:
# use_augmentations = True   
sub = "CSI4"
wandb.login()
wandb.init(project="BrainTuning",config={"model":"multimodal","single_subject":True,  "sub":sub})



0,1
epoch,▁▁▂▂▂▂▃▃▄▄▅▅▅▅▅▆▆▇▇▇▇██▁▁▂▂▂▂▃▃▄▄▅▅
subject_3_identification_accuracy,▁
subject_3_top1_acc,▁
subject_3_top5_acc,▁
train_loss_epoch,█▆▅▅▄▃▃▂▂▁▁█▆▅▅▄▃
train_loss_step,▁
trainer/global_step,▁▁▂▂▂▂▃▃▄▄▅▅▅▅▅▆▆▇▇▇▇██▁▁▂▂▂▂▃▃▄▄▅▅
val_cosine_similarity,▁▅▇████████▁▄▅▅▆▅
val_loss,▇▅▃▃▂▂▁▁▁▁▁█▆▅▄▄▄
val_mse_loss,▅▃▂▁▁▁▁▁▁▁▁█▆▅▅▅▅

0,1
epoch,5.0
subject_3_identification_accuracy,0.89703
subject_3_top1_acc,0.04183
subject_3_top5_acc,0.12801
train_loss_epoch,3.60227
train_loss_step,3.53726
trainer/global_step,47.0
val_cosine_similarity,0.14572
val_loss,4.9928
val_mse_loss,1.73616


In [34]:
## load the data

train_datasets_images=[]
val_datasets_images=[]
test_datasets_images=[]

train_datasets_text=[]
val_datasets_text=[]
test_datasets_text=[]


for subj in tqdm.tqdm([sub]):

    subj_id = int(subj.split("CSI")[1])

    data_path =  f"/home/matteo/storage/brain_tuning/{subj}"

    train_fmri = np.load(opj(data_path, "train_fmri_top.npy"))
    val_fmri = np.load(opj(data_path, "val_fmri_top.npy"))
    test_fmri = np.load(opj(data_path, "test_fmri_top.npy"))

    ##load the images
    img_train = np.load(opj(data_path, "img_train.npy"),allow_pickle=True)
    img_val = np.load(opj(data_path, "img_val.npy"),allow_pickle=True)
    img_test = np.load(opj(data_path, "img_test.npy"),allow_pickle=True)

     ##load the captions
    train_captions = np.load(opj(data_path, "train_captions.npy"),allow_pickle=True)
    val_captions = np.load(opj(data_path, "val_captions.npy"),allow_pickle=True)
    test_captions = np.load(opj(data_path, "test_captions.npy"),allow_pickle=True)

    ## load the features
    train_features = np.load(opj(data_path, "train_image_features.npy"))
    val_features = np.load(opj(data_path, "val_image_features.npy"))
    test_features = np.load(opj(data_path, "test_image_features.npy"))

    ## load the text features
    train_text_features = np.load(opj(data_path, "train_text_features.npy"))
    val_text_features = np.load(opj(data_path, "val_text_features.npy"))
    test_text_features = np.load(opj(data_path, "test_text_features.npy"))

    ## create the dataset
    train_dataset = fMRI_Multi_Dataset(train_fmri,img_train,train_captions,train_features,train_text_features,subj_id,feature_type="image")
    val_dataset = fMRI_Multi_Dataset(val_fmri,img_val,val_captions,val_features,val_text_features,subj_id,feature_type="image")
    test_dataset = fMRI_Multi_Dataset(test_fmri,img_test,test_captions,test_features,test_text_features, subj_id,feature_type="image")

    ## append the datasets
    train_datasets_images.append(train_dataset)
    val_datasets_images.append(val_dataset)
    test_datasets_images.append(test_dataset)

    train_dataset_text = fMRI_Multi_Dataset(train_fmri,img_train,train_captions,train_features,train_text_features,subj_id,feature_type="text")
    val_dataset_text = fMRI_Multi_Dataset(val_fmri,img_val,val_captions,val_features,val_text_features,subj_id,feature_type="text")
    test_dataset_text = fMRI_Multi_Dataset(test_fmri,img_test,test_captions,test_features,test_text_features,subj_id,feature_type="text")

    ## append the datasets
    train_datasets_text.append(train_dataset_text)
    val_datasets_text.append(val_dataset_text)
    test_datasets_text.append(test_dataset_text)





100%|██████████| 1/1 [00:01<00:00,  1.65s/it]


In [35]:
## Concatenate the datasets
train_dataset_images = torch.utils.data.ConcatDataset(train_datasets_images)
val_dataset_images = torch.utils.data.ConcatDataset(val_datasets_images)
test_dataset_images = torch.utils.data.ConcatDataset(test_datasets_images)

train_dataset_text = torch.utils.data.ConcatDataset(train_datasets_text)
val_dataset_text = torch.utils.data.ConcatDataset(val_datasets_text)
test_dataset_text = torch.utils.data.ConcatDataset(test_datasets_text)




In [36]:
BATCH_SIZE = 512
train_loader = DataLoader(train_dataset_images, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset_images, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset_images, batch_size=BATCH_SIZE, shuffle=False)

train_loader_text = DataLoader(train_dataset_text, batch_size=BATCH_SIZE, shuffle=True)
val_loader_text = DataLoader(val_dataset_text, batch_size=BATCH_SIZE, shuffle=False)
test_loader_text = DataLoader(test_dataset_text, batch_size=BATCH_SIZE, shuffle=False)

## Train an Brain2ImageModel

In [37]:
## optimal parameter obtained from the hyperparameter search

act_fn = nn.Identity
base_channel_size = [2048]
hidden_dims = [1024]
latent_dim = 768

loss_type = "contrastive"
lr = 1e-4
temperature = 0.1
wd = 1e-5
alpha = 0.8

In [38]:
brain_image_model = ContrastiveModel(num_input_channels= 10000,
                                base_channel_size=base_channel_size, 
                                hidden_dims=hidden_dims,
                                latent_dim=latent_dim,
                                act_fn=act_fn,
                                loss_type=loss_type,
                                lr = lr,
                                wd = wd,
                                alpha=alpha)

# Set up early stopping to monitor 'val_loss'
early_stop_callback = EarlyStopping(monitor='val_loss', patience=1,verbose=True, mode='min')             # 'min' because we want to minimize val_loss
wandb_logger = WandbLogger()  # Logs the model and metrics to wandb


# Set up early stopping to monitor 'val_loss'
early_stop_callback = EarlyStopping(monitor='val_loss', patience=1,verbose=True, mode='min')             # 'min' because we want to minimize val_loss
wandb_logger = WandbLogger()  # Logs the model and metrics to wandb


# Create a unique checkpoint directory based on the run name or ID
run_name = "multimodal_model_IMAGE"
checkpoint_dir = os.path.join(data_path, "models_multi",sub, run_name)
os.makedirs(checkpoint_dir, exist_ok=True)

# Model checkpoint configuration
checkpoint_callback = ModelCheckpoint(monitor='val_loss',dirpath=checkpoint_dir,filename='{sub}_brain_image_model-{epoch:02d}-{val_loss:.2f}',save_top_k=3,mode='min',)


# Initialize trainer with logger

trainer = pl.Trainer(max_epochs=20, devices=[1], callbacks=[early_stop_callback,checkpoint_callback],logger=wandb_logger ) # Add the wandb logger here

trainer.fit(brain_image_model, train_loader, val_loader)

/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/matteo/anaconda3/envs/borg/lib/python3.8/site- ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name  | Type    | Params | Mode 
------------------------------------------
0 | model | Encoder | 9.7 K  | train
------------------------------------------
9.7 K     Trainable params
0        

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.


                                                                            

/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 397. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=255` in the `DataLoader` to improve performance.
/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 5/5 [00:02<00:00,  1.88it/s, v_num=g4pz, train_loss_step=4.860]

/home/matteo/anaconda3/envs/borg/lib/python3.8/site-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 196. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Epoch 0: 100%|██████████| 5/5 [00:03<00:00,  1.61it/s, v_num=g4pz, train_loss_step=4.860, val_loss=5.450, val_mse_loss=1.810, val_cosine_similarity=0.0784, train_loss_epoch=6.000]

Metric val_loss improved. New best score: 5.454


Epoch 1: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s, v_num=g4pz, train_loss_step=4.180, val_loss=5.210, val_mse_loss=1.740, val_cosine_similarity=0.115, train_loss_epoch=5.200] 

Metric val_loss improved by 0.248 >= min_delta = 0.0. New best score: 5.206


Epoch 2: 100%|██████████| 5/5 [00:03<00:00,  1.56it/s, v_num=g4pz, train_loss_step=3.750, val_loss=5.040, val_mse_loss=1.690, val_cosine_similarity=0.139, train_loss_epoch=4.740]

Metric val_loss improved by 0.162 >= min_delta = 0.0. New best score: 5.043


Epoch 3: 100%|██████████| 5/5 [00:03<00:00,  1.48it/s, v_num=g4pz, train_loss_step=3.410, val_loss=4.930, val_mse_loss=1.660, val_cosine_similarity=0.154, train_loss_epoch=4.380]

Metric val_loss improved by 0.114 >= min_delta = 0.0. New best score: 4.929


Epoch 4: 100%|██████████| 5/5 [00:03<00:00,  1.62it/s, v_num=g4pz, train_loss_step=3.120, val_loss=4.840, val_mse_loss=1.640, val_cosine_similarity=0.164, train_loss_epoch=4.060]

Metric val_loss improved by 0.085 >= min_delta = 0.0. New best score: 4.844


Epoch 5: 100%|██████████| 5/5 [00:03<00:00,  1.42it/s, v_num=g4pz, train_loss_step=2.830, val_loss=4.780, val_mse_loss=1.630, val_cosine_similarity=0.167, train_loss_epoch=3.760]

Metric val_loss improved by 0.061 >= min_delta = 0.0. New best score: 4.783


Epoch 6: 100%|██████████| 5/5 [00:03<00:00,  1.65it/s, v_num=g4pz, train_loss_step=2.570, val_loss=4.740, val_mse_loss=1.630, val_cosine_similarity=0.169, train_loss_epoch=3.470]

Metric val_loss improved by 0.047 >= min_delta = 0.0. New best score: 4.736


Epoch 7: 100%|██████████| 5/5 [00:03<00:00,  1.61it/s, v_num=g4pz, train_loss_step=2.270, val_loss=4.700, val_mse_loss=1.620, val_cosine_similarity=0.171, train_loss_epoch=3.210]

Metric val_loss improved by 0.040 >= min_delta = 0.0. New best score: 4.696


Epoch 8: 100%|██████████| 5/5 [00:03<00:00,  1.45it/s, v_num=g4pz, train_loss_step=2.080, val_loss=4.660, val_mse_loss=1.620, val_cosine_similarity=0.173, train_loss_epoch=2.960]

Metric val_loss improved by 0.031 >= min_delta = 0.0. New best score: 4.665


Epoch 9: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s, v_num=g4pz, train_loss_step=1.930, val_loss=4.640, val_mse_loss=1.620, val_cosine_similarity=0.175, train_loss_epoch=2.730]

Metric val_loss improved by 0.025 >= min_delta = 0.0. New best score: 4.640


Epoch 10: 100%|██████████| 5/5 [00:03<00:00,  1.61it/s, v_num=g4pz, train_loss_step=1.720, val_loss=4.620, val_mse_loss=1.620, val_cosine_similarity=0.176, train_loss_epoch=2.520]

Metric val_loss improved by 0.022 >= min_delta = 0.0. New best score: 4.618


Epoch 11: 100%|██████████| 5/5 [00:03<00:00,  1.56it/s, v_num=g4pz, train_loss_step=1.580, val_loss=4.600, val_mse_loss=1.610, val_cosine_similarity=0.177, train_loss_epoch=2.320]

Metric val_loss improved by 0.015 >= min_delta = 0.0. New best score: 4.603


Epoch 12: 100%|██████████| 5/5 [00:03<00:00,  1.57it/s, v_num=g4pz, train_loss_step=1.390, val_loss=4.590, val_mse_loss=1.610, val_cosine_similarity=0.177, train_loss_epoch=2.140]

Metric val_loss improved by 0.014 >= min_delta = 0.0. New best score: 4.589


Epoch 13: 100%|██████████| 5/5 [00:03<00:00,  1.44it/s, v_num=g4pz, train_loss_step=1.260, val_loss=4.580, val_mse_loss=1.610, val_cosine_similarity=0.178, train_loss_epoch=1.980]

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 4.580


Epoch 14: 100%|██████████| 5/5 [00:03<00:00,  1.47it/s, v_num=g4pz, train_loss_step=1.150, val_loss=4.570, val_mse_loss=1.610, val_cosine_similarity=0.178, train_loss_epoch=1.830]

Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 4.574


Epoch 15: 100%|██████████| 5/5 [00:03<00:00,  1.56it/s, v_num=g4pz, train_loss_step=1.050, val_loss=4.560, val_mse_loss=1.610, val_cosine_similarity=0.179, train_loss_epoch=1.700]

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 4.565


Epoch 16: 100%|██████████| 5/5 [00:03<00:00,  1.49it/s, v_num=g4pz, train_loss_step=0.987, val_loss=4.560, val_mse_loss=1.610, val_cosine_similarity=0.179, train_loss_epoch=1.590]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 4.560


Epoch 17: 100%|██████████| 5/5 [00:03<00:00,  1.56it/s, v_num=g4pz, train_loss_step=0.893, val_loss=4.560, val_mse_loss=1.610, val_cosine_similarity=0.178, train_loss_epoch=1.480]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 4.558


Epoch 18: 100%|██████████| 5/5 [00:03<00:00,  1.59it/s, v_num=g4pz, train_loss_step=0.821, val_loss=4.560, val_mse_loss=1.610, val_cosine_similarity=0.179, train_loss_epoch=1.380]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 4.556


Epoch 19: 100%|██████████| 5/5 [00:03<00:00,  1.57it/s, v_num=g4pz, train_loss_step=0.710, val_loss=4.560, val_mse_loss=1.610, val_cosine_similarity=0.178, train_loss_epoch=1.300]

Monitored metric val_loss did not improve in the last 1 records. Best score: 4.556. Signaling Trainer to stop.
`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 5/5 [00:03<00:00,  1.41it/s, v_num=g4pz, train_loss_step=0.710, val_loss=4.560, val_mse_loss=1.610, val_cosine_similarity=0.178, train_loss_epoch=1.300]


## Train the Text model

In [39]:
brain_text_model = ContrastiveModel(num_input_channels= 10000,
                                base_channel_size=base_channel_size, 
                                hidden_dims=hidden_dims,
                                latent_dim=512,
                                act_fn=act_fn,
                                loss_type=loss_type,
                                lr = lr,
                                wd = wd,
                                alpha=alpha)

# Set up early stopping to monitor 'val_loss'
early_stop_callback = EarlyStopping(monitor='val_loss', patience=1,verbose=True, mode='min')             # 'min' because we want to minimize val_loss
wandb_logger = WandbLogger()  # Logs the model and metrics to wandb


# Set up early stopping to monitor 'val_loss'
early_stop_callback = EarlyStopping(monitor='val_loss', patience=1,verbose=True, mode='min')             # 'min' because we want to minimize val_loss
wandb_logger = WandbLogger()  # Logs the model and metrics to wandb


# Create a unique checkpoint directory based on the run name or ID
run_name = run_name = "multimodal_model_TEXT"

checkpoint_dir = os.path.join(data_path, "models_multi",sub, run_name)
os.makedirs(checkpoint_dir, exist_ok=True)

# Model checkpoint configuration
checkpoint_callback = ModelCheckpoint(monitor='val_loss',dirpath=checkpoint_dir,filename='{sub}_TEXT_brain_text_model-{epoch:02d}-{val_loss:.2f}',save_top_k=3,mode='min',)


# Initialize trainer with logger

trainer = pl.Trainer(max_epochs=20, devices=[1], callbacks=[early_stop_callback,checkpoint_callback],logger=wandb_logger ) # Add the wandb logger here

trainer.fit(brain_text_model, train_loader_text, val_loader_text)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name  | Type    | Params | Mode 
------------------------------------------
0 | model | Encoder | 9.2 K  | train
------------------------------------------
9.2 K     Trainable params
0         Non-trainable params
9.2 K     Total params
0.037     Total estimated model params size (MB)
7         Modules in train mode
0         Modules in eval mode


Epoch 0:  20%|██        | 1/5 [00:00<00:02,  1.63it/s, v_num=g4pz, train_loss_step=6.340]

Epoch 0: 100%|██████████| 5/5 [00:04<00:00,  1.23it/s, v_num=g4pz, train_loss_step=4.940, val_loss=5.560, val_mse_loss=1.880, val_cosine_similarity=0.0719, train_loss_epoch=6.040]

Metric val_loss improved. New best score: 5.557


Epoch 1: 100%|██████████| 5/5 [00:03<00:00,  1.50it/s, v_num=g4pz, train_loss_step=4.080, val_loss=5.370, val_mse_loss=1.820, val_cosine_similarity=0.101, train_loss_epoch=5.210] 

Metric val_loss improved by 0.189 >= min_delta = 0.0. New best score: 5.367


Epoch 2: 100%|██████████| 5/5 [00:03<00:00,  1.41it/s, v_num=g4pz, train_loss_step=3.730, val_loss=5.240, val_mse_loss=1.790, val_cosine_similarity=0.119, train_loss_epoch=4.730]

Metric val_loss improved by 0.126 >= min_delta = 0.0. New best score: 5.241


Epoch 3: 100%|██████████| 5/5 [00:03<00:00,  1.66it/s, v_num=g4pz, train_loss_step=3.380, val_loss=5.140, val_mse_loss=1.760, val_cosine_similarity=0.132, train_loss_epoch=4.340]

Metric val_loss improved by 0.096 >= min_delta = 0.0. New best score: 5.145


Epoch 4: 100%|██████████| 5/5 [00:03<00:00,  1.66it/s, v_num=g4pz, train_loss_step=2.940, val_loss=5.080, val_mse_loss=1.750, val_cosine_similarity=0.137, train_loss_epoch=3.980]

Metric val_loss improved by 0.064 >= min_delta = 0.0. New best score: 5.081


Epoch 5: 100%|██████████| 5/5 [00:03<00:00,  1.52it/s, v_num=g4pz, train_loss_step=2.680, val_loss=5.050, val_mse_loss=1.750, val_cosine_similarity=0.138, train_loss_epoch=3.650]

Metric val_loss improved by 0.028 >= min_delta = 0.0. New best score: 5.053


Epoch 6: 100%|██████████| 5/5 [00:03<00:00,  1.58it/s, v_num=g4pz, train_loss_step=2.450, val_loss=5.040, val_mse_loss=1.750, val_cosine_similarity=0.136, train_loss_epoch=3.350]

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 5.044


Epoch 7: 100%|██████████| 5/5 [00:03<00:00,  1.60it/s, v_num=g4pz, train_loss_step=2.250, val_loss=5.030, val_mse_loss=1.750, val_cosine_similarity=0.136, train_loss_epoch=3.060]

Metric val_loss improved by 0.009 >= min_delta = 0.0. New best score: 5.035


Epoch 8: 100%|██████████| 5/5 [00:03<00:00,  1.40it/s, v_num=g4pz, train_loss_step=1.920, val_loss=5.040, val_mse_loss=1.750, val_cosine_similarity=0.135, train_loss_epoch=2.800]

Monitored metric val_loss did not improve in the last 1 records. Best score: 5.035. Signaling Trainer to stop.


Epoch 8: 100%|██████████| 5/5 [00:03<00:00,  1.27it/s, v_num=g4pz, train_loss_step=1.920, val_loss=5.040, val_mse_loss=1.750, val_cosine_similarity=0.135, train_loss_epoch=2.800]


## Multimodal Evaluation:

I'll wrap both models in a class, run predictions and concatenate outputs for subsequent evalautions

In [40]:
class MultimodalModelWrapper(nn.Module):
    def __init__(self, image_model, text_model):
        super(MultimodalModelWrapper, self).__init__()
        self.image_model = image_model
        self.text_model = text_model

    def forward(self, batch):
        # Extract relevant features from the batch

        x = batch["data"]
        text_embed = batch["text_features"]
        img_embed = batch["image_features"]
        k = batch["subject_id"]
        # Pass features through each model
        image_output,_ = self.image_model(x,img_embed,k=k)
        text_output,_ = self.text_model(x,text_embed,k=k)

        concatenated_embeddings = torch.cat((img_embed, text_embed), dim=-1)

        # Concatenate outputs along the last dimension
        combined_output = torch.cat((image_output, text_output), dim=-1)  # Shape: (batch_size, combined_dim)

        return combined_output, concatenated_embeddings

In [41]:
multimodal_model = MultimodalModelWrapper(brain_image_model, brain_text_model)

In [42]:
import importlib 
import multi_evaluation
importlib.reload(multi_evaluation)
from multi_evaluation import *
results_df, similarity_matrices, results = evaluate_and_log(test_loader_text,multimodal_model)


  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  1.32it/s]


Starting evaluation...
Evaluating metrics for subject 4...
Computed similarity matrix for subject 4.
Top-1 Accuracy: 0.0300, Top-5 Accuracy: 0.1435 for subject 4.
Identification accuracy for subject 4: 0.8806
Logged top-5 retrievals for subject 4.
Evaluation complete. Results loaded to wandb.


In [43]:
results_df

Unnamed: 0,Subject,Identification Accuracy (%),ID Accuracy Baseline (%),Top-1 Accuracy (%),Top1 Baseline (%),Top1 Improvement Over Baseline,Top-5 Accuracy (%),Top5 Baseline (%),Top5 Improvement Over Baseline
0,4,88.057274,50,2.997859,0.214133,14.0,14.346895,1.070664,13.4


In [44]:
output_path = "/home/matteo/storage/brain_tuning/"
results_df.to_csv(opj(output_path,f"results_multi_contrastive_{sub}.csv"))



In [45]:
print("Done",sub)

Done CSI4
