In [None]:
import sys
from copy import deepcopy
from functools import partial
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.model_selection import KFold

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from neurovlm.data import get_data_dir

# NeuroConText imports, clone the neurocontext repo and fetch their data
neurocontext_dir = '/Users/anons/projects/NeuroConText'
data_dir = Path('/Users/anon/projects/NeuroConText/data_NeuroConText')
sys.path.append(neurocontext_dir)
from layers import ClipModel, ProjectionHead, ResidualHead
from losses import ClipLoss
from training import predict, train
from metrics import mix_match
from src.utils import recall_n

In [3]:
# Load NeuroConText pmids
train_pmids = np.array(list(pd.read_pickle(data_dir/ "train_pmids.pkl")))
test_pmids = np.array(list(pd.read_pickle(data_dir / "test_pmids.pkl")))
pmids = np.concatenate((train_pmids, test_pmids))
pmids = np.sort(pmids)

# Load neurovlm data
neurolm_dir = get_data_dir()
df = pd.read_parquet(get_data_dir() / "publications_more.parquet")
assert pd.Series(pmids).isin(df["pmid"]).all() # we have all the ids the neurocontext has and 10k more

# Filter by neurocontext pmids
df_nc = df[df["pmid"].isin(pmids)].copy()
df_nc.sort_values(by="pmid", inplace=True)
df_nc["description"] = df_nc["description"].str.strip(" ").str.strip("\n")
df_nc.reset_index(drop=True, inplace=True)

# NeuroConText embeddings
output_dir = Path(data_dir)

model_dir = get_data_dir() / "models"
model_dir.mkdir(exist_ok=True, parents=True)

test_gaussian_embeddings = pd.read_pickle(data_dir / 'test_gaussian_embeddings.pkl')
train_gaussian_embeddings = pd.read_pickle(data_dir / 'train_gaussian_embeddings.pkl')
test_text_embeddings = pd.read_pickle(data_dir / 'test_abstract_embeddings.pkl').values
train_text_embeddings = pd.read_pickle(data_dir / 'train_abstract_embeddings.pkl').values

sorted_indices = np.argsort(pmids)
assert (pmids[sorted_indices] == df_nc["pmid"]).all()

dataset = TensorDataset(
    torch.from_numpy(
        np.vstack((train_gaussian_embeddings, test_gaussian_embeddings))[sorted_indices]
    ).float(),
    torch.from_numpy(
        np.vstack((train_text_embeddings, test_text_embeddings))[sorted_indices]
    ).float(),
)

# NeuroConText settings
plot_verbose = True
batch_size = 128
lr = 1e-4
weight_decay = 0.1
dropout = 0.6
output_size = test_gaussian_embeddings.shape[1]
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.mps.is_available() else "cpu"
criterion = ClipLoss()
is_clip_loss = criterion.__class__ == ClipLoss
loss_specific_kwargs = {
    "logit_scale": 10 if is_clip_loss else np.log(10),
    "logit_bias": None if is_clip_loss else -10,
}
val_size = 1000

# NeuroVLM embeddings
latent_text_specter, pmids_neurovlm = torch.load(neurolm_dir / "latent_specter2_adhoc.pt", weights_only=False).values()
latent_neuro = torch.load(neurolm_dir / "latent_neuro.pt")
autoencoder = torch.load(neurolm_dir / "autoencoder.pt", weights_only=False)
decoder = autoencoder.decoder.to("cpu")
mask = pd.Series(pmids_neurovlm).isin(pmids) # mask to match neurocontext corpus
pmids_neurovlm = pmids_neurovlm[mask]
latent_neuro = latent_neuro[mask]
latent_text_specter = latent_text_specter[mask]
assert (pmids_neurovlm == pmids).all()

In [4]:
# Metrics
recall_fn = partial(recall_n, thresh=0.95, reduce_mean=True)
recall_20_nc, recall_200_nc = np.zeros(10), np.zeros(10)
recall_20_nv, recall_200_nv = np.zeros(10), np.zeros(10)
mix_match_nc = np.zeros(10)

# CV
n_epochs_nc = 50
kfolds = KFold(n_splits=10, random_state=0, shuffle=True)

