# Objective of this notebook is to showcase different abilities of our neural foundation model

In [1]:
import pandas as pd
import os
from os.path import join
import numpy as np
import mne
from mne_bids import (
    BIDSPath,
    read_raw_bids,
    print_dir_tree,
    make_report,
    find_matching_paths,
    get_entity_vals,
)

import h5py
from os.path import join as opj
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
import tqdm
from versatile_diffusion_dual_guided_fake_images import *

from torchsummary import summary

import pandas as pd
import os
from os.path import join as opj
from PIL import Image
import h5py
import numpy as np
import nibabel as nib
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC

from sklearn.metrics import accuracy_score,classification_report,confusion_matrix
from scipy.signal import stft

#import labelencoder
from sklearn.preprocessing import LabelEncoder
#import pipeline
from sklearn.pipeline import Pipeline
import tqdm
import glob
import yaml

In [2]:
import yaml

In [3]:
from fmri_nsd_models import ContrastiveModel as fMRIModule

In [4]:
def get_fmri_nsd_test_data():
    path="data_fmri_nsd"
    
    
    test_data=np.load(opj(path,"test_data.npy"))
    subject_test_ids=np.load(opj(path,"subject_test_ids.npy"))
    
    base_path= '/nsd_data' #"/srv/nfs-data/sisko/grigorii/data/"
    
    processed_data=opj(base_path,"processed_data","subj01")
    
    imgs_test_data=opj(processed_data,f"nsd_test_stim_sub1.npy")
    
    imgs=np.load(imgs_test_data)
    imgs=np.concatenate([imgs,imgs,imgs,imgs],0) #all subjects have same test images
    
    return test_data, imgs, subject_test_ids

In [5]:
def make_predictions(model,data,device,subject_ids=None):
    
    model.to(device)
    #create a tensordataset
    if subject_ids is not None:
        dataset=torch.utils.data.TensorDataset(data,subject_ids)
    else:
        dataset=torch.utils.data.TensorDataset(data)
    dataloader=torch.utils.data.DataLoader(dataset,batch_size=128,shuffle=False)
    
    embeddings=[]
    with torch.no_grad():
        for x in tqdm.tqdm(dataloader):
            if subject_ids is not None:
                x,k=x
                embeddings.append(model(x.to(device),k=k.to(device)).cpu())

            else:
                x=x[0]
                embeddings.append(model(x.to(device)).cpu())

    
    model.to("cpu")
    embeddings=torch.cat(embeddings,0)
    return embeddings
                      

In [6]:
def clip_2way(test_embeddings,pred_embeddings):
    rnd_idx=torch.randperm(len(pred_embeddings))
    pred_embeddings_random=pred_embeddings[rnd_idx]

    
    cosine=torch.nn.CosineSimilarity()

    acc=torch.stack([cosine(test_embeddings,pred_embeddings_random),cosine(test_embeddings,pred_embeddings)],dim=1)
    #l'indice 1 è quello giusto così posso fare argmax e sommare

    acc=torch.argmax(acc,dim=1)
    acc=torch.sum(acc)/len(acc)
    acc

    return acc

In [12]:
def load_model(base_name, exp_name, epoch, num_input_channels):
    for dn in os.listdir('wandb'):
        if exp_name in dn:
            break
    config_fn = opj('wandb', dn, 'files', 'config.yaml')
    with open(config_fn) as cyml:
        config = yaml.safe_load(cyml)

    act_fn_c = config['act_fn']['value']
    n_layers_c = config['n_encoder_net_layers']['value']
    act_fn_ = nn.ReLU
    if act_fn_c == "ReLU":
        act_fn_ = nn.ReLU
    elif act_fn_c == "GELU":
        act_fn_ = nn.GELU
    elif act_fn_c == "Identity":
        act_fn = nn.Identity
        n_layers_c =1
    
    fmri_model = fMRIModule(num_input_channels=fmri_nsd_data.shape[-1],
                        base_channel_size=[config['base_channel_size']['value']]*n_layers_c,
                        latent_dim=config['latent_dim']['value'],
                        act_fn=act_fn_,
                        loss_type=config['loss_type']['value'])
    
    epoch_str = 'epoch='+str(epoch)+'-'
    ckpt_dir = opj(base_name, exp_name, 'checkpoints')
    for fn in os.listdir(ckpt_dir):
        if epoch_str in fn:
            break
    ckpt_name = opj(ckpt_dir,fn)
    fmri_model.load_state_dict(torch.load(ckpt_name)['state_dict'])
    return fmri_model

In [13]:
fmri_nsd_data, fmri_nsd_img, fmri_subject_ids=get_fmri_nsd_test_data()

In [17]:
fmri_model = load_model(base_name='sweep-nsd-clip-3', exp_name='so12cmcn', epoch=18, num_input_channels=fmri_nsd_data.shape[-1])

In [18]:
fmri_nsd_subjective_ids=[int(i[-1]) for i in fmri_subject_ids]

In [23]:
fmri_nsd_subjective_ids

[1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,


In [19]:
fmri_embeddings=make_predictions(fmri_model,torch.tensor(fmri_nsd_data).float(),device="cuda:2",subject_ids=torch.tensor(fmri_nsd_subjective_ids))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 31/31 [00:00<00:00, 39.51it/s]


In [20]:
fmri_gt_image_embeddings=torch.load("data_fmri_nsd/test_clip_img_embeds.pt")

In [21]:
fmri_decoding_acc=clip_2way(fmri_gt_image_embeddings,fmri_embeddings)

In [22]:
fmri_decoding_acc

tensor(0.5298)