In [1]:
from tqdm.notebook import tqdm
import pickle, gzip
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import torch

import nibabel as nib
from nilearn import maskers
from nilearn.plotting import view_img
from nilearn.image import resample_img

from neurovlm.data import get_data_dir
from neurovlm.models import Specter

from neurovlm.brain_input import search_papers_from_brain, search_wiki_from_brain

In [2]:
get_data_dir()

PosixPath('/Users/borng/code/lab_work/neurovlm/src/neurovlm/neurovlm_data')

In [3]:
# Load network atlases
with gzip.open(get_data_dir() / f"networks_arrays.pkl.gz", "rb") as f:
    networks = pickle.load(f)

In [4]:
# Load models
proj_head_mse_adhoc = torch.load(get_data_dir() / f"proj_head_image_infonce.pt", weights_only=False).cpu()
proj_head_img = torch.load(get_data_dir() / f"proj_head_image_infonce.pt", weights_only=False).cpu()
proj_head_text = torch.load(get_data_dir() / f"proj_head_text_infonce.pt", weights_only=False).cpu()
specter = Specter("allenai/specter2_aug2023refresh", adapter="adhoc_query")
autoencoder = torch.load(get_data_dir() / "autoencoder_sparse.pt", weights_only=False).cpu()
# decoder = autoencoder.decoder.to("cpu")
mask_arrays = np.load(f"{get_data_dir()}/mask.npz", allow_pickle=True)
mask_img = nib.Nifti1Image(mask_arrays["mask"].astype(float),  mask_arrays["affine"])
masker = maskers.NiftiMasker(mask_img=mask_img, dtype=np.float32).fit()

There are adapters available but none are activated for the forward pass.


In [5]:
networks = {k: v for _k in networks.keys() for k, v in networks[_k].items()}

In [6]:
mask_arrays = np.load(f"{get_data_dir()}/mask.npz", allow_pickle=True)
mask_img = nib.Nifti1Image(mask_arrays["mask"].astype(float),  mask_arrays["affine"])
masker = maskers.NiftiMasker(mask_img=mask_img, dtype=np.float32).fit()
networks_resampled = {}

for k in tqdm(networks.keys(), total=len(networks)):
    img = nib.Nifti1Image(networks[k]["array"], affine=networks[k]["affine"])

    if len(np.unique(networks[k]["array"])) == 2:
        # binary data
        img_resampled = resample_img(img, mask_arrays["affine"], interpolation="nearest")
    else:
        img_resampled = resample_img(img, mask_arrays["affine"])
        img_resampled_arr = img_resampled.get_fdata()
        img_resampled_arr[img_resampled_arr < 0] = 0.
        thresh = np.percentile(img_resampled_arr.flatten(), 95)
        img_resampled_arr[img_resampled_arr < thresh] = 0.
        img_resampled_arr[img_resampled_arr >= thresh] = 1.
        img_resampled = nib.Nifti1Image(img_resampled_arr, affine=mask_arrays["affine"])

    networks_resampled[k] = img_resampled

  0%|          | 0/152 [00:00<?, ?it/s]

In [7]:
networks_embed = {}

for k, v in tqdm(networks_resampled.items(), total=len(networks_resampled)):
    networks_embed[k] = autoencoder.encoder(torch.from_numpy(masker.transform(v)))

  0%|          | 0/152 [00:00<?, ?it/s]

## Brain to text (Network vs Paper)

In [8]:
df_similar_titles = pd.DataFrame(columns=[
    "atlas_label",
    "similar_title1",
    "similar_title2",
    "similar_title3",
    "similar_title4",
    "similar_title5"
])

df_similar_titles.head()

def add2df(df, atlas_label, similar_titles):
    new_row = {
        "atlas_label": atlas_label,
        "similar_title1": similar_titles[0] if len(similar_titles) > 0 else np.nan,
        "similar_title2": similar_titles[1] if len(similar_titles) > 1 else np.nan,
        "similar_title3": similar_titles[2] if len(similar_titles) > 2 else np.nan,
        "similar_title4": similar_titles[3] if len(similar_titles) > 3 else np.nan,
        "similar_title5": similar_titles[4] if len(similar_titles) > 4 else np.nan,
    }
    df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
    return df

In [None]:
# i need non aligned  latent text embeddings
# the wiki aligned ones are not working well
# i also dont have the most recent publication parquet file
abstract, titles = search_papers_from_brain(networks_embed["AUD"], show_titles=True)

Top matches:
1. Remembering with gains and losses: effects of monetary reward and punishment on successful encoding activation of source memories.
2. Neuroimaging of valence decisions in children and adults
3. Bacterial and Archaea Community Present in the Pine Barrens Forest of Long Island, NY: Unusually High Percentage of Ammonia Oxidizing Bacteria


In [11]:
related_wiki = search_wiki_from_brain(networks_embed["AUD"], top_k= 10 ,show_titles=True)

Top matches:
1. GNA12
2. PSMD12
3. Suicide attempt
4. GNB1
5. BCL10
6. Zaspopathy
7. GRB10
8. Pituitary disease
9. TDP-43
10. Clobenpropit


In [None]:
for key, item in tqdm(networks_embed.items(), total=len(networks_embed)):
    abstract, titles = search_papers_from_brain(item)
    df_similar_titles = add2df(df_similar_titles, key, titles)

In [None]:
df_similar_titles

In [20]:
df_similar_titles.to_csv("~/Desktop/brain2text_results.csv", index=False)