In [1]:
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.optim import AdamW

from neurovlm.loss import InfoNCELoss
from neurovlm.data import data_dir
from neurovlm.train import Trainer, which_device
from neurovlm.models import TextAligner
from neurovlm.loss import InfoNCELoss
from neurovlm.retrieval_resources import _proj_head_mse_adhoc
proj_head = _proj_head_mse_adhoc()
device = which_device()

# Projection Head

Projection head refers to a small network to align the latent spaces between text and neuroimages. This notebook learns a project head to map the image and neuro latent spaces to one another using a contrastive loss function.

In [2]:
# Load encoded neurovectors from the second notebook
latent_neuro, pmids_latent = torch.load(data_dir / "latent_neuro_sparse.pt", weights_only=False).values()
# latent_neuro = latent_neuro / latent_neuro.norm(dim=1)[:, None]

# Load encoded text from last notebook
latent_text_specter, pmids = torch.load(data_dir / "latent_specter2_adhoc.pt", weights_only=False).values()
inds = np.argsort(pmids)
latent_text_specter, pmids = latent_text_specter[inds], pmids[inds]

mask =  pd.Series(pmids).isin(pmids_latent)
pmids = pmids[mask]
latent_text_specter = latent_text_specter[mask]
assert (pmids == pmids_latent).all()

In [None]:
# mask = latent_neuro.norm(dim=1).detach().cpu().numpy() < 35 # sparser targets
# pmids = pmids[mask]
# latent_neuro = latent_neuro[mask]
# latent_text_specter = latent_text_specter[mask]

In [3]:
# Load splits
ids_train, ids_test, ids_val = torch.load(data_dir / "pmids_split.pt", weights_only=False).values()
train_inds = np.where(pd.Series(pmids).isin(ids_train))[0]
test_inds = np.where(pd.Series(pmids).isin(ids_test))[0]
val_inds = np.where(pd.Series(pmids).isin(ids_val))[0]

In [14]:
proj_head = _proj_head_mse_adhoc()

# Split data
X_train_image = latent_neuro[train_inds].to(device)
X_train_text  = latent_text_specter[train_inds].to(device)
X_val_image = latent_neuro[val_inds].to(device)
X_val_text  = latent_text_specter[val_inds].to(device)

# Models
proj_head_text  = proj_head.to(device) # initialize with the decoder model
# proj_head_text  = TextAligner(seed=123, latent_text_dim=768, hidden_dim=512, latent_neuro_dim=384).to(device)
proj_head_image = TextAligner(seed=123, latent_text_dim=384, hidden_dim=384, latent_neuro_dim=384).to(device)

# Settings
loss_fn = InfoNCELoss(temperature=0.2)
n_epochs = 200
batch_size = 2048#512
lr = 1e-5
optimizer = AdamW([*proj_head_text.parameters(), *proj_head_image.parameters()], lr=lr)
interval = 10

# Train
iterable = tqdm(range(n_epochs), total=n_epochs)

for iepoch in iterable:

    proj_head_text.train()
    proj_head_image.train()

    # Randomly shuffle and batch
    torch.manual_seed(iepoch)
    rand_inds = torch.randperm(len(X_train_image), device=device)

    for i in range(0, len(X_train_image), batch_size):
        idx = rand_inds[i:i+batch_size]

        # Forward
        y_text  = proj_head_text(X_train_text[idx])
        y_image = proj_head_image(X_train_image[idx])

        # Loss
        loss = loss_fn(y_text, y_image)

        # Backward
        optimizer.zero_grad(set_to_none=True)
        loss.backward()

        # Step
        optimizer.step()

    # Report validation
    if iepoch % interval == 0 or iepoch == (n_epochs - 1):
        proj_head_text.eval()
        proj_head_image.eval()
        with torch.no_grad():
            y_text  = proj_head_text(X_val_text)
            y_image = proj_head_image(X_val_image)
            val_loss = loss_fn(y_text, y_image)
            print(f"Epoch: {iepoch}, val loss: {float(val_loss):.5g}")

  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 0, val loss: 7.8383
Epoch: 10, val loss: 7.5797
Epoch: 20, val loss: 7.4161
Epoch: 30, val loss: 7.3353
Epoch: 40, val loss: 7.2921
Epoch: 50, val loss: 7.2657
Epoch: 60, val loss: 7.2477
Epoch: 70, val loss: 7.2343
Epoch: 80, val loss: 7.2234
Epoch: 90, val loss: 7.2146
Epoch: 100, val loss: 7.207
Epoch: 110, val loss: 7.2005
Epoch: 120, val loss: 7.1945
Epoch: 130, val loss: 7.1893
Epoch: 140, val loss: 7.1843
Epoch: 150, val loss: 7.1801
Epoch: 160, val loss: 7.1761
Epoch: 170, val loss: 7.1724
Epoch: 180, val loss: 7.169
Epoch: 190, val loss: 7.1657
Epoch: 199, val loss: 7.1633


In [15]:
torch.save(proj_head_text, data_dir / "proj_head_text_infonce.pt")
torch.save(proj_head_image, data_dir / "proj_head_image_infonce.pt")