In [4]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from neurovlm.retrieval_resources import (
    _load_dataframe, _load_latent_text
)
from neurovlm.data import data_dir
from neurovlm.train import Trainer, which_device
from neurovlm.models import ConceptClf

# Concept Classifier

The concept classifier predicts which concepts are present given a latent neuro embeddings. The top-10 related concepts are passed to an LLM to summarize the brain map. Here, Llama-3.1-8B-Instruct is used to generated interpretations. Any language model may be used. Larger models or models trained one neuroscience literature may provided better brain map interpretations.

In [None]:
# N-gram embeddings, from 06_n_grams.ipyn
ngram_emb = torch.load(data_dir / "ngram_emb.pt")
mask = np.load(data_dir / "ngram_mask.npy")

# load text
df = _load_dataframe()
text = df["name"] + " [SEP] " + df["description"]

# load pre-computed ngrams from 06_n_grams.ipynb
X = np.load(data_dir / "ngram_matrix.npy")[:, mask]
features = np.load(data_dir / "ngram_labels.npy")[mask]

# load latent text
latent, pmids = _load_latent_text()

# align df order, on pmid
inds = df["pmid"].argsort().values
df = df.iloc[inds]
df.reset_index(inplace=True, drop=True)
X = X[inds]
assert (df['pmid'] == pmids).all()

In [10]:
# cosine similarity as target
y = latent @ (ngram_emb / ngram_emb.norm(dim=1)[:, None]).T
m = (y < 0.) | (torch.from_numpy(X) == 0.)
y[m==1] = 0.

# transform cosine similarity ~= probabilities
t = 0.03
tau = 0.08
y = torch.sigmoid((y - t)/ tau)

y[m] = 0.
y = y.numpy()

In [11]:
# ensure latent neuro vectors align with df
latent_neuro, pmid = torch.load(
    data_dir / "latent_neuro.pt", weights_only=False, map_location="cpu"
).values()

assert (df["pmid"] == df["pmid"].sort_values()).all()

mask = df['pmid'].isin(pmid)
df, y = df[mask], y[mask]
df.reset_index(inplace=True, drop=True)

In [12]:
# load data splits
train_ids, test_ids, val_ids = torch.load(data_dir / "pmids_split.pt", weights_only=False).values()
train_ids.sort()
val_ids.sort()
test_ids.sort()

def split(df, latent, y, pmids, device):
    mask = df['pmid'].isin(pmids).to_numpy()
    X = latent[torch.from_numpy(mask)].clone().to(device)
    y = torch.from_numpy(y[mask].copy()).float().to(device)
    pmids = pmids[pd.Series(pmids).isin(df["pmid"])]
    return X, y, pmids

device = which_device()
X_train, y_train, train_ids = split(df, latent_neuro, y, train_ids, device)
X_val, y_val, val_ids = split(df, latent_neuro, y,  val_ids, device)
X_test, y_test, test_ids = split(df, latent_neuro, y, test_ids, device)

# ensure sorted
assert (df['pmid'] == df['pmid'].sort_values()).all()
assert (train_ids == np.sort(train_ids)).all()
assert (val_ids == np.sort(val_ids)).all()
assert (test_ids == np.sort(test_ids)).all()

In [14]:
clf = ConceptClf(X.shape[1]).to(device)

loss_fn = nn.BCEWithLogitsLoss()

trainer = Trainer(
    clf,
    loss_fn,
    lr=5e-5,
    n_epochs=200,
    batch_size=1028,
    optimizer=torch.optim.AdamW,
    X_val=X_val,
    y_val=y_val,
    interval=20
)

trainer.fit(X_train, y_train)

trainer.save(data_dir / "concept_clf.pt")

Epoch: -1, val loss: 0.71944
Epoch: 0, val loss: 0.1559
Epoch: 20, val loss: 0.054673
Epoch: 40, val loss: 0.054531
Epoch: 60, val loss: 0.054543
Epoch: 80, val loss: 0.0546
Epoch: 100, val loss: 0.054684
Epoch: 120, val loss: 0.054806
Epoch: 140, val loss: 0.054949
Epoch: 160, val loss: 0.055094
Epoch: 180, val loss: 0.055262
