In [105]:
from tqdm.notebook import tqdm
import pickle, gzip
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import torch

import nibabel as nib
from nilearn import maskers
from nilearn.plotting import view_img
from nilearn.image import resample_img

from neurovlm.data import data_dir
from neurovlm.models import Specter
from neurovlm.retrieval_resources import (
    _load_autoencoder, _proj_head_image_infonce, _proj_head_text_infonce, _load_masker
)

In [106]:
# Load network atlases
with gzip.open("/Users/ryanhammonds/projects/cbig_network_correspondence/networks.pkl.gz", "rb") as f:
    networks = pickle.load(f)
networks = [(_k, k, v) for _k in networks.keys() for k, v in networks[_k].items()]

# Load models
specter = Specter("allenai/specter2_aug2023refresh", adapter="adhoc_query")
proj_head_text = _proj_head_text_infonce()
proj_head_image = _proj_head_image_infonce()
autoencoder = _load_autoencoder()
masker = _load_masker()

There are adapters available but none are activated for the forward pass.


In [None]:
# Resample networks
networks_resampled = []

for img in tqdm(networks, total=len(networks)):
    img = img[2]
    img_arr = img.get_fdata()

    if len(np.unique(img_arr)) == 2:
        # binary data
        img_resampled = resample_img(img, masker.affine_, interpolation="nearest")
    else:
        img_resampled = resample_img(img, masker.affine_)
        img_resampled_arr = img_resampled.get_fdata()
        img_resampled_arr[img_resampled_arr < 0] = 0.
        thresh = np.percentile(img_resampled_arr.flatten(), 95)
        img_resampled_arr[img_resampled_arr < thresh] = 0.
        img_resampled_arr[img_resampled_arr >= thresh] = 1.
        img_resampled = nib.Nifti1Image(img_resampled_arr, affine=masker.affine_)

    networks_resampled.append(img_resampled)

# Encode networks
networks_embed = []
for v in tqdm(networks_resampled, total=len(networks_resampled)):
    networks_embed.append(autoencoder.encoder(torch.from_numpy(masker.transform(v)[0])))

with open(data_dir / "networks_embed.pkl", "wb") as f:
    pickle.dump(networks_embed, f)

  0%|          | 0/185 [00:00<?, ?it/s]

  0%|          | 0/185 [00:00<?, ?it/s]

In [28]:
# Skip above and load
with open(data_dir / "networks_embed.pkl", "rb") as f:
    networks_embed = pickle.load(f)

for i in range(len(networks_embed)):
    networks_embed[i] = networks_embed[i].detach()

In [None]:

# Network list
network_labels = [
    'default mode network',
    'frontoparietal network',
    "control network",
    'cognitive control network',
    'dorsal attention network',
    'salience network',
    'somatosensory network',
    'motor network',
    "primary motor network",
    'somatomotor network',
    'visual network',
    'language network',
    'executive control network',
    'auditory network',
    'limbic emotional network',
    'multiple demand network',
    'subcortical network',
    'cerebellar network',
    'cingulo-opercular network'
]

# Specter
label_embeddings = specter(network_labels)
label_embeddings = label_embeddings / label_embeddings.norm(dim=1)[:, None]

# Project
label_embeddings = proj_head_text(label_embeddings)
label_embeddings = label_embeddings / label_embeddings.norm(dim=1)[:, None]

# Label networks
atlas = []
atlas_label = []

primary_label = []
secondary_label = []
primary_label_sim_score = []
secondary_label_sim_score = []

for names, img_embed in zip(networks, networks_embed):

    # im_embed = proj_head_image(img_embed / img_embed.norm())
    im_embed = proj_head_image(img_embed)
    im_embed = im_embed / im_embed.norm()

    cossim = label_embeddings @ im_embed
    inds = torch.argsort(cossim, descending=True)

    primary_label.append(network_labels[inds[0]])
    primary_label_sim_score.append(float(cossim[inds[0]]))

    secondary_label.append(network_labels[inds[1]])
    secondary_label_sim_score.append(float(cossim[inds[1]]))

    atlas.append(names[0])
    atlas_label.append(names[1])

# Results
df = pd.DataFrame({
    "atlas": atlas,
    "atlas_label": atlas_label,
    "predicted_label_primary": primary_label,
    "predicted_label_secondary": secondary_label,
    "primary_cossim_score": primary_label_sim_score,
    "secondar_cossim_score": secondary_label_sim_score,
})

df.iloc[:10]

Unnamed: 0,atlas,atlas_label,predicted_label_primary,predicted_label_secondary,primary_cossim_score,secondar_cossim_score
0,Du,VIS-P,visual network,dorsal attention network,0.169895,0.156872
1,Du,CG-OP,executive control network,somatosensory network,0.113601,0.077232
2,Du,DN-B,default mode network,frontoparietal network,0.201371,0.160549
3,Du,SMOT-B,somatosensory network,primary motor network,0.175995,0.156031
4,Du,AUD,auditory network,language network,0.226339,0.075596
5,Du,PM-PPr,primary motor network,somatosensory network,0.172838,0.172285
6,Du,dATN-B,visual network,dorsal attention network,0.362174,0.274632
7,Du,SMOT-A,primary motor network,somatosensory network,0.324312,0.323186
8,Du,LANG,language network,auditory network,0.135207,0.042561
9,Du,FPN-B,default mode network,cognitive control network,0.19456,0.182519
