In [1]:
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import re
from sklearn.model_selection import train_test_split

import torch
from torch import nn
import torch.nn.functional as F
from safetensors.torch import save_file
import warnings

from tqdm.notebook import tqdm


from sentence_transformers import SentenceTransformer, SimilarityFunction


import nibabel as nib
from nilearn.datasets import fetch_atlas_harvard_oxford,load_mni152_template
from nilearn.plotting import plot_glass_brain, view_img
from nilearn.image import load_img, smooth_img, resample_img, coord_transform,resample_to_img
from nilearn import datasets
from nilearn.maskers import NiftiMasker

from neurovlm.data import fetch_data
from neurovlm.coords import coords_to_vectors
from neurovlm.models import NeuroAutoEncoder, TextAligner
from neurovlm.train import Trainer, which_device
device = which_device()

import os


from neuroquery import datasets
from neuroquery.img_utils import gaussian_coord_smoothing, coords_to_peaks_img
from neuroquery._compat import maskers, load_mni152_brain_mask


from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import pipeline


from neurovlm.data import fetch_data
from neurovlm.models import NeuroAutoEncoder, TextAligner

warnings.filterwarnings("ignore")

In [2]:
data_dir = fetch_data()
df_pubs = pd.read_parquet(f"{data_dir}/publications.parquet")
df_coords = pd.read_parquet(f"{data_dir}/coordinates.parquet")

In [3]:
# Load pre-trained
neuro_encoder = torch.load(
    "docs/autoencoder.pt", weights_only=False
).to("cpu")

neuro_decoder = torch.load(
    "docs/decoder_half.pt", weights_only=False
).to("cpu")

text_aligner_half = torch.load(
    "specter/aligner_half.pt", weights_only=False
).to("cpu")

text_aligner = torch.load(
    "specter/aligner.pt", weights_only=False
).to("cpu")


latent_titles = torch.load(
    "specter/latent_text.pt", weights_only=False
).to("cpu")







In [4]:
latent_titles.size()


torch.Size([28757, 768])

In [5]:
df_pubs.__len__()

28393

In [11]:
# Remove duplicates from latent_titles
unique_titles = torch.unique(latent_titles, dim=0)


In [13]:
unique_titles.size()

torch.Size([28472, 768])

In [3]:
# Get total number of batches
total_batches = 3550
chunk_size = total_batches // 4

def process_quarter(start_idx, end_idx, quarter_name):
    print(f"Processing {quarter_name}...")
    latent_text_batches = []
    for i in tqdm(range(start_idx, end_idx + 1)):
        batch_tensor = torch.load(f"batches/latent_text_batch_{i}.pt")
        latent_text_batches.append(batch_tensor)
    
    # Concatenate batches
    quarter_tensor = torch.cat(latent_text_batches, dim=0)
    
    # Remove duplicates within the quarter
    unique_tensor, inverse_indices = torch.unique(quarter_tensor, dim=0, return_inverse=True)
    print(f"{quarter_name} unique vectors: {len(unique_tensor)} (removed {len(quarter_tensor) - len(unique_tensor)} duplicates)")
    
    torch.save(unique_tensor, f"latent_text_{quarter_name}.pt")
    return unique_tensor



In [None]:
# Process each quarter
latent_text_first = process_quarter(1, chunk_size, "first_quarter")


Processing first_quarter...


  0%|          | 0/887 [00:00<?, ?it/s]

In [None]:
latent_text_second = process_quarter(chunk_size + 1, chunk_size * 2, "second_quarter")
latent_text_third = process_quarter(chunk_size * 2 + 1, chunk_size * 3, "third_quarter")
latent_text_fourth = process_quarter(chunk_size * 3 + 1, total_batches, "fourth_quarter")

# Combine all quarters and remove any remaining duplicates
print("Combining quarters and removing final duplicates...")
latent_text = torch.cat([
    latent_text_first,
    latent_text_second,
    latent_text_third,
    latent_text_fourth
], dim=0)

# Remove any duplicates that might exist between quarters
unique_latent_text, inverse_indices = torch.unique(latent_text, dim=0, return_inverse=True)
print(f"Final unique vectors: {len(unique_latent_text)} (removed {len(latent_text) - len(unique_latent_text)} duplicates)")
torch.save(unique_latent_text, "latent_text.pt")

print("Complete!")

# Map to brain

In [4]:
titles = df_pubs['name'].values

In [10]:
# # Align latent text using aligner_half
# aligned_titles_half_list = []
# for title_vec in latent_titles:
#     aligned_vec = text_aligner_half(title_vec.unsqueeze(0))
#     aligned_titles_half_list.append(aligned_vec)
# aligned_titles_half = torch.cat(aligned_titles_half_list, dim=0)
# torch.save(aligned_titles_half, "specter/aligned_titles_half.pt")




In [15]:
aligned_text = torch.load(
    "specter/aligned_text.pt", weights_only=False
).to("cpu")
# aligned_titles_half



# Cosine similarity between user input and aligned tensor

Compute query vector, take the top-k most related studies

