In [1]:
import sys
from copy import deepcopy
from functools import partial
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.model_selection import KFold

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from neurovlm.data import get_data_dir
from neurovlm.models import TextAligner
from neurovlm.loss import InfoNCELoss, recall_n, mix_match
from neurovlm.train import Trainer

In [2]:
# NeuroVLM embeddings
neurovlm_dir = get_data_dir()
latent_text_specter, pmids_neurovlm = torch.load(neurovlm_dir / "latent_specter2_neuro.pt", weights_only=False).values()
# latent_text_specter, pmids_neurovlm = torch.load(neurovlm_dir / "latent_specter2_adhoc.pt", weights_only=False).values()

latent_neuro = torch.load(neurovlm_dir / "latent_neuro.pt")
df = pd.read_parquet(neurovlm_dir / "publications_more.parquet")

In [None]:
# # Load NeuroConText pmids
# data_dir = Path('/Users/ryanhammonds/projects/NeuroConText/data_NeuroConText')
# train_pmids = np.array(list(pd.read_pickle(data_dir/ "train_pmids.pkl")))
# test_pmids = np.array(list(pd.read_pickle(data_dir / "test_pmids.pkl")))
# pmids = np.concatenate((train_pmids, test_pmids))
# pmids = np.sort(pmids)
# np.save(neurovlm_dir / "pmids_neurocontext.npy", pmids)

pmids = np.load(neurovlm_dir / "pmids_neurocontext.npy")
mask = np.array(df['pmid'].isin(pmids))
latent_neuro = latent_neuro[mask]
latent_text_specter = latent_text_specter[mask]

In [None]:
# Metrics
recall_fn = partial(recall_n, thresh=0.95, reduce_mean=True)
perf_20_nv = np.zeros(10)   # recall@20
perf_200_nv = np.zeros(10)  # recall@200
mix_match_nv = np.zeros(10)

# CV
n_epochs_nv = 300
val_size = 1000
kfolds = KFold(n_splits=10, random_state=0, shuffle=True)

for i, (inds_train, inds_test) in enumerate(kfolds.split(np.arange(len(latent_neuro)))):

    print(f"Fold: {i}")

    # Data loaders and output directory
    fold_dir = neurovlm_dir / "models" / "tmp"
    fold_dir.mkdir(exist_ok=True, parents=True)

    np.random.seed(i)
    np.random.shuffle(inds_train)

    inds_val = inds_train[:val_size]
    inds_train = inds_train[val_size:]

    # Projection head (align latent text to latent neuro)
    trainer_specter = Trainer(
        TextAligner(seed=i),
        batch_size=int(4098),
        n_epochs=n_epochs_nv,
        lr=4e-5,
        loss_fn=InfoNCELoss(),
        optimizer=torch.optim.AdamW,
        X_val=latent_text_specter[inds_val],
        y_val=latent_neuro[inds_val],
        device="auto",
        verbose=True,
        interval=10,
        use_tqdm=True
    )

    trainer_specter.fit(
        latent_text_specter[inds_train].clone(),
        latent_neuro[inds_train].clone()
    )
    trainer_specter.restore_best() # restore model with best val loss
    proj_head = trainer_specter.model.to("cpu")
    trainer_specter.save(fold_dir / "proj_head.pt")

    # Performance
    with torch.no_grad():
        # Neurovlm
        image_embeddings_nv = latent_neuro[inds_test].to("cpu").detach().clone()
        text_embeddings_nv = proj_head(latent_text_specter[inds_test]).detach()
        # Norm for cosine similarity
        image_embeddings_nv /= image_embeddings_nv.norm(dim=1)[:, None]
        text_embeddings_nv /= text_embeddings_nv.norm(dim=1)[:, None]

    # Neurovlm
    similarity = (image_embeddings_nv @ text_embeddings_nv.T).softmax(dim=1).numpy()
    perf_20_nv[i] = recall_fn(similarity, np.eye(len(similarity)), n_first=20)
    perf_200_nv[i] = recall_fn(similarity, np.eye(len(similarity)), n_first=200)
    mix_match_nv[i] = mix_match(similarity)
    break

Fold: 0
Epoch: -1, val loss: 6.9998


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch: 0, val loss: 6.9417
Epoch: 10, val loss: 6.6846
Epoch: 20, val loss: 6.5093
Epoch: 30, val loss: 6.3507
Epoch: 40, val loss: 6.283
Epoch: 50, val loss: 6.2361
Epoch: 60, val loss: 6.1974
Epoch: 70, val loss: 6.1777
Epoch: 80, val loss: 6.1638
Epoch: 90, val loss: 6.1399
Epoch: 100, val loss: 6.1311
Epoch: 110, val loss: 6.121
Epoch: 120, val loss: 6.1107
Epoch: 130, val loss: 6.1072
Epoch: 140, val loss: 6.1002
Epoch: 150, val loss: 6.0958
Epoch: 160, val loss: 6.0958
Epoch: 170, val loss: 6.0889
Epoch: 180, val loss: 6.0906
Epoch: 190, val loss: 6.0854
Epoch: 200, val loss: 6.0942
Epoch: 210, val loss: 6.0899
Epoch: 220, val loss: 6.0878
Epoch: 230, val loss: 6.0903
Epoch: 240, val loss: 6.0976
Epoch: 250, val loss: 6.0864
Epoch: 260, val loss: 6.0891
Epoch: 270, val loss: 6.0924
Epoch: 280, val loss: 6.0911
Epoch: 290, val loss: 6.0951


In [11]:
pd.DataFrame({
    "neurovlm": [perf_20_nv[i], perf_200_nv[i], mix_match_nv[i]],
    "neurocontext": [0.218, 0.596, 0.848] # from neurocontext model & dataset
}, index=["recall@20", "recall@200", "mix&match"])

Unnamed: 0,neurovlm,neurocontext
recall@20,0.13588,0.218
recall@200,0.469536,0.596
mix&match,0.796885,0.848
