In [None]:
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split

import sys
sys.path.append('/home/anon/repos/neurovlm/src/neurovlm')

from neurovlm.data import get_data_dir
from neurovlm.models import TextAligner
from neurovlm.loss import InfoNCELoss
from neurovlm.train import Trainer

# Load data

In [27]:
# Load your data -- might require to update to the most recent data / latents
neurovlm_dir = get_data_dir()
latent_text_specter, pmids_neurovlm = torch.load(neurovlm_dir / "recall/latent_specter2_neuro.pt", weights_only=False).values()
latent_neuro = torch.load(neurovlm_dir / "recall/new_latent_neuro.pt", map_location='cuda')
df = pd.read_parquet(neurovlm_dir / "recall/publications_more.parquet")

# Filter to common subset
pmids = np.load(neurovlm_dir / "recall/pmids_neurocontext.npy")
mask = np.array(df['pmid'].isin(pmids))
latent_neuro = latent_neuro[mask]
latent_text_specter = latent_text_specter[mask]

print(f"Dataset: {len(latent_neuro)} documents")

Dataset: 20674 documents


In [28]:
# train / test / val splits
train_idx, test_idx = train_test_split(np.arange(len(latent_neuro)), test_size=0.2, random_state=42)
train_idx, val_idx = train_test_split(train_idx, test_size=0.1, random_state=42)

# Train projection head
trainer = Trainer(
    TextAligner(seed=42),
    batch_size=4096,
    n_epochs=100,
    lr=1e-4,
    loss_fn=InfoNCELoss(),
    optimizer=torch.optim.AdamW,
    X_val=latent_text_specter[val_idx],
    y_val=latent_neuro[val_idx],
    device="auto",
    verbose=False
)

trainer.fit(
    latent_text_specter[train_idx].clone(),
    latent_neuro[train_idx].clone()
)

print("NeuroVLM training complete")

NeuroVLM training complete


In [30]:
def document_matching_accuracy(brain_embeddings, text_embeddings, n_trials=1000):
    """
    NeuroQuery's decoding experiment: 2-alternative forced choice.
    For each document, test if predicted brain pattern is closer to correct
    brain pattern than to a random one.
    """
    n_docs = len(brain_embeddings)
    correct_matches = 0

    # Normalize embeddings
    brain_norm = brain_embeddings / np.linalg.norm(brain_embeddings, axis=1, keepdims=True)
    text_norm = text_embeddings / np.linalg.norm(text_embeddings, axis=1, keepdims=True)

    for i in range(min(n_trials, n_docs)):
        # Get text and correct brain for document i
        text_pred = text_norm[i]
        correct_brain = brain_norm[i]

        # Get random brain as negative example
        random_idx = np.random.choice([j for j in range(n_docs) if j != i])
        random_brain = brain_norm[random_idx]

        # Compute similarities
        sim_correct = np.dot(text_pred, correct_brain)
        sim_random = np.dot(text_pred, random_brain)

        # Check if text is closer to correct brain than random brain
        if sim_correct > sim_random:
            correct_matches += 1

    return correct_matches / min(n_trials, n_docs)

# Test NeuroVLM
with torch.no_grad():
    brain_test = latent_neuro[test_idx].cpu().numpy()
    model_device = next(trainer.model.parameters()).device
    text_aligned = trainer.model(latent_text_specter[test_idx].to(model_device)).cpu().numpy()

neurovlm_accuracy = document_matching_accuracy(brain_test, text_aligned, n_trials=1000)

results = {
    'Method': ['Random Chance', 'NeuroQuery', 'NeuroVLM'],
    'Accuracy': [0.500, 0.720, neurovlm_accuracy]
}

results_df = pd.DataFrame(results)
print("=== COMPARISON RESULTS ===")
results_df.round(3)

=== COMPARISON RESULTS ===


Unnamed: 0,Method,Accuracy
0,Random Chance,0.5
1,NeuroQuery,0.72
2,NeuroVLM,0.8
