In [1]:
from itertools import chain
from copy import deepcopy
from functools import partial
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.model_selection import KFold

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from neurovlm.data import get_data_dir
from neurovlm.models import TextAligner
from neurovlm.loss import InfoNCELoss, recall_n, mix_match
from neurovlm.train import Trainer

In [None]:
# NeuroVLM embeddings
neurovlm_dir = get_data_dir()
latent_text_specter_neuro, pmids_neurovlm = torch.load(neurovlm_dir / "latent_specter2_neuro.pt", weights_only=False).values()
latent_text_specter, pmids_neurovlm = torch.load(neurovlm_dir / "latent_specter2_proxi.pt", weights_only=False).values()

latent_neuro = torch.load(neurovlm_dir / "latent_neuro_nc.pt")
df = pd.read_parquet(neurovlm_dir / "publications_more.parquet")

# Mask for neurocontext papers only
pmids = np.load(neurovlm_dir / "pmids_neurocontext.npy")
mask = np.array(df['pmid'].isin(pmids))
df = df[mask]
latent_text_specter = latent_text_specter[mask]
latent_text_specter_neuro = latent_text_specter_neuro[mask]
neuro_vectors = torch.load(neurovlm_dir / "neuro_vectors.pt")[mask]

latent_neuro = latent_neuro.to("mps")
latent_text_specter = torch.column_stack((latent_text_specter, latent_text_specter_neuro)).to("mps")

In [None]:
# Metrics
recall_fn = partial(recall_n, thresh=0.95, reduce_mean=True)
recall_20_nv = np.zeros(10)   # recall@20
recall_200_nv = np.zeros(10)  # recall@200
mix_match_nv = np.zeros(10)

# CV
n_epochs_nv = 101
val_size = 1000
kfolds = KFold(n_splits=10, random_state=0, shuffle=True)

for i, (inds_train, inds_test) in enumerate(kfolds.split(np.arange(len(latent_neuro)))):

    print(f"Fold: {i}")

    # Data loaders and output directory
    fold_dir = neurovlm_dir / "models"
    fold_dir.mkdir(exist_ok=True, parents=True)

    np.random.seed(i)
    np.random.shuffle(inds_train)

    inds_val = inds_train[:val_size]
    inds_train = inds_train[val_size:]

    # Projection head (align latent text to latent neuro)
    proj_head_text = TextAligner(latent_text_dim=int(768*2), hidden_dim=768, seed=0).to('mps')
    proj_head_neuro = TextAligner(latent_text_dim=768, hidden_dim=768, seed=1).to('mps')
    optimizer = torch.optim.AdamW(chain(proj_head_neuro.parameters(), proj_head_text.parameters()), lr=5e-5)
    loss_fn = InfoNCELoss(temperature=.1)
    batch_size = int(2048 * 3)
    best_loss = np.inf
    inds_train_rand = inds_train.copy() # for random shuffling

    # Train loop
    for iepoch in range(n_epochs_nv):
        np.random.shuffle(inds_train_rand)
        for start_idx in range(0, len(inds_train), batch_size):

            # Batch
            end_idx = min(start_idx + batch_size, len(inds_train))
            batch_inds = inds_train_rand[start_idx:end_idx]

            # Forward pass
            proj_neuro = proj_head_neuro(latent_neuro[batch_inds])
            proj_text = proj_head_text(latent_text_specter[batch_inds])

            # Compute loss
            loss = loss_fn(proj_text, proj_neuro)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        with torch.no_grad():
            # Validation loss
            proj_neuro = proj_head_neuro(latent_neuro[inds_val])
            proj_text = proj_head_text(latent_text_specter[inds_val])
            loss = loss_fn(proj_text, proj_neuro)
            if loss < best_loss:
                best_loss = float(loss)
                best_proj_neuro = deepcopy(proj_head_neuro)
                best_proj_text = deepcopy(proj_head_text)

            if iepoch  % 10 == 0:
                # Report metrics occasionally
                image_embeddings_nv = proj_head_neuro(latent_neuro[inds_val]).detach()
                text_embeddings_nv = proj_head_text(latent_text_specter[inds_val]).detach()
                image_embeddings_nv /= image_embeddings_nv.norm(dim=1)[:, None]
                text_embeddings_nv /= text_embeddings_nv.norm(dim=1)[:, None]
                similarity = (image_embeddings_nv @ text_embeddings_nv.T)
                recall_20_nv_ = recall_fn(similarity.cpu().numpy(), np.eye(len(similarity)), n_first=20)
                recall_200_nv_ = recall_fn(similarity.cpu().numpy(), np.eye(len(similarity)), n_first=200)
                print(f"Epoch: {iepoch}, Loss: {loss.item():.4f}, Recall@20: {recall_20_nv_:.4f}, Recall@200: {recall_200_nv_:.4f}")

    # Report test metrics for the current fold
    image_embeddings_nv = proj_head_neuro(latent_neuro[inds_test]).detach()
    text_embeddings_nv = proj_head_text(latent_text_specter[inds_test]).detach()
    image_embeddings_nv /= image_embeddings_nv.norm(dim=1)[:, None]
    text_embeddings_nv /= text_embeddings_nv.norm(dim=1)[:, None]
    similarity = (image_embeddings_nv @ text_embeddings_nv.T).cpu().numpy()

    recall_20_nv[i] = recall_fn(similarity, np.eye(len(similarity)), n_first=20)
    recall_200_nv[i] = recall_fn(similarity, np.eye(len(similarity)), n_first=200)
    mix_match_nv[i] = mix_match(similarity)
    print(f"Test Recall@20: {recall_20_nv[i]:.4f}, Recall@200: {recall_200_nv[i]:.4f}")

    torch.save(best_proj_neuro, fold_dir / f"proj_neuro_fold{str(i).zfill(2)}.pt")
    torch.save(best_proj_text, fold_dir / f"proj_text_fold{str(i).zfill(2)}.pt")

