In [1]:
from pathlib import Path
from tqdm.notebook import tqdm
import json
import re
from collections import defaultdict
from math import ceil

import numpy as np
import pandas as pd
import torch
from torch import nn

from keybert import KeyBERT
from sklearn.feature_extraction.text import CountVectorizer

from neurovlm.retrieval_resources import _load_dataframe, _load_specter, _load_latent_text
from neurovlm.data import data_dir
from neurovlm.train import Trainer, which_device

# from neurovlm.models import ConceptClf

In [2]:
def extract_ngrams(docs, ngram_range):
    counts = CountVectorizer(
        ngram_range=ngram_range,
        stop_words="english",
        min_df=1
    ).fit(docs)

    X = counts.transform(docs)  # shape: (n_docs, n_features)

    feature_names = counts.get_feature_names_out()

    mask = np.array(X.sum(axis=0) >= 100)[0]

    X = np.array(X[:, mask].todense())
    feature_names = feature_names[mask]

    return X, feature_names

# Load text
df = _load_dataframe()
text = df["name"] + " [SEP] " + df["description"]

# Extract n-grams
if not (data_dir / "ngram_matrix.npy").exists():

    X_uni, features_uni = extract_ngrams(text, (1, 1))
    X_bi, features_bi = extract_ngrams(text, (2, 2))
    X_tri, features_tri = extract_ngrams(text, (3, 3))

    X = np.hstack((X_uni, X_bi, X_tri))
    features = np.concat((features_uni, features_bi, features_tri))

    np.save(data_dir / "ngram_matrix.npy", X)
    np.save(data_dir / "ngram_labels.npy", features.astype(str))
else:
    X = np.load(data_dir / "ngram_matrix.npy")
    features = np.load(data_dir / "ngram_labels.npy")

In [3]:
# manual cleaning
DROP_SUBSTRINGS = [
    # study-like language
    "study", "studies", "result", "indicate", "show", "related",
    "differences", "significant", "effect", "role", "measure",
    "displayed", "involved", "examined", "associated", "altered",
    "performed", "demonstrated", "conclus", "correlate", "individuals",
    "common", "prior",
    # too general
    "brain", "neural", "neuroimaging", "mri", "fmri", "connectivity",
    "diagnosed", "patients", "little", "known", "activation", "blood",
    "alterations", "neuroscience", "people",
]

DROP_REGEXES = [
    r"^cortex",
    # general single terms
    r"^ventral$", r"^frontal$", r"^neuronal$", r"^cognitive$",
    r"^cerebral$", r"^resting_state$",  r"^disorder$",
    r"^neuropsychological$", r"^cognition$", r"^stimulus$",
    r"^dysfunction$", r"^imaging$", r"^functional$",
    r"^functional imaging$", r"^task performance$", r"^impairments$",
    r"^traits$", r"^dysfunction$",  r"^cognitive abilities$", r"^imaging dti$",
    # [SEP] token
    r"\bsep\b",
]

pattern = "|".join(
    [re.escape(s) for s in DROP_SUBSTRINGS] +  # plain terms
    DROP_REGEXES                               # regex terms
)

mask = ~pd.Series(features).str.contains(pattern, case=False, na=False, regex=True)
features = features[mask]
X = X[:, mask]

In [4]:
# specter embeddings for ngrams
specter = _load_specter()
specter.specter = specter.specter.eval()