In [12]:
from nilearn.datasets import fetch_neurovault
maps = fetch_neurovault(collection_ids=[1039], get_data=True)



[fetch_neurovault] fetch_neurovault: using default value of 100 for max_images. Set max_images to another value or None if you want more images.
[get_dataset_dir] Dataset found in /Users/borng/nilearn_data/neurovault
[fetch_neurovault] Reading local neurovault data.
[fetch_neurovault] Already fetched 1 image
[fetch_neurovault] Already fetched 2 images
[fetch_neurovault] Already fetched 3 images
[fetch_neurovault] Already fetched 4 images
[fetch_neurovault] Already fetched 5 images
[fetch_neurovault] Already fetched 6 images
[fetch_neurovault] Already fetched 7 images
[fetch_neurovault] Already fetched 8 images
[fetch_neurovault] Already fetched 9 images
[fetch_neurovault] Already fetched 10 images
[fetch_neurovault] Already fetched 11 images
[fetch_neurovault] Already fetched 12 images
[fetch_neurovault] Already fetched 13 images
[fetch_neurovault] Already fetched 14 images
[fetch_neurovault] Already fetched 15 images
[fetch_neurovault] Already fetched 16 images
[fetch_neurovault] Alre

In [13]:
def transfrom_nifti_to_2d(nifti_img, mask_img_path=None):
    """
    Transforms a NIfTI image into a 2D flattened array using a mask image.
    
    Parameters:
    - nifti_img: nibabel NIfTI image object.
    - mask_img_path: str, optional path to a mask image. If not provided, 
      the function uses the NeuroQuery model mask.
      
    Returns:
    - out: np.ndarray, the 2D flattened array representation of the NIfTI image.
    """
    region_img = (nifti_img.get_fdata()).astype(float)

    region_nii = nib.Nifti1Image(region_img.astype(np.int32), nifti_img.affine, dtype=np.int32)

    # Load the mask image; if no path is provided, load from the NeuroQuery model.
    if mask_img_path is None:
        mask_img = load_img(f"{datasets.fetch_neuroquery_model()}/mask_img.nii.gz", dtype=np.float32)
    else:
        mask_img = load_img(mask_img_path, dtype=np.float32)

    masker = NiftiMasker(mask_img=mask_img, dtype=np.float32).fit()

    region_nii_resampled = resample_to_img(
        region_nii, mask_img, interpolation='nearest', force_resample=True, copy_header=False
    )
    
    # Transform the resampled region image into a 2D array (flattened) using the masker.
    out = masker.transform(region_nii_resampled)

    out[out > 0] = 1

    out_tensor = torch.tensor(out).squeeze(0)

    return out_tensor
    
    

In [14]:
# Path to the NeuroVault collection folder
path = "/Users/borng/nilearn_data/neurovault/collection_1039"

# Make sure the path exists
if not os.path.exists(path):
    raise FileNotFoundError(f"Path does not exist: {path}")

# Loop through and load all .nii.gz files
nifti_files = [f for f in os.listdir(path) if f.endswith(".nii.gz")]

nifti_flattened = []
for file in tqdm(nifti_files):
    full_path = os.path.join(path, file)
    img = nib.load(full_path)
    nifti_flattened.append(transfrom_nifti_to_2d(img))

print(f"\nTotal NIfTI files loaded: {len(nifti_flattened)}")


  0%|          | 0/41 [00:00<?, ?it/s]


Total NIfTI files loaded: 41


In [16]:
top_k = 10

numpy_text = aligned_text.detach().cpu().numpy().astype(np.float32)

query_vec = nifti_flattened[0]

# make sure query_vec has a batch dim:
q = query_vec.unsqueeze(0)          # shape (1, 28542)

# run it *through the encoder only*:
with torch.no_grad():
    z = neuro_encoder.encoder(q)    # → shape (1, 384)

z = z.squeeze(0)                    # → shape (384,)
encoded_vec = z.cpu().numpy().astype(np.float16)

# now both are (·,384) and (384,) and you can do:
cos_sim = (
    (numpy_text / np.linalg.norm(numpy_text, axis=1, keepdims=True)) @
    (encoded_vec / np.linalg.norm(encoded_vec))
)


top_inds = np.argsort(cos_sim)[::-1][:top_k]

for i in top_inds:
    print(titles[i])


