In [1]:
from tqdm.notebook import tqdm
import pickle, gzip
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import torch

import nibabel as nib
from nilearn import maskers
from nilearn.plotting import view_img
from nilearn.image import resample_img

from neurovlm.data import data_dir
from neurovlm.models import Specter

In [None]:
# Load network atlases
with gzip.open("/Users/ryanhammonds/projects/cbig_network_correspondence/networks.pkl.gz", "rb") as f:
    networks = pickle.load(f)
networks = [(_k, k, v) for _k in networks.keys() for k, v in networks[_k].items()]

# Load models
specter = Specter("allenai/specter2_aug2023refresh", adapter="adhoc_query")
autoencoder = torch.load(data_dir / "autoencoder_sparse.pt", weights_only=False).cpu()
proj_head_text = torch.load(data_dir / "proj_head_text_infonce.pt", weights_only=False).to('cpu')
proj_head_image = torch.load(data_dir / "proj_head_image_infonce.pt", weights_only=False).to('cpu')

mask_arrays = np.load(f"{data_dir}/mask.npz", allow_pickle=True)
mask_img = nib.Nifti1Image(mask_arrays["mask"].astype(float),  mask_arrays["affine"])
masker = maskers.NiftiMasker(mask_img=mask_img, dtype=np.float32).fit()

There are adapters available but none are activated for the forward pass.


In [None]:
# Resample networks
networks_resampled = []

for img in tqdm(networks, total=len(networks)):
    img = img[2]

    img_arr = img.get_fdata()

    if len(np.unique(img_arr)) == 2:
        # binary data
        img_resampled = resample_img(img, mask_arrays["affine"], interpolation="nearest")
    else:
        img_resampled = resample_img(img, mask_arrays["affine"])
        img_resampled_arr = img_resampled.get_fdata()
        img_resampled_arr[img_resampled_arr < 0] = 0.
        thresh = np.percentile(img_resampled_arr.flatten(), 95)
        img_resampled_arr[img_resampled_arr < thresh] = 0.
        img_resampled_arr[img_resampled_arr >= thresh] = 1.
        img_resampled = nib.Nifti1Image(img_resampled_arr, affine=mask_arrays["affine"])

    networks_resampled.append(img_resampled)

# Encode networks
networks_embed = []
for v in tqdm(networks_resampled, total=len(networks_resampled)):
    networks_embed.append(autoencoder.encoder(torch.from_numpy(masker.transform(v)[0])))

with open("neworks_embed.pkl", "wb") as f:
    pickle.dump(networks_embed, f)

In [3]:
# Skip above and load
with open("neworks_embed.pkl", "rb") as f:
    networks_embed = pickle.load(f)

for i in range(len(networks_embed)):
    networks_embed[i] = networks_embed[i].detach()

In [None]:
# Label networks
network_labels = [
    'default mode network',
    'frontoparietal network',
    "control network",
    'cognitive control network',
    'dorsal attention network',
    'salience network',
    'somatosensory network',
    'motor network',
    'somatomotor network',
    'visual network',
    'language network',
    'executive control network',
    'auditory network',
    'limbic emotional network',
    'multiple demand network',
    'subcortical network',
    'cerebellar network',
    'cingulo-opercular network'
]

label_embeddings = specter(network_labels)
label_embeddings = label_embeddings / label_embeddings.norm(dim=1)[:, None]

label_embeddings = proj_head_text(label_embeddings)
label_embeddings = label_embeddings / label_embeddings.norm(dim=1)[:, None]

atlas = []
atlas_label = []

primary_label = []
secondary_label = []
primary_label_sim_score = []
secondary_label_sim_score = []

for names, img_embed in zip(networks, networks_embed):

    img_embed = proj_head_image(img_embed)
    img_embed = img_embed / img_embed.norm()

    cossim = label_embeddings @ img_embed
    inds = torch.argsort(cossim, descending=True)

    primary_label.append(network_labels[inds[0]])
    primary_label_sim_score.append(float(cossim[inds[0]]))

    secondary_label.append(network_labels[inds[1]])
    secondary_label_sim_score.append(float(cossim[inds[1]]))

    atlas.append(names[0])
    atlas_label.append(names[1])

Consider using tensor.detach() first. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)
  primary_label_sim_score.append(float(cossim[inds[0]]))


In [6]:
df = pd.DataFrame({
    "atlas": atlas,
    "atlas_label": atlas_label,
    "predicted_label_primary": primary_label,
    "predicted_label_secondary": secondary_label,
    "primary_cossim_score": primary_label_sim_score,
    "secondar_cossim_score": secondary_label_sim_score,
})

df.iloc[:10]

Unnamed: 0,atlas,atlas_label,predicted_label_primary,predicted_label_secondary,primary_cossim_score,secondar_cossim_score
0,Du,VIS-P,visual network,cerebellar network,0.221639,0.206852
1,Du,CG-OP,executive control network,cingulo-opercular network,0.10184,0.094298
2,Du,DN-B,default mode network,frontoparietal network,0.15698,0.155471
3,Du,SMOT-B,somatosensory network,motor network,0.267073,0.236592
4,Du,AUD,auditory network,somatosensory network,0.18634,0.036143
5,Du,PM-PPr,motor network,somatosensory network,0.285581,0.258426
6,Du,dATN-B,visual network,dorsal attention network,0.355336,0.294363
7,Du,SMOT-A,somatosensory network,motor network,0.33254,0.295238
8,Du,LANG,language network,auditory network,0.095445,0.091745
9,Du,FPN-B,frontoparietal network,default mode network,0.128733,0.114357


In [7]:
df.to_csv("~/Desktop/sim_updated.csv", index=False)