if not (data_dir / "ngram_emb.pt").exists():
    ngram_emb = []
    batch_size = 512
    for i in tqdm(range(0, len(features), batch_size), total=ceil(len(features)//batch_size)):
        with torch.no_grad():
            ngram_emb.append(specter(features[i:i+batch_size].tolist()))
    ngram_emb = torch.vstack(ngram_emb)

    ngram_emb = ngram_emb / ngram_emb.norm(dim=1)[:, None] # unit vector
    torch.save(ngram_emb, data_dir / "ngram_emb.pt")
else:
    ngram_emb = torch.load(data_dir / "ngram_emb.pt")

There are adapters available but none are activated for the forward pass.


In [5]:
# load latent text
latent, pmids = _load_latent_text()

# align df order, on pmid
inds = df["pmid"].argsort().values
df = df.iloc[inds]
df.reset_index(inplace=True, drop=True)
X = X[inds]
assert (df['pmid'] == pmids).all()

In [6]:
# # cosine similarity
# sim = latent @ (ngram_emb / ngram_emb.norm(dim=1)[:, None]).T
# sim[sim < 0.] = 0.
# sim[X == 0] = 0.
# y = sim.numpy()

In [6]:
# presence of n_grams is the prediction target
y = X
y = (y > 0).astype(float) # binary targets, rather than counts 

In [7]:
# ensure latent neuro vectors align with df
latent_neuro, pmid = torch.load(
    data_dir / "latent_neuro_sparse.pt", weights_only=False, map_location="cpu"
).values()

assert (df["pmid"] == df["pmid"].sort_values()).all()

mask = df['pmid'].isin(pmid)
df, y = df[mask], y[mask]
df.reset_index(inplace=True, drop=True)

# save
np.save(data_dir  / "y_ngram.npy", y)
np.save(data_dir / "features_ngram.npy", features)
pd.Series(features).to_csv("tmp.csv", index=False, sep="\t")

In [8]:
# load data splits
train_ids, test_ids, val_ids = torch.load(data_dir / "pmids_split.pt", weights_only=False).values()
train_ids.sort()
val_ids.sort()
test_ids.sort()

def split(df, latent, y, pmids, device):
    mask = df['pmid'].isin(pmids).to_numpy()
    X = latent[torch.from_numpy(mask)].clone().to(device)
    y = torch.from_numpy(y[mask].copy()).float().to(device)
    pmids = pmids[pd.Series(pmids).isin(df["pmid"])]
    return X, y, pmids

device = which_device()
X_train, y_train, train_ids = split(df, latent_neuro, y, train_ids, device)
X_val, y_val, val_ids = split(df, latent_neuro, y,  val_ids, device)
X_test, y_test, test_ids = split(df, latent_neuro, y, test_ids, device)

assert (df['pmid'] == df['pmid'].sort_values()).all()
assert (train_ids == np.sort(train_ids)).all()
assert (val_ids == np.sort(val_ids)).all()
assert (test_ids == np.sort(test_ids)).all()

In [9]:
# from neurovlm.retrieval_resources import _proj_head_image_infonce
# proj_head = _proj_head_image_infonce()
# proj_head = proj_head.to(device)
# X_train = proj_head(X_train).detach()
# X_test = proj_head(X_test).detach()
# X_val = proj_head(X_val).detach()

# norm = lambda x : x / x.norm(dim=1)[:, None]
# X_train = norm(X_train)
# X_test = norm(X_test)
# X_val = norm(X_val)

In [10]:
class ConceptClf(nn.Module):
    """Predict concepts from latent neuro embeddings."""
    def __init__(self, d_out):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Linear(384, 768),
            nn.ReLU(),
            nn.Linear(768, 1526),
            nn.ReLU(),
            nn.Linear(1526, d_out)
        )
    def forward(self, X: torch.tensor):
        return self.seq(X)

In [11]:
clf = ConceptClf(X.shape[1]).to(device)

loss_fn = nn.BCEWithLogitsLoss()

trainer = Trainer(
    clf,
    loss_fn,
    lr=3e-5,
    n_epochs=200,
    batch_size=1028,
    optimizer=torch.optim.AdamW,
    X_val=X_val,
    y_val=y_val,
    interval=10
)

trainer.fit(X_train, y_train)

# Epoch: -1, val loss: 0.69304
# Epoch: 0, val loss: 0.63102
# Epoch: 10, val loss: 0.060834
# Epoch: 20, val loss: 0.060065
# Epoch: 30, val loss: 0.059569
# Epoch: 40, val loss: 0.059233
# Epoch: 50, val loss: 0.059019
# Epoch: 60, val loss: 0.058871
# Epoch: 70, val loss: 0.05874
# Epoch: 80, val loss: 0.05863
# Epoch: 90, val loss: 0.058547
# Epoch: 100, val loss: 0.058478
# Epoch: 110, val loss: 0.058422
# Epoch: 120, val loss: 0.058379
# Epoch: 130, val loss: 0.058337
# Epoch: 140, val loss: 0.0583

Epoch: -1, val loss: 0.69408
Epoch: 0, val loss: 0.60878
Epoch: 10, val loss: 0.065173
Epoch: 20, val loss: 0.064304
Epoch: 30, val loss: 0.06368
Epoch: 40, val loss: 0.063078
Epoch: 50, val loss: 0.062456
Epoch: 60, val loss: 0.061853
Epoch: 70, val loss: 0.061275
Epoch: 80, val loss: 0.060775
Epoch: 90, val loss: 0.060334
Epoch: 100, val loss: 0.059996
Epoch: 110, val loss: 0.059714
Epoch: 120, val loss: 0.059508
Epoch: 130, val loss: 0.059316
Epoch: 140, val loss: 0.0592
Epoch: 150, val loss: 0.059047
Epoch: 160, val loss: 0.058971
Epoch: 170, val loss: 0.058892
Epoch: 180, val loss: 0.05882
Epoch: 190, val loss: 0.058751


## PubMed Test Set
todo: add test set results here

In [None]:
# y_val_pred = trainer.model(X_val) # <- change to test set, not val

# # ii = 1 # 6
# # print(df[df['pmid'].isin(val_ids).to_numpy()].iloc[ii]["name"])
# # print(df[df['pmid'].isin(val_ids).to_numpy()].iloc[ii]["description"])

# # true
# _df = pd.DataFrame({
#     "phrase": features,
#     "score": y_val[ii].to("cpu").detach().numpy()
# })

# _df.sort_values(by="score", ascending=False)

# # predicted
# _df = pd.DataFrame({
#     "phrase": features,
#     "score": torch.sigmoid(y_val_pred[ii].to("cpu").detach()).numpy()
# })
# _df.sort_values(by="score", ascending=False)

## Network Correspondence Test Set

Test generation on this on the DMN. This concept classifier identifies the correct terms. The LLM summarizes the terms into a wiki-like article.

In [12]:
import pickle, gzip
from neurovlm.retrieval_resources import _load_masker, _load_autoencoder
from nilearn.plotting import view_img

masker = _load_masker()

# Load network atlases
with gzip.open(data_dir / "networks.pkl.gz", "rb") as f:
    networks = pickle.load(f)
networks = [(_k, k, v) for _k in networks.keys() for k, v in networks[_k].items()]

In [13]:
x = masker.transform(networks[14][2])

autoencoder = _load_autoencoder()
with torch.no_grad():
    x = autoencoder.encoder(torch.from_numpy(x))

  return self.func(*args, **kwargs)


  return self.func(*args, **kwargs)


In [14]:
# x = proj_head.cpu()(x)
# x = x/x.norm()

In [15]:
scores = torch.sigmoid(clf(x.to(device)).cpu().detach())

_df = pd.DataFrame(dict(
    features=features[scores.argsort().flip(0).numpy()],
    scores=scores[scores.argsort().flip(0)]
))

_df.sort_values("scores", ascending=False).iloc[:20]

Unnamed: 0,features,scores
0,posterior cingulate,0.744396
1,precuneus,0.662985
2,default,0.650656
3,posterior cingulate cortex,0.60569
4,default mode,0.600453
5,default mode network,0.592126
6,mode network,0.581262
7,mode,0.579798
8,cingulate,0.57335
9,regions,0.558416


In [16]:
_df = _df.rename(columns=dict(features="terms", scores="weights"))

In [17]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# Load LLM
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

model.config.pad_token_id = tokenizer.pad_token_id

# Eval mode and disable gradients
model.eval()
for p in model.parameters():
    p.requires_grad_(False)
torch.set_grad_enabled(False)

torch.autograd.grad_mode.set_grad_enabled(mode=False)

In [None]:
system_prompt = """
Write a wiki-style article related to the following set of terms. Ignore terms that are unrelated or unspecific. Find a common theme.
"""
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
)