In [None]:
np.save(neurovlm_dir / "recall_20_nv_proxi_neuro.npy", recall_20_nv)
np.save(neurovlm_dir / "recall_200_nv_proxi_neuro.npy", recall_200_nv)
np.save(neurovlm_dir / "mix_match_nv_proxi_neuro.npy", mix_match_nv)

In [None]:
# neuro
# Test Recall@20: 0.1968, Recall@200: 0.5604
# Test Recall@20: 0.2012, Recall@200: 0.5687
# Test Recall@20: 0.1929, Recall@200: 0.5546
# Test Recall@20: 0.2079, Recall@200: 0.5638
# Test Recall@20: 0.1916, Recall@200: 0.5568
# Test Recall@20: 0.1964, Recall@200: 0.5433
# Test Recall@20: 0.2037, Recall@200: 0.5617
# Test Recall@20: 0.2042, Recall@200: 0.5636
# Test Recall@20: 0.2114, Recall@200: 0.5718
# Test Recall@20: 0.1838, Recall@200: 0.5665

# neuro + proxi
# Test Recall@20: 0.2002, Recall@200: 0.5735
# Test Recall@20: 0.2132, Recall@200: 0.5812
# Test Recall@20: 0.1949, Recall@200: 0.5725
# Test Recall@20: 0.2147, Recall@200: 0.5754
# Test Recall@20: 0.2027, Recall@200: 0.5583
# Test Recall@20: 0.2148, Recall@200: 0.5539
# Test Recall@20: 0.2075, Recall@200: 0.5704
# Test Recall@20: 0.2230, Recall@200: 0.5718
# Test Recall@20: 0.2196, Recall@200: 0.5965
# Test Recall@20: 0.2090, Recall@200: 0.5762

