In [None]:
from tqdm.notebook import tqdm
import pickle, gzip
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import torch

import nibabel as nib
from nilearn import maskers
from nilearn.plotting import view_img
from nilearn.image import resample_img

from neurovlm.data import data_dir
from neurovlm.models import Specter

In [2]:
# Load network atlases
with gzip.open("/Users/ryanhammonds/projects/cbig_network_correspondence/networks.pkl.gz", "rb") as f:
    networks = pickle.load(f)

In [62]:
# Load models
proj_head = torch.load(data_dir / f"proj_head_mse_sparse.pt", weights_only=False).cpu()
specter = Specter("allenai/specter2_aug2023refresh", adapter="adhoc_query")
autoencoder = torch.load(data_dir / "autoencoder_sparse.pt", weights_only=False).cpu()
# decoder = autoencoder.decoder.to("cpu")
mask_arrays = np.load(f"{data_dir}/mask.npz", allow_pickle=True)
mask_img = nib.Nifti1Image(mask_arrays["mask"].astype(float),  mask_arrays["affine"])
masker = maskers.NiftiMasker(mask_img=mask_img, dtype=np.float32).fit()

There are adapters available but none are activated for the forward pass.


In [11]:
networks = {k: v for _k in networks.keys() for k, v in networks[_k].items()}

In [None]:
networks_resampled = {}

for k, img in tqdm(networks.items(), total=len(networks)):
    img_arr = img.get_fdata()

    if len(np.unique(img_arr)) == 2:
        # binary data
        img_resampled = resample_img(img, mask_arrays["affine"], interpolation="nearest")
    else:
        img_resampled = resample_img(img, mask_arrays["affine"])
        img_resampled_arr = img_resampled.get_fdata()
        img_resampled_arr[img_resampled_arr < 0] = 0.
        thresh = np.percentile(img_resampled_arr.flatten(), 95)
        img_resampled_arr[img_resampled_arr < thresh] = 0.
        img_resampled_arr[img_resampled_arr >= thresh] = 1.
        img_resampled = nib.Nifti1Image(img_resampled_arr, affine=mask_arrays["affine"])

    networks_resampled[k] = img_resampled

  0%|          | 0/152 [00:00<?, ?it/s]

In [60]:
networks_embed = {}

for k, v in tqdm(networks_resampled.items(), total=len(networks_resampled)):
    networks_embed[k] = autoencoder.encoder(torch.from_numpy(masker.transform(v)[0]))

  0%|          | 0/152 [00:00<?, ?it/s]

In [122]:
network_labels = [
    'default mode network',
    'frontoparietal network',
    "control network",
    'cognitive control network',
    'dorsal attention network',
    'salience network attention',
    'somatosensory network',
    'motor network',
    'somatomotor network',
    'visual network',
    'language network',
    'executive control network',
    'auditory network',
    'limbic emotional network',
    'multiple demand network',
    'subcortical network',
    'cerebellar network'
]

label_embeddings = specter(network_labels)
label_embeddings = label_embeddings / label_embeddings.norm(dim=1)[:, None]

label_embeddings = proj_head(label_embeddings)
label_embeddings = label_embeddings / label_embeddings.norm(dim=1)[:, None]

In [123]:
primary_label = []
secondary_label = []
primary_label_sim_score = []
secondary_label_sim_score = []
atlas_label = []
results = {}
i = 0
for k, embed in networks_embed.items():
    embed = embed / embed.norm()
    cossim = label_embeddings @ embed
    inds = torch.argsort(cossim, descending=True)
    primary_label.append(network_labels[inds[0]])
    primary_label_sim_score.append(float(cossim[inds[0]]))
    secondary_label.append(network_labels[inds[1]])
    secondary_label_sim_score.append(float(cossim[inds[1]]))
    atlas_label.append(k)

    i += 1

In [126]:
df = pd.DataFrame({
    "atlas_label": atlas_label,
    "predicted_label_primary": primary_label,
    "predicted_label_secondary": secondary_label,
    "primary_cossim_score": primary_label_sim_score,
    "secondar_cossim_score": secondary_label_sim_score,
})

In [127]:
df.to_csv("~/Desktop/sim.csv", index=False)