In [None]:
from pathlib import Path
import pickle, gzip
from tqdm.notebook import tqdm
import json
import re
from collections import defaultdict
from math import ceil
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch import nn

# from keybert import KeyBERT
from sklearn.feature_extraction.text import CountVectorizer

from nilearn.image import resample_to_img
from nilearn.plotting import view_img

from neurovlm.retrieval_resources import (
    _load_dataframe, _load_specter, _load_latent_text,
    _load_masker, _load_autoencoder, _load_networks
)
from neurovlm.data import data_dir
from neurovlm.train import Trainer, which_device
# from neurovlm.models import ConceptClf

# Interpreting Brain Maps

## Corpus Extraction
Extract n-grams for the training corpus. N-grams are weighted by cosine similarity to article embeddings, e.g. if n-gram is highly similar to the articles it gets a value near 1, otherwise it gets a value near 0.

# Concept Classifier
The concept classifier predicts which concepts are present given a latent neuro embeddings. The top-10 related concepts are passed to an LLM to summarize the brain map. Here, Llama-3.1-8B-Instruct is used to generated interpretations. Any language model may be used. Larger models or models trained one neuroscience literature may provided better brain map interpretations.

In [None]:
def extract_ngrams(docs, ngram_range):
    counts = CountVectorizer(
        ngram_range=ngram_range,
        stop_words="english",
        min_df=1
    ).fit(docs)

    X = counts.transform(docs)  # shape: (n_docs, n_features)

    feature_names = counts.get_feature_names_out()

    mask = np.array(X.sum(axis=0) >= 100)[0]

    X = np.array(X[:, mask].todense())
    feature_names = feature_names[mask]

    return X, feature_names

# load text
df = _load_dataframe()
text = df["name"] + " [SEP] " + df["description"]

# extract n-grams
if not (data_dir / "ngram_matrix.npy").exists():
    X_uni, features_uni = extract_ngrams(text, (1, 1))
    X_bi, features_bi = extract_ngrams(text, (2, 2))
    X_tri, features_tri = extract_ngrams(text, (3, 3))
    X = np.hstack((X_uni, X_bi, X_tri))
    features = np.concat((features_uni, features_bi, features_tri))
    np.save(data_dir / "ngram_matrix.npy", X)
    np.save(data_dir / "ngram_labels.npy", features.astype(str))
else:
    # load pre-computed
    X = np.load(data_dir / "ngram_matrix.npy")
    features = np.load(data_dir / "ngram_labels.npy")
    
# manual cleaning
DROP_SUBSTRINGS = [
    # study-like language
    "study", "studies", "result", "indicate", "show", "related",
    "differences", "significant", "effect", "role", "measure",
    "displayed", "involved", "examined", "associated", "altered",
    "performed", "demonstrated", "conclus", "correlate", "individuals",
    "common", "prior",
    # too general
    "brain", "neural", "neuroimaging", "mri", "fmri", "connectivity",
    "diagnosed", "patients", "little", "known", "activation", "blood",
    "alterations", "neuroscience", "people", # "magnetic"
]

DROP_REGEXES = [
    r"^cortex",
    # general single terms
    r"^ventral$", r"^frontal$", r"^neuronal$", r"^cognitive$",
    r"^cerebral$", r"^resting_state$",  r"^disorder$",
    r"^neuropsychological$", r"^cognition$", r"^stimulus$",
    r"^dysfunction$", r"^imaging$", r"^functional$",
    r"^functional imaging$", r"^task performance$", r"^impairments$",
    r"^traits$", r"^dysfunction$",  r"^cognitive abilities$", r"^imaging dti$",
    # [SEP] token
    r"\bsep\b",
]

pattern = "|".join(
    [re.escape(s) for s in DROP_SUBSTRINGS] +  # plain terms
    DROP_REGEXES                               # regex terms
)

mask = ~pd.Series(features).str.contains(pattern, case=False, na=False, regex=True)
features = features[mask]
X = X[:, mask]

# load latent text
latent, pmids = _load_latent_text()

# align df order, on pmid
inds = df["pmid"].argsort().values
df = df.iloc[inds]
df.reset_index(inplace=True, drop=True)
X = X[inds]
assert (df['pmid'] == pmids).all()

In [52]:
# specter embeddings for ngrams
specter = _load_specter()
specter.specter = specter.specter.eval()