# neuro + proxi + adhoc
# Test Recall@20: 0.2132, Recall@200: 0.5759
# Test Recall@20: 0.2113, Recall@200: 0.5837
# Test Recall@20: 0.1958, Recall@200: 0.5638
# Test Recall@20: 0.2152, Recall@200: 0.5856
# Test Recall@20: 0.2027, Recall@200: 0.5709
# Test Recall@20: 0.1945, Recall@200: 0.5588
# Test Recall@20: 0.2080, Recall@200: 0.5728
# Test Recall@20: 0.2172, Recall@200: 0.5767
# Test Recall@20: 0.2221, Recall@200: 0.5849
# Test Recall@20: 0.1911, Recall@200: 0.5830

# Neurocontext
# Test Loss: 0.2191, Recall@20: 0.5919, Recall@200: 0.5919
# Test Loss: 0.2234, Recall@20: 0.5841, Recall@200: 0.5841
# Test Loss: 0.2152, Recall@20: 0.5856, Recall@200: 0.5856
# Test Loss: 0.2340, Recall@20: 0.5924, Recall@200: 0.5924
# Test Loss: 0.2167, Recall@20: 0.5936, Recall@200: 0.5936
# Test Loss: 0.2104, Recall@20: 0.5810, Recall@200: 0.5810
# Test Loss: 0.2206, Recall@20: 0.5941, Recall@200: 0.5941
# Test Loss: 0.2206, Recall@20: 0.5670, Recall@200: 0.5670
# Test Loss: 0.2216, Recall@20: 0.6067, Recall@200: 0.6067
# Test Loss: 0.1969, Recall@20: 0.5573, Recall@200: 0.5573

In [8]:
# NeuroVLM embeddings
neurovlm_dir = get_data_dir()
latent_text_neuro, pmids_neurovlm = torch.load(neurovlm_dir / "latent_specter2_neuro.pt", weights_only=False).values()
latent_text_adhoc, pmids_neurovlm = torch.load(neurovlm_dir / "latent_specter2_adhoc.pt", weights_only=False).values()
latent_text_proxi, pmids_neurovlm = torch.load(neurovlm_dir / "latent_specter2_proxi.pt", weights_only=False).values()

latent_neuro = torch.load(neurovlm_dir / "latent_neuro_nc.pt")
df = pd.read_parquet(neurovlm_dir / "publications_more.parquet")

# Mask for neurocontext papers only
pmids = np.load(neurovlm_dir / "pmids_neurocontext.npy")
mask = np.array(df['pmid'].isin(pmids))
df = df[mask]
latent_text_adhoc = latent_text_adhoc[mask]
latent_text_neuro = latent_text_neuro[mask]
latent_text_proxi = latent_text_proxi[mask]

neuro_vectors = torch.load(neurovlm_dir / "neuro_vectors.pt")[mask]

latent_neuro = latent_neuro.to("mps")
latent_text_specter = torch.column_stack((latent_text_adhoc, latent_text_neuro, latent_text_proxi)).to("mps")

In [None]:
# Metrics
recall_fn = partial(recall_n, thresh=0.95, reduce_mean=True)
recall_20_nv = np.zeros(10)   # recall@20
recall_200_nv = np.zeros(10)  # recall@200
mix_match_nv = np.zeros(10)

# CV
n_epochs_nv = 151
val_size = 1000
kfolds = KFold(n_splits=10, random_state=0, shuffle=True)