Device set to use cuda:0


In [94]:
user_prompt = _df.sort_values("weights", ascending=False).iloc[:10].to_string()

# user_prompt = "\n".join((user_prompt["terms"] + ", " +  user_prompt["weights"].astype(str)).values)
# user_prompt = "terms, weights\n" + user_prompt
# user_prompt

print(user_prompt)

                        terms   weights
0         posterior cingulate  0.744396
1                   precuneus  0.662985
2                     default  0.650656
3  posterior cingulate cortex  0.605690
4                default mode  0.600453
5        default mode network  0.592126
6                mode network  0.581262
7                        mode  0.579798
8                   cingulate  0.573350
9                     regions  0.558416


In [95]:
user_prompt = "\n".join(_df.sort_values("weights", ascending=False).iloc[:10]["terms"].values)

In [98]:
messages = [
    {"role": "system", "content": system_prompt.strip("\n")},
    {"role": "user", "content": user_prompt},
]

with torch.inference_mode():
    outputs = pipe(
        messages,
        max_new_tokens=1000,
        do_sample=True,
        temperature=0.2,
        top_p=0.95,
        top_k=50,
        repetition_penalty=1.1,
        return_full_text=True,
        padding=True,
        truncation=True
    )
out = outputs[0]["generated_text"][-1]["content"].strip()