if not (data_dir / "ngram_emb.pt").exists():
    ngram_emb = []
    batch_size = 512
    for i in tqdm(range(0, len(features), batch_size), total=ceil(len(features)//batch_size)):
        with torch.no_grad():
            ngram_emb.append(specter(features[i:i+batch_size].tolist()))
    ngram_emb = torch.vstack(ngram_emb)

    ngram_emb = ngram_emb / ngram_emb.norm(dim=1)[:, None] # unit vector
    torch.save(ngram_emb, data_dir / "ngram_emb.pt")
else:
    ngram_emb = torch.load(data_dir / "ngram_emb.pt")

In [None]:
# cosine similarity as target
y = latent @ (ngram_emb / ngram_emb.norm(dim=1)[:, None]).T
m = (y < 0.) | (torch.from_numpy(X) == 0.)
y[m==1] = 0.

# transform cosine similarity ~= probabilities
t = 0.03
tau = 0.08
y = torch.sigmoid((y - t)/ tau)

y[m] = 0.
y = y.numpy()

In [None]:
# ensure latent neuro vectors align with df
latent_neuro, pmid = torch.load(
    data_dir / "latent_neuro_sparse.pt", weights_only=False, map_location="cpu"
).values()

assert (df["pmid"] == df["pmid"].sort_values()).all()

mask = df['pmid'].isin(pmid)
df, y = df[mask], y[mask]
df.reset_index(inplace=True, drop=True)

In [None]:
# load data splits
train_ids, test_ids, val_ids = torch.load(data_dir / "pmids_split.pt", weights_only=False).values()
train_ids.sort()
val_ids.sort()
test_ids.sort()

def split(df, latent, y, pmids, device):
    mask = df['pmid'].isin(pmids).to_numpy()
    X = latent[torch.from_numpy(mask)].clone().to(device)
    y = torch.from_numpy(y[mask].copy()).float().to(device)
    pmids = pmids[pd.Series(pmids).isin(df["pmid"])]
    return X, y, pmids

device = which_device()
X_train, y_train, train_ids = split(df, latent_neuro, y, train_ids, device)
X_val, y_val, val_ids = split(df, latent_neuro, y,  val_ids, device)
X_test, y_test, test_ids = split(df, latent_neuro, y, test_ids, device)

# ensure sorted
assert (df['pmid'] == df['pmid'].sort_values()).all()
assert (train_ids == np.sort(train_ids)).all()
assert (val_ids == np.sort(val_ids)).all()
assert (test_ids == np.sort(test_ids)).all()

In [58]:
class ConceptClf(nn.Module):
    """Predict concepts from latent neuro embeddings."""
    def __init__(self, d_out):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Linear(384, 768),
            nn.ReLU(),
            nn.Linear(768, 1526),
            nn.ReLU(),
            nn.Linear(1526, d_out)
        )
    def forward(self, X: torch.tensor):
        return self.seq(X)
    
clf = ConceptClf(X.shape[1]).to(device)

loss_fn = nn.BCEWithLogitsLoss()

trainer = Trainer(
    clf,
    loss_fn,
    lr=5e-5,
    n_epochs=200,
    batch_size=1028,
    optimizer=torch.optim.AdamW,
    X_val=X_val,
    y_val=y_val,
    interval=20
)

trainer.fit(X_train, y_train)

Epoch: -1, val loss: 0.69402
Epoch: 0, val loss: 0.48088
Epoch: 20, val loss: 0.03878
Epoch: 40, val loss: 0.037303
Epoch: 60, val loss: 0.036246
Epoch: 80, val loss: 0.035679
Epoch: 100, val loss: 0.035344
Epoch: 120, val loss: 0.035209
Epoch: 140, val loss: 0.035075
Epoch: 160, val loss: 0.035017
Epoch: 180, val loss: 0.034976


## Network Correspondence

Test geneartion on the network dataset. Predicted concepts are passed to the LLM to summarizes.

In [None]:
masker = _load_masker()
autoencoder = _load_autoencoder()
networks = _load_networks()

In [66]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# Load LLM
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

model.config.pad_token_id = tokenizer.pad_token_id

# Eval mode and disable gradients
model.eval()
for p in model.parameters():
    p.requires_grad_(False)
torch.set_grad_enabled(False)

torch.autograd.grad_mode.set_grad_enabled(mode=False)

In [105]:
system_prompt = """
You are a neuroscience editor writing a short wiki-style article from a list of terms.

INPUT: a list of neuroscience terms (networks, brain regions, cognitive functions, disorders).
OUTPUT: ONE article that uses the terms to form a coherent theme.

Rules:
1) Title (required): 6–12 words. Make it specific and content-based.
   - Use 1–2 of the most informative terms (prefer: network/circuit + region + cognition; add disorder only if strongly supported).
   - DO NOT use generic titles like: "Summary", "Overview", "Brain Network Analysis", "A Summary of Terms".

2) Lead paragraph (2–3 sentences):
   - State the unifying theme directly (what the terms collectively describe).
   - Name 3–5 “anchor” terms that drive the theme.
   - Do NOT say “the provided list”, “top-ranked”, “these terms appear”, or anything about scoring/ranking.

3) Body sections:
   - Networks
   - Key Regions
   - Cognitive Functions
   - Clinical Relevance

4) Be concrete:
   - Prefer specific mechanisms, pathways, and canonical associations over vague statements.
   - If a term is too vague/ambiguous/unrelated, ignore it in the main text.

No references. Do not mention this prompt.
""".strip()

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
)