for i, (inds_train, inds_test) in enumerate(kfolds.split(np.arange(len(latent_neuro)))):

    print(f"Fold: {i}")

    # Data loaders and output directory
    fold_dir = neurovlm_dir / "models"
    fold_dir.mkdir(exist_ok=True, parents=True)

    np.random.seed(i)
    np.random.shuffle(inds_train)

    inds_val = inds_train[:val_size]
    inds_train = inds_train[val_size:]

    # Projection head (align latent text to latent neuro)
    proj_head_text = TextAligner(latent_text_dim=int(768*3), hidden_dim=768, seed=0).to('mps')
    proj_head_neuro = TextAligner(latent_text_dim=768, hidden_dim=768, seed=1).to('mps')
    optimizer = torch.optim.AdamW(chain(proj_head_neuro.parameters(), proj_head_text.parameters()), lr=5e-5)
    loss_fn = InfoNCELoss(temperature=.1)
    batch_size = int(2048 * 3)
    best_loss = np.inf
    inds_train_rand = inds_train.copy() # for random shuffling

    # Train loop
    for iepoch in range(n_epochs_nv):
        np.random.shuffle(inds_train_rand)
        for start_idx in range(0, len(inds_train), batch_size):

            # Batch
            end_idx = min(start_idx + batch_size, len(inds_train))
            batch_inds = inds_train_rand[start_idx:end_idx]

            # Forward pass
            proj_neuro = proj_head_neuro(latent_neuro[batch_inds])
            proj_text = proj_head_text(latent_text_specter[batch_inds])

            # Compute loss
            loss = loss_fn(proj_text, proj_neuro)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        with torch.no_grad():
            # Validation loss
            proj_neuro = proj_head_neuro(latent_neuro[inds_val])
            proj_text = proj_head_text(latent_text_specter[inds_val])
            loss = loss_fn(proj_text, proj_neuro)
            if loss < best_loss:
                best_loss = float(loss)
                best_proj_neuro = deepcopy(proj_head_neuro)
                best_proj_text = deepcopy(proj_head_text)

            if iepoch  % 10 == 0:
                # Report metrics occasionally
                image_embeddings_nv = proj_head_neuro(latent_neuro[inds_val]).detach()
                text_embeddings_nv = proj_head_text(latent_text_specter[inds_val]).detach()
                image_embeddings_nv /= image_embeddings_nv.norm(dim=1)[:, None]
                text_embeddings_nv /= text_embeddings_nv.norm(dim=1)[:, None]
                similarity = (image_embeddings_nv @ text_embeddings_nv.T)
                recall_20_nv_ = recall_fn(similarity.cpu().numpy(), np.eye(len(similarity)), n_first=20)
                recall_200_nv_ = recall_fn(similarity.cpu().numpy(), np.eye(len(similarity)), n_first=200)
                print(f"Epoch: {iepoch}, Loss: {loss.item():.4f}, Recall@20: {recall_20_nv_:.4f}, Recall@200: {recall_200_nv_:.4f}")

    # Report test metrics for the current fold
    image_embeddings_nv = proj_head_neuro(latent_neuro[inds_test]).detach()
    text_embeddings_nv = proj_head_text(latent_text_specter[inds_test]).detach()
    image_embeddings_nv /= image_embeddings_nv.norm(dim=1)[:, None]
    text_embeddings_nv /= text_embeddings_nv.norm(dim=1)[:, None]
    similarity = (image_embeddings_nv @ text_embeddings_nv.T).cpu().numpy()

    recall_20_nv[i] = recall_fn(similarity, np.eye(len(similarity)), n_first=20)
    recall_200_nv[i] = recall_fn(similarity, np.eye(len(similarity)), n_first=200)
    mix_match_nv[i] = mix_match(similarity)
    print(f"Test Recall@20: {recall_20_nv[i]:.4f}, Recall@200: {recall_200_nv[i]:.4f}")

    torch.save(best_proj_neuro, fold_dir / f"proj_neuro_fold{str(i).zfill(2)}.pt")
    torch.save(best_proj_text, fold_dir / f"proj_text_fold{str(i).zfill(2)}.pt")

In [13]:
np.save(neurovlm_dir / "recall_20_nv_proxi_adhoc_neuro.npy", recall_20_nv)
np.save(neurovlm_dir / "recall_200_nv_proxi_adhoc_neuro.npy", recall_200_nv)
np.save(neurovlm_dir / "mix_match_nv_proxi_adhoc_neuro.npy", mix_match_nv)