The interplay of prefrontal and sensorimotor cortices during inhibitory control of learned motor behavior
Social cognition, behaviour and therapy adherence in frontal lobe epilepsy: a study combining neuroeconomic and neuropsychological methods
Neurofunctional Signature of Hyperfamiliarity for Unknown Faces.
Mind-wandering in Parkinson's disease hallucinations reflects primary visual and default network coupling.
Corrigendum to “Fear Processing in Dental Phobia during Crossmodal Symptom Provocation: An fMRI Study”
Structural and functional connectivity of the subthalamic nucleus during vocal emotion decoding.
Uncovering a context-specific connectional fingerprint of human dorsal premotor cortex.
Depressive mood in pre-dialytic chronic kidney disease: Statistical parametric mapping analysis of Tc-99m ECD brain SPECT
The effect of stimulus context on pitch representations in the human auditory cortex.
Modulation of neuronal activity after spinal cord stimulation for neuropathic pain; H(2

In [17]:
top_k = 10

numpy_text = aligned_text.detach().cpu().numpy().astype(np.float32)

query_vec = nifti_flattened[2]

# make sure query_vec has a batch dim:
q = query_vec.unsqueeze(0)          # shape (1, 28542)

# run it *through the encoder only*:
with torch.no_grad():
    z = neuro_encoder.encoder(q)    # → shape (1, 384)

z = z.squeeze(0)                    # → shape (384,)
encoded_vec = z.cpu().numpy().astype(np.float16)

# now both are (·,384) and (384,) and you can do:
cos_sim = (
    (numpy_text / np.linalg.norm(numpy_text, axis=1, keepdims=True)) @
    (encoded_vec / np.linalg.norm(encoded_vec))
)


top_inds = np.argsort(cos_sim)[::-1][:top_k]

for i in top_inds:
    print(titles[i])


Social cognition, behaviour and therapy adherence in frontal lobe epilepsy: a study combining neuroeconomic and neuropsychological methods
Mind-wandering in Parkinson's disease hallucinations reflects primary visual and default network coupling.
An optimized voxel-based morphometric study of gray matter changes in patients with left-sided and right-sided mesial temporal lobe epilepsy and hippocampal sclerosis (MTLE/HS).
Depressive mood in pre-dialytic chronic kidney disease: Statistical parametric mapping analysis of Tc-99m ECD brain SPECT
Insular and Anterior Cingulate Circuits in Smokers with Schizophrenia
Modulation of neuronal activity after spinal cord stimulation for neuropathic pain; H(2)15O PET study.
A balancing act of the brain: activations and deactivations driven by cognitive load.
Corrigendum to “Fear Processing in Dental Phobia during Crossmodal Symptom Provocation: An fMRI Study”
When Action Observation Facilitates Visual Perception: Activation in Visuo-Motor  Areas Cont

In [23]:
df_pubs

Unnamed: 0,pmid,pmcid,doi,name,description
0,24911975,,10.1371/journal.pone.0099222,Acute aerobic exercise increases cortical acti...,There is increasing evidence that acute aerobi...
1,22884992,,10.1016/j.dcn.2012.07.001,Developmental differences in the neural correl...,Despite vast knowledge on the behavioral proce...
2,15722210,,10.1016/j.cogbrainres.2004.09.011,The neural substrate of arithmetic operations ...,Recent functional neuroimaging studies have be...
3,21930137,,10.1016/j.neuropsychologia.2011.09.006,Neural processing associated with comprehensio...,"In daily communication, we often use indirect ..."
4,21930160,,10.1097/gme.0b013e3181cc49e9,Postmenopausal hormone use impact on emotion p...,Despite considerable evidence for potential ef...
...,...,...,...,...,...
28388,11923438,,10.1523/JNEUROSCI.22-07-02730.2002,The neural correlates of moral sensitivity: a ...,Humans are endowed with a natural sense of fai...
28389,12873805,,10.1016/S0006-3223(02)01749-3,Abnormalities in emotion processing within cor...,BACKGROUND: Neurobiology of psychopathy is imp...
28390,19925196,,10.1162/jocn.2009.21387,Virus and epidemic: causal knowledge activates...,Knowledge about cause and effect relationships...
28391,16001111,,10.1007/s00213-005-0077-5,A functional MRI study of the effects of bromo...,RATIONALE: Dopamine is abundant in the prefron...


In [20]:
aligned_text.size()

torch.Size([28757, 384])

# Brain map to text. return top 10 papers with summarized abstract

In [4]:
facebook_bart = pipeline("summarization", model="facebook/bart-large-cnn")


Device set to use mps:0


In [6]:

def summarize(article):
    """
    Summarizes an article using facebook/bart-large-cnn model
    
    Args:
        article (str): Text to summarize
        
    Returns:
        str: Summarized text
    """
    dic = facebook_bart(article, max_length=256, min_length=50, do_sample=False)
    summary_text = dict(dic[0])['summary_text']
    return summary_text


In [7]:
top_k = 10

numpy_titles = aligned_titles_half.detach().cpu().numpy().astype(np.float16)

query_vec = nifti_flattened[2]

# make sure query_vec has a batch dim:
q = query_vec.unsqueeze(0)          # shape (1, 28542)

# run it *through the encoder only*:
with torch.no_grad():
    z = neuro_encoder.encoder(q)    # → shape (1, 384)

z = z.squeeze(0)                    # → shape (384,)
encoded_vec = z.cpu().numpy().astype(np.float16)

# now both are (·,384) and (384,) and you can do:
cos_sim = (
    (numpy_titles / np.linalg.norm(numpy_titles, axis=1, keepdims=True)) @
    (encoded_vec / np.linalg.norm(encoded_vec))
)


top_inds = np.argsort(cos_sim)[::-1][:top_k]

for i in top_inds:
    print(titles[i])

NameError: name 'aligned_titles_half' is not defined