Device set to use cuda:0


In [None]:
import nibabel as nib

with gzip.open(data_dir / "networks_arrays.pkl.gz", "rb") as f:
        networks = pickle.load(f)

network_imgs = []
for k in networks.keys():
    for a in networks[k].keys():
        network_imgs.append((k, a, nib.Nifti1Image(networks[k][a]["array"], affine=networks[k][a]["affine"])))

networks = [i for i in network_imgs if i[0] not in ["UKBICA", "HCPICA"]]

In [None]:
# Encode networks
networks_emb = torch.zeros((len(networks), 384))
for i, row in tqdm(enumerate(networks), total=len(networks)):
    
    # Encode network image
    with torch.no_grad():
        x = masker.transform(
            resample_to_img(row[2], masker.mask_img, interpolation="nearest", force_resample=True, copy_header=True)
        )
        x = autoencoder.encoder(torch.from_numpy(x))
        networks_emb[i] = x

torch.save(networks_emb, "networks_emb.pt")

  0%|          | 0/145 [00:00<?, ?it/s]

In [149]:
messages = []

for i, row in tqdm(enumerate(networks), total=len(networks)):
    
    # Concept classifier
    scores = torch.sigmoid(clf(networks_emb[i].to(device)).cpu().detach())

    # Pass concepts to LLM
    user_prompt = ", ".join(features[scores.argsort().flip(0).numpy()[:20]])

    messages.append([
        {"role": "system", "content": system_prompt.strip("\n")},
        {"role": "user", "content": user_prompt},
    ])


  0%|          | 0/145 [00:00<?, ?it/s]

In [None]:
generated_summaries = []

batch_size = 5
for i in tqdm(range(0, len(networks), batch_size), total=ceil(len(networks)/batch_size)):
    with torch.inference_mode():
        outputs = pipe(
            messages[i:i+batch_size],
            max_new_tokens=1000,
            do_sample=True,
            temperature=0.2,
            top_p=0.95,
            top_k=50,
            repetition_penalty=1.1,
            return_full_text=True,
            padding=True,
            truncation=True
        )
    for idx in range(len(outputs)):
        generated_summary = outputs[idx][0]["generated_text"][-1]["content"].strip()
        row = networks[i:i+batch_size][idx]
        with open(f"generated_summaries_{row[0].lower().replace("/", "-")}_{row[1].lower().replace("/", "-")}.txt", "w") as f:
            f.write(generated_summary)  

  0%|          | 0/29 [00:00<?, ?it/s]

In [None]:
import pandas as pd
from nilearn import datasets, maskers

atlas0 = datasets.fetch_atlas_harvard_oxford("cort-maxprob-thr25-2mm")
    
atlas1 = datasets.fetch_atlas_schaefer_2018(
    n_rois=400, yeo_networks=7, resolution_mm=2
)

labels0 = []
labels1 = []
for row in tqdm(networks, total=len(networks)):
    for atlas, labels in [(atlas0, labels0), (atlas1, labels1)]:
        img = row[-1]
        labels_img = atlas.maps
        _labels = list(atlas.labels)
        masker = maskers.NiftiLabelsMasker(labels_img=labels_img, standardize=False)
        region_means = masker.fit_transform(img)
        df = pd.DataFrame({
            "label": _labels[1:],
            "mean_activation": region_means,
        })
        df = df.iloc[1:].sort_values("mean_activation", ascending=False)
        top_k = 5
        top_regions = df.head(top_k)
        labels.append(top_regions["label"].tolist())

## PubMed Test Set
todo: add test set results here

In [None]:
# y_val_pred = trainer.model(X_val) # <- change to test set, not val

# # ii = 1 # 6
# # print(df[df['pmid'].isin(val_ids).to_numpy()].iloc[ii]["name"])
# # print(df[df['pmid'].isin(val_ids).to_numpy()].iloc[ii]["description"])

# # true
# _df = pd.DataFrame({
#     "phrase": features,
#     "score": y_val[ii].to("cpu").detach().numpy()
# })

# _df.sort_values(by="score", ascending=False)

# # predicted
# _df = pd.DataFrame({
#     "phrase": features,
#     "score": torch.sigmoid(y_val_pred[ii].to("cpu").detach()).numpy()
# })
# _df.sort_values(by="score", ascending=False)