In [None]:
import pickle, gzip
from tqdm.notebook import tqdm
from math import ceil
import pandas as pd
import numpy as np
import torch

from nilearn.image import resample_to_img
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

from neurovlm.retrieval_resources import (
    _load_masker, _load_autoencoder, _load_networks
)
from neurovlm.data import data_dir
from neurovlm.models import ConceptClf
from neurovlm.train import which_device
device = which_device()

# Interpreting Brain Maps

## Corpus Extraction
Extract n-grams for the training corpus. N-grams are weighted by cosine similarity to article embeddings, e.g. if n-gram is highly similar to the articles it gets a value near 1, otherwise it gets a value near 0.

# Concept Classifier
The concept classifier predicts which concepts are present given a latent neuro embeddings. The top-10 related concepts are passed to an LLM to summarize the brain map. Here, Llama-3.1-8B-Instruct is used to generated interpretations. Any language model may be used. Larger models or models trained one neuroscience literature may provided better brain map interpretations.

In [2]:
# Load from 10_n_grams.ipynb
concept_clf = torch.load(data_dir / "concept_clf.pt", weights_only=False)

## Network Correspondence

Test geneartion on the network dataset. Predicted concepts are passed to the LLM to summarizes.

In [3]:
# Load
masker = _load_masker()
autoencoder = _load_autoencoder()
networks = _load_networks()

In [None]:
# Load LLM
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

model.config.pad_token_id = tokenizer.pad_token_id

# Eval mode and disable gradients
model.eval()
for p in model.parameters():
    p.requires_grad_(False)
torch.set_grad_enabled(False)

torch.autograd.grad_mode.set_grad_enabled(mode=False)

In [5]:
system_prompt = """
You are a neuroscience editor writing a short wiki-style article from a list of terms.

INPUT: a list of neuroscience terms (networks, brain regions, cognitive functions, disorders).
OUTPUT: ONE article that uses the terms to form a coherent theme.

Rules:
1) Title (required): 6–12 words. Make it specific and content-based.
   - Use 1–2 of the most informative terms (prefer: network/circuit + region + cognition; add disorder only if strongly supported).
   - DO NOT use generic titles like: "Summary", "Overview", "Brain Network Analysis", "A Summary of Terms".

2) Lead paragraph (2–3 sentences):
   - State the unifying theme directly (what the terms collectively describe).
   - Name 3–5 “anchor” terms that drive the theme.
   - Do NOT say “the provided list”, “top-ranked”, “these terms appear”, or anything about scoring/ranking.

3) Body sections:
   - Networks
   - Key Regions
   - Cognitive Functions
   - Clinical Relevance

4) Be concrete:
   - Prefer specific mechanisms, pathways, and canonical associations over vague statements.
   - If a term is too vague/ambiguous/unrelated, ignore it in the main text.

No references. Do not mention this prompt.
""".strip()

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
)

Device set to use cuda:0


In [6]:
import nibabel as nib

with gzip.open(data_dir / "networks_arrays.pkl.gz", "rb") as f:
    networks = pickle.load(f)

network_imgs = []
for k in networks.keys():
    for a in networks[k].keys():
        network_imgs.append((k, a, nib.Nifti1Image(networks[k][a]["array"], affine=networks[k][a]["affine"])))

networks = [i for i in network_imgs if i[0] not in ["UKBICA", "HCPICA"]]

In [9]:
if not (data_dir / "networks_emb.pt").exists():
    # Encode networks
    networks_emb = torch.zeros((len(networks), 384))
    for i, row in tqdm(enumerate(networks), total=len(networks)):
        
        # Encode network image
        with torch.no_grad():
            x = masker.transform(
                resample_to_img(row[2], masker.mask_img, interpolation="nearest", force_resample=True, copy_header=True)
            )
            x = autoencoder.encoder(torch.from_numpy(x))
            networks_emb[i] = x

    torch.save(networks_emb, data_dir / "networks_emb.pt")
else:
    networks_emb = torch.load(data_dir / "networks_emb.pt")

In [None]:
# Load from 06_n_grams.ipynb
features = np.load(data_dir / "ngram_labels.npy")
mask = np.load(data_dir / "ngram_mask.npy")
features = features[mask]

# Compute llm response
messages = []

for i, row in tqdm(enumerate(networks), total=len(networks)):
    
    # Concept classifier
    scores = torch.sigmoid(concept_clf(networks_emb[i].to(device)).cpu().detach())

    # Pass concepts to LLM
    user_prompt = ", ".join(features[scores.argsort().flip(0).numpy()[:20]])

    messages.append([
        {"role": "system", "content": system_prompt.strip("\n")},
        {"role": "user", "content": user_prompt},
    ])

  0%|          | 0/145 [00:00<?, ?it/s]

In [None]:
generated_summaries = []

batch_size = 5
for i in tqdm(range(0, len(networks), batch_size), total=ceil(len(networks)/batch_size)):
    with torch.inference_mode():
        outputs = pipe(
            messages[i:i+batch_size],
            max_new_tokens=1000,
            do_sample=True,
            temperature=0.2,
            top_p=0.95,
            top_k=50,
            repetition_penalty=1.1,
            return_full_text=True,
            padding=True,
            truncation=True
        )
    for idx in range(len(outputs)):
        generated_summary = outputs[idx][0]["generated_text"][-1]["content"].strip()
        row = networks[i:i+batch_size][idx]
        with open(f"generated_summaries_{row[0].lower().replace("/", "-")}_{row[1].lower().replace("/", "-")}.txt", "w") as f:
            f.write(generated_summary)  

  0%|          | 0/29 [00:00<?, ?it/s]

In [None]:
import pandas as pd
from nilearn import datasets, maskers

atlas0 = datasets.fetch_atlas_harvard_oxford("cort-maxprob-thr25-2mm")
    
atlas1 = datasets.fetch_atlas_schaefer_2018(
    n_rois=400, yeo_networks=7, resolution_mm=2
)

labels0 = []
labels1 = []
for row in tqdm(networks, total=len(networks)):
    for atlas, labels in [(atlas0, labels0), (atlas1, labels1)]:
        img = row[-1]
        labels_img = atlas.maps
        _labels = list(atlas.labels)
        masker = maskers.NiftiLabelsMasker(labels_img=labels_img, standardize=False)
        region_means = masker.fit_transform(img)
        df = pd.DataFrame({
            "label": _labels[1:],
            "mean_activation": region_means,
        })
        df = df.iloc[1:].sort_values("mean_activation", ascending=False)
        top_k = 5
        top_regions = df.head(top_k)
        labels.append(top_regions["label"].tolist())

## PubMed Test Set
todo: add test set results here

In [None]:
# y_val_pred = trainer.model(X_val) # <- change to test set, not val

# # ii = 1 # 6
# # print(df[df['pmid'].isin(val_ids).to_numpy()].iloc[ii]["name"])
# # print(df[df['pmid'].isin(val_ids).to_numpy()].iloc[ii]["description"])

# # true
# _df = pd.DataFrame({
#     "phrase": features,
#     "score": y_val[ii].to("cpu").detach().numpy()
# })

# _df.sort_values(by="score", ascending=False)

# # predicted
# _df = pd.DataFrame({
#     "phrase": features,
#     "score": torch.sigmoid(y_val_pred[ii].to("cpu").detach()).numpy()
# })
# _df.sort_values(by="score", ascending=False)