In [1]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from sklearn.model_selection import train_test_split
from neurovlm.data import get_data_dir
from neurovlm.train import Trainer, which_device
from neurovlm.models import TextAligner
from neurovlm.loss import InfoNCELoss

device = which_device()
data_dir = get_data_dir()

# Projection Head

Projection head refers to a small network to align the latent spaces between text and neuroimages. The training regime starts with MSELoss, then gradually removed the influences of outliers through truncation, i.e. masking out the top-k% of loss instances from gradient computation.

In [2]:
# Load encoded neurovectors from the second notebook
latent_neuro, pmids_latent = torch.load(data_dir / "latent_neuro_sparse.pt", weights_only=False).values()

# Load encoded text from last notebook
latent_text_specter, pmids = torch.load(data_dir / "latent_specter2_adhoc.pt", weights_only=False).values()
inds = np.argsort(pmids)
latent_text_specter, pmids = latent_text_specter[inds], pmids[inds]

In [3]:
# Sparse
mask =  pd.Series(pmids).isin(pmids_latent)
pmids = pmids[mask]
latent_text_specter = latent_text_specter[mask]
assert (pmids == pmids_latent).all()

mask = latent_neuro.norm(dim=1).detach().cpu().numpy() < 35 # sparser targets
latent_neuro = latent_neuro[mask]
latent_text_specter = latent_text_specter[mask]

In [4]:
torch.save(dict(
    latent=latent_neuro,
    pmid=pmids[mask],
), data_dir/"latent_neuro_sparse.pt")

torch.save(dict(
    latent=latent_text_specter,
    pmid=pmids[mask],
), data_dir/"latent_text_sparse.pt")

In [5]:
# Train/test/validation split
inds = torch.arange(len(latent_neuro))
train_inds, test_inds = train_test_split(
    inds, train_size=0.8, random_state=0
)
test_inds, val_inds = train_test_split(
    test_inds, train_size=0.5, random_state=1
)

In [None]:
import torch
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from torch.optim import AdamW
from neurovlm.loss import InfoNCELoss
from tqdm.notebook import tqdm

# Set device
device = "mps"

X_train_image = latent_neuro[train_inds].to(device)
X_train_text  = latent_text_specter[train_inds].to(device)
X_val_image = latent_neuro[val_inds].to(device)
X_val_text  = latent_text_specter[val_inds].to(device)

# Models
proj_head_text  = TextAligner(seed=123, latent_text_dim=768, hidden_dim=512, latent_neuro_dim=384).to(device)
proj_head_image = TextAligner(seed=123, latent_text_dim=384, hidden_dim=384, latent_neuro_dim=384).to(device)

# Settings
loss_fn = InfoNCELoss()
n_epochs = 200
batch_size = 512
lr = 1e-5
optimizer = AdamW([*proj_head_text.parameters(), *proj_head_image.parameters()], lr=lr, weight_decay=1e-4)  # small wd helps
interval = 1
max_grad_norm = 1.0  # gradient clipping

# Train
iterable = tqdm(range(n_epochs), total=n_epochs)

for iepoch in iterable:
    proj_head_text.train()
    proj_head_image.train()

    # Randomly shuffle data
    torch.manual_seed(iepoch)
    rand_inds = torch.randperm(len(X_train_image), device=device)

    for i in range(0, len(X_train_image), batch_size):
        idx = rand_inds[i:i+batch_size]

        # Forward
        y_text  = proj_head_text(X_train_text[idx])
        y_image = proj_head_image(X_train_image[idx])

        # Loss
        loss = loss_fn(y_text, y_image)

        # Backward
        optimizer.zero_grad(set_to_none=True)
        loss.backward()

        # Clip to tame spikes
        # clip_grad_norm_([*proj_head_text.parameters(), *proj_head_image.parameters()], max_grad_norm)

        # Step
        optimizer.step()

    if iepoch % interval == 0:
        proj_head_text.eval()
        proj_head_image.eval()
        with torch.no_grad():
            y_text  = proj_head_text(X_val_text)
            y_image = proj_head_image(X_val_image)

            y_text  = F.normalize(y_text,  dim=-1)
            y_image = F.normalize(y_image, dim=-1)

            val_loss = loss_fn(y_text, y_image)
            print(f"Epoch: {iepoch}, val loss: {float(val_loss):.5g}")


  0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 0, val loss: 7.8888
Epoch: 1, val loss: 7.8262
Epoch: 2, val loss: 7.7599
Epoch: 3, val loss: 7.7002
Epoch: 4, val loss: 7.6519
Epoch: 5, val loss: 7.6101
Epoch: 6, val loss: 7.5721
Epoch: 7, val loss: 7.536
Epoch: 8, val loss: 7.5012
Epoch: 9, val loss: 7.4705
Epoch: 10, val loss: 7.4407
Epoch: 11, val loss: 7.4154
Epoch: 12, val loss: 7.3923
Epoch: 13, val loss: 7.3734
Epoch: 14, val loss: 7.3578
Epoch: 15, val loss: 7.343
Epoch: 16, val loss: 7.3279
Epoch: 17, val loss: 7.3169
Epoch: 18, val loss: 7.3055
Epoch: 19, val loss: 7.2949
Epoch: 20, val loss: 7.2855
Epoch: 21, val loss: 7.2787
Epoch: 22, val loss: 7.273
Epoch: 23, val loss: 7.2653
Epoch: 24, val loss: 7.258
Epoch: 25, val loss: 7.2542
Epoch: 26, val loss: 7.2478
Epoch: 27, val loss: 7.2431
Epoch: 28, val loss: 7.2385
Epoch: 29, val loss: 7.2364
Epoch: 30, val loss: 7.2302
Epoch: 31, val loss: 7.2252
Epoch: 32, val loss: 7.2208
Epoch: 33, val loss: 7.2159
Epoch: 34, val loss: 7.2129
Epoch: 35, val loss: 7.2111
Epoch:

In [28]:
torch.save(proj_head_text, data_dir / "proj_head_text_infonce.pt")
torch.save(proj_head_image, data_dir / "proj_head_image_infonce.pt")