for i, (inds_train, inds_test) in enumerate(kfolds.split(dataset)):

    print(f"Fold: {i}")

    # Data loaders and output directory
    fold_dir = get_data_dir() / "models" / "tmp"
    fold_dir.mkdir(exist_ok=True, parents=True)

    np.random.seed(i)
    np.random.shuffle(inds_train)

    inds_val = inds_train[:val_size] # split train into (train, val)
    inds_train = inds_train[val_size:]

    train_dataset = TensorDataset(*dataset[inds_train])
    test_dataset = TensorDataset(*dataset[inds_test])
    val_dataset = TensorDataset(*dataset[inds_val])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # NeuroConText
    model = ClipModel(
        image_model=nn.Sequential(
            ResidualHead(output_size, dropout=dropout),
            ResidualHead(output_size, dropout=dropout),
            ResidualHead(output_size, dropout=dropout),
        ),
        text_model=nn.Sequential(
            ProjectionHead(train_text_embeddings.shape[1], output_size, dropout=dropout),
            ResidualHead(output_size, dropout=dropout),
            ResidualHead(output_size, dropout=dropout),
        ),
        **loss_specific_kwargs,
    )

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay,
    )

    clip_model, clip_train_loss, clip_val_loss, callback_outputs = train(
        model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        scheduler=None,
        criterion=criterion,
        num_epochs=n_epochs_nc,
        device=device,
        verbose=True,
        output_dir=fold_dir,
        callbacks=[],
    )


    # Metrics
    with torch.no_grad():
        clip_model.load_state_dict(torch.load(fold_dir / "best_val.pt"))
        image_embeddings_nc, text_embeddings_nc = predict(clip_model, test_loader, device=device)
        image_embeddings_nc /= image_embeddings_nc.norm(dim=1)[:, None]
        text_embeddings_nc /= text_embeddings_nc.norm(dim=1)[:, None]

    # Neuorcontext
    similarity = (image_embeddings_nc @ text_embeddings_nc.T).softmax(dim=1).numpy()
    recall_20_nc[i] = recall_fn(similarity, np.eye(len(similarity)), n_first=20)
    recall_200_nc[i] = recall_fn(similarity, np.eye(len(similarity)), n_first=200)
    mix_match_nc[i] = mix_match(similarity)
    print(f"Test Loss: {recall_20_nc[i]:.4f}, Recall@20: {recall_200_nc[i]:.4f}, Recall@200: {recall_200_nc[i]:.4f}")

Fold: 0


100%|██████████| 50/50 [02:00<00:00,  2.41s/it]


Test Loss: 0.2200, Recall@20: 0.5938, Recall@200: 0.5938
Fold: 1


100%|██████████| 50/50 [02:00<00:00,  2.42s/it]


Test Loss: 0.2103, Recall@20: 0.5832, Recall@200: 0.5832
Fold: 2


100%|██████████| 50/50 [02:01<00:00,  2.43s/it]


Test Loss: 0.2166, Recall@20: 0.5880, Recall@200: 0.5880
Fold: 3


100%|██████████| 50/50 [02:07<00:00,  2.55s/it]


Test Loss: 0.2302, Recall@20: 0.5928, Recall@200: 0.5928
Fold: 4


100%|██████████| 50/50 [02:05<00:00,  2.50s/it]


Test Loss: 0.2269, Recall@20: 0.5893, Recall@200: 0.5893
Fold: 5


100%|██████████| 50/50 [02:05<00:00,  2.50s/it]


Test Loss: 0.2100, Recall@20: 0.5806, Recall@200: 0.5806
Fold: 6


100%|██████████| 50/50 [02:06<00:00,  2.53s/it]


Test Loss: 0.2245, Recall@20: 0.5830, Recall@200: 0.5830
Fold: 7


100%|██████████| 50/50 [02:06<00:00,  2.53s/it]


Test Loss: 0.2187, Recall@20: 0.5602, Recall@200: 0.5602
Fold: 8


100%|██████████| 50/50 [01:59<00:00,  2.39s/it]


Test Loss: 0.2163, Recall@20: 0.6004, Recall@200: 0.6004
Fold: 9


100%|██████████| 50/50 [01:58<00:00,  2.38s/it]


Test Loss: 0.1984, Recall@20: 0.5573, Recall@200: 0.5573


In [9]:
np.save(neurolm_dir / "recall_20_nc.npy", recall_20_nc)
np.save(neurolm_dir / "recall_200_nc.npy", recall_200_nc)
np.save(neurolm_dir / "mix_match_nc.npy", mix_match_nc)