In [1]:
from tqdm.notebook import tqdm
import pickle, gzip
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import torch

import nibabel as nib
from nilearn import maskers
from nilearn.plotting import view_img
from nilearn.image import resample_img

from neurovlm.data import data_dir
from neurovlm.models import Specter

In [None]:
# Load specter
specter = Specter("allenai/specter2_aug2023refresh", adapter="adhoc_query")

# Load network atlases
with gzip.open("/Users/ryanhammonds/projects/cbig_network_correspondence/networks.pkl.gz", "rb") as f:
    networks = pickle.load(f)

# Load models
proj_head = torch.load(data_dir / f"proj_head_mse_sparse_adhoc.pt", weights_only=False).cpu()
specter = Specter("allenai/specter2_aug2023refresh", adapter="adhoc_query")
autoencoder = torch.load(data_dir / "autoencoder_sparse.pt", weights_only=False).cpu()

mask_arrays = np.load(f"{data_dir}/mask.npz", allow_pickle=True)
mask_img = nib.Nifti1Image(mask_arrays["mask"].astype(float),  mask_arrays["affine"])
masker = maskers.NiftiMasker(mask_img=mask_img, dtype=np.float32).fit()

There are adapters available but none are activated for the forward pass.
There are adapters available but none are activated for the forward pass.


In [None]:
# Resample networks
networks = [(_k, k, v) for _k in networks.keys() for k, v in networks[_k].items()]
networks_resampled = []

for img in tqdm(networks, total=len(networks)):
    img = img[2]

    img_arr = img.get_fdata()

    if len(np.unique(img_arr)) == 2:
        # binary data
        img_resampled = resample_img(img, mask_arrays["affine"], interpolation="nearest")
    else:
        img_resampled = resample_img(img, mask_arrays["affine"])
        img_resampled_arr = img_resampled.get_fdata()
        img_resampled_arr[img_resampled_arr < 0] = 0.
        thresh = np.percentile(img_resampled_arr.flatten(), 95)
        img_resampled_arr[img_resampled_arr < thresh] = 0.
        img_resampled_arr[img_resampled_arr >= thresh] = 1.
        img_resampled = nib.Nifti1Image(img_resampled_arr, affine=mask_arrays["affine"])

    networks_resampled.append(img_resampled)

# Encode networks
networks_embed = []
for v in tqdm(networks_resampled, total=len(networks_resampled)):
    networks_embed.append(autoencoder.encoder(torch.from_numpy(masker.transform(v)[0])))

  0%|          | 0/185 [00:00<?, ?it/s]

  0%|          | 0/185 [00:00<?, ?it/s]

In [4]:
with open("neworks_embed.pkl", "wb") as f:
    pickle.dump(networks_embed, f)

In [None]:
# Label networks
network_labels = [
    'default mode network',
    'frontoparietal network',
    "control network",
    'cognitive control network',
    'dorsal attention network',
    'salience network',
    'somatosensory network',
    'motor network',
    'somatomotor network',
    'visual network',
    'language network',
    'executive control network',
    'auditory network',
    'limbic emotional network',
    'multiple demand network',
    'subcortical network',
    'cerebellar network',
    'cingulo-opercular network'
]

label_embeddings = specter(network_labels)
label_embeddings = label_embeddings / label_embeddings.norm(dim=1)[:, None]

label_embeddings = proj_head(label_embeddings)
label_embeddings = label_embeddings / label_embeddings.norm(dim=1)[:, None]

atlas = []
atlas_label = []

primary_label = []
secondary_label = []

primary_label_sim_score = []
secondary_label_sim_score = []

for names, embed in zip(networks, networks_embed):

    atlas.append(names[0])
    atlas_label.append(names[1])

    embed = embed / embed.norm()
    cossim = label_embeddings @ embed
    inds = torch.argsort(cossim, descending=True)
    primary_label.append(network_labels[inds[0]])
    primary_label_sim_score.append(float(cossim[inds[0]]))
    secondary_label.append(network_labels[inds[1]])
    secondary_label_sim_score.append(float(cossim[inds[1]]))

Consider using tensor.detach() first. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)
  primary_label_sim_score.append(float(cossim[inds[0]]))


In [10]:
df = pd.DataFrame({
    "atlas": atlas,
    "atlas_label": atlas_label,
    "predicted_label_primary": primary_label,
    "predicted_label_secondary": secondary_label,
    "primary_cossim_score": primary_label_sim_score,
    "secondar_cossim_score": secondary_label_sim_score,
})

df.iloc[:10]

Unnamed: 0,atlas,atlas_label,predicted_label_primary,predicted_label_secondary,primary_cossim_score,secondar_cossim_score
0,Du,VIS-P,salience network,cerebellar network,0.799473,0.766441
1,Du,CG-OP,somatosensory network,executive control network,0.670064,0.666779
2,Du,DN-B,default mode network,frontoparietal network,0.699849,0.654484
3,Du,SMOT-B,auditory network,motor network,0.810977,0.798625
4,Du,AUD,auditory network,language network,0.889523,0.783747
5,Du,PM-PPr,motor network,somatosensory network,0.801772,0.801299
6,Du,dATN-B,visual network,dorsal attention network,0.824848,0.819269
7,Du,SMOT-A,somatosensory network,motor network,0.849443,0.847434
8,Du,LANG,language network,auditory network,0.755044,0.731132
9,Du,FPN-B,default mode network,cognitive control network,0.766769,0.757961


In [31]:
df = pd.DataFrame({
    "atlas_label": atlas_label,
    "predicted_label_primary": primary_label,
    "predicted_label_secondary": secondary_label,
    "primary_cossim_score": primary_label_sim_score,
    "secondar_cossim_score": secondary_label_sim_score,
})

df.iloc[:10]

Unnamed: 0,atlas_label,predicted_label_primary,predicted_label_secondary,primary_cossim_score,secondar_cossim_score
0,VIS-P,visual network,cerebellar network,0.221639,0.206852
1,CG-OP,executive control network,somatosensory network,0.10184,0.091646
2,DN-B,default mode network,frontoparietal network,0.156981,0.155471
3,SMOT-B,somatosensory network,motor network,0.267073,0.236592
4,AUD,auditory network,somatosensory network,0.18634,0.036143
5,PM-PPr,motor network,somatosensory network,0.285581,0.258426
6,dATN-B,visual network,dorsal attention network,0.355336,0.294363
7,SMOT-A,somatosensory network,motor network,0.33254,0.295238
8,LANG,language network,auditory network,0.095445,0.091745
9,FPN-B,frontoparietal network,default mode network,0.128733,0.114357


In [11]:
df.to_csv("~/Desktop/sim.csv", index=False)