In [99]:
out

'**Default Mode Network**\n\nThe Default Mode Network (DMN) is a set of brain regions that are active when an individual is not focused on the external environment and the brain is at "wakeful rest." This network is characterized by its high activity during tasks such as daydreaming, mind-wandering, and recalling past events.\n\n**Key Regions of the Default Mode Network**\n\n1. **Posterior Cingulate Cortex**: The posterior cingulate cortex (PCC) is a region in the brain that plays a crucial role in the DMN. It is involved in error detection, conflict monitoring, and attentional control.\n2. **Precuneus**: The precuneus is a region located in the parietal lobe that is also part of the DMN. It is involved in self-referential processing, memory retrieval, and spatial awareness.\n3. **Medial Prefrontal Cortex**: Although not explicitly mentioned in the given terms, the medial prefrontal cortex (mPFC) is another key region of the DMN. It is involved in self-referential thinking, emotion reg

'**Default Mode Network**

The Default Mode Network (DMN) is a set of brain regions that are active when an individual is not focused on the external environment and the brain is at "wakeful rest." This network is characterized by its high activity during tasks such as daydreaming, mind-wandering, and recalling past events.

**Key Regions of the Default Mode Network**

1. **Posterior Cingulate Cortex**: The posterior cingulate cortex (PCC) is a region in the brain that plays a crucial role in the DMN. It is involved in error detection, conflict monitoring, and attentional control.
2. **Precuneus**: The precuneus is a region located in the parietal lobe that is also part of the DMN. It is involved in self-referential processing, memory retrieval, and spatial awareness.
3. **Medial Prefrontal Cortex**: Although not explicitly mentioned in the given terms, the medial prefrontal cortex (mPFC) is another key region of the DMN. It is involved in self-referential thinking, emotion regulation, and decision-making.

**Functions of the Default Mode Network**

The DMN is thought to be involved in various cognitive processes, including:

* **Self-referential thinking**: The ability to reflect on oneself, one\'s thoughts, and emotions.
* **Memory retrieval**: The process of recalling past events and experiences.
* **Mind-wandering**: The tendency for the mind to wander away from the present moment and engage in daydreaming or fantasy.
* **Theory of mind**: The ability to attribute mental states to oneself and others.

**Abnormalities in the Default Mode Network**

Dysfunction in the DMN has been implicated in various neurological and psychiatric disorders, including:

* **Alzheimer\'s disease**: Abnormalities in the PCC have been linked to early stages of Alzheimer\'s disease.
* **Depression**: Altered activity in the mPFC and PCC has been observed in individuals with depression.
* **Schizophrenia**: Abnormalities in the DMN have been linked to symptoms of schizophrenia, such as hallucinations and delusions.

In conclusion, the default mode network is a complex system of brain regions that play a critical role in various cognitive processes, including self-referential thinking, memory retrieval, and mind-wandering. Dysfunction in this network has been implicated in various neurological and psychiatric disorders.'