# contrastive learning on two seperate modalities

In [1]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ['TENSORBOARD_BINARY'] = '/p/project1/hai_fzj_bda/koenig8/jupyter/kernels/contrastive_learn/bin/tensorboard'

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torch.optim import SGD, Adam
import numpy as np
from torch.utils.data import Dataset
from collections import defaultdict
import random

import scanpy as sc
from pytorch_lightning.callbacks import TQDMProgressBar
%load_ext tensorboard

In [2]:
class SimpleModel(nn.Module):
   def __init__(self, out_dim = 512, in_dim=512, hidden_dim = 512):
       super().__init__()

       # add mlp projection head
       self.model = nn.Sequential(
           nn.Linear(in_features=in_dim, out_features=hidden_dim),
           nn.ReLU(),
           nn.Linear(in_features=hidden_dim, out_features=out_dim)
       )

   def forward(self, x):
       return self.model(x)

In [24]:
class SimCLR(pl.LightningModule):
   """
   Vanilla Contrastive loss, also called InfoNceLoss as in SimCLR paper
   """
   def __init__(self, hidden_dim, lr, temperature, weight_decay, max_epochs=500, log_every = 2000, histo_size = 368, st_size = 50):
       super().__init__()
       self.save_hyperparameters()
       assert self.hparams.temperature > 0.0, 'The temperature must be a positive float!'
       self.histo_model = SimpleModel(in_dim = histo_size, hidden_dim = hidden_dim, out_dim = 4 * hidden_dim)
       self.st_model = SimpleModel(in_dim = st_size, hidden_dim = hidden_dim, out_dim = 4 * hidden_dim)
       self.log_every = log_every

   def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(),
                                lr=self.hparams.lr,
                                weight_decay=self.hparams.weight_decay)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                            T_max=self.hparams.max_epochs,
                                                            eta_min=self.hparams.lr/50)
        return [optimizer], [lr_scheduler]

   def info_nce_loss(self, batch, batch_idx, mode = "train"):
        X_st = batch['a_batch']
        classes_st = batch['a_is_positive']
        X_histo = batch['b_batch']
        classes_histo = batch['b_is_positive']
       
        X_st = F.normalize(X_st, p=2, dim=1)
        X_histo = F.normalize(X_histo, p=2, dim=1)
        emb_histo = self.histo_model(X_histo)
        emb_st = self.st_model(X_st)
       
        # Index the embeddings
        pos_st = emb_st[classes_st]
        pos_histo = emb_histo[classes_histo]
        neg_st = emb_st[~classes_st]
        neg_histo = emb_histo[~classes_histo]
        
        # Concatenate positives and negatives
        positives = torch.cat([pos_st, pos_histo], dim=0)
        negatives = torch.cat([neg_st, neg_histo], dim=0)
       
        nce_parts = []
        similarities = []
        features = ((positives, positives), (positives, negatives))
        for i in range(2):
            # Repeat for positives and negatives
            feat1, feat2 = features[i]
            
            cos_sim = F.cosine_similarity(feat1[:,None,:], feat2[None,:,:], dim=-1)
            if i == 0:
                # Remove the diagonal from positive to positive comparison
                self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
                cos_sim.masked_fill_(self_mask, -9e15)
            cos_sim = cos_sim / self.hparams.temperature
            
            nce_parts.append(torch.logsumexp(cos_sim, dim=-1))
            similarities.append(cos_sim)
            
        nll = -nce_parts[0] + nce_parts[1]
        nll = nll.mean()

        # Logging loss
        self.log(mode+'_loss', nll)
        # print(similarities[0].shape, similarities[1].shape)  => 60,60 and 60,240 
        # More in depth logging (less frequent)
        if mode == "val" or (batch_idx % self.log_every == 0):
            n_pos = positives.shape[0]
            n_neg = negatives.shape[0]
            comb_sim = torch.cat([
                similarities[0].sum(dim = 1) / (n_pos - 1),  # Mean cosine distance of positive samples
                similarities[1].mean(dim = 0)  # Again fot negative samples
            ], dim = -1)
            sim_argsort = comb_sim.argsort(dim=-1, descending=False) 
            classes = torch.cat([torch.ones(n_pos), torch.zeros(n_neg)], dim = -1).to(sim_argsort.device)

            assert sim_argsort.max() < len(classes)
            top_classes = classes[sim_argsort] 
            self.log(mode+'_acc_top1', top_classes[0].float())
            self.log(mode+'_acc_top5', top_classes[:5].float().mean())
            self.log(mode+'_acc_mean_pos', top_classes[:n_pos].float().mean())
            
        return nll

   def training_step(self, batch, batch_idx):
        return self.info_nce_loss(batch, batch_idx, mode='train')

   def validation_step(self, batch, batch_idx):
        self.info_nce_loss(batch, batch_idx, mode='val')

In [4]:
feat1 = torch.Tensor(np.array([(1000,9000,3000) , (1,2,3), (1,5,3), (5,6,7)]))
feat2 = torch.Tensor(np.array([(1,2,3) , (2000,2000,1000)]))
comb_sim = F.cosine_similarity(feat1[:,None,:], feat2[None,:,:], dim=-1).mean(axis = 1)
comb_sim#.argsort(dim=-1, descending=True)#.argmin(dim=-1)

tensor([0.7941, 0.9009, 0.8743, 0.9450])

In [5]:
sim_argsort = comb_sim.argsort(dim=-1, descending=False) 
classes = torch.cat([torch.ones(2), torch.zeros(2)], dim = -1)
classes[sim_argsort]

tensor([1., 0., 1., 0.])

In [6]:
torch.eye(10, dtype=torch.bool)

tensor([[ True, False, False, False, False, False, False, False, False, False],
        [False,  True, False, False, False, False, False, False, False, False],
        [False, False,  True, False, False, False, False, False, False, False],
        [False, False, False,  True, False, False, False, False, False, False],
        [False, False, False, False,  True, False, False, False, False, False],
        [False, False, False, False, False,  True, False, False, False, False],
        [False, False, False, False, False, False,  True, False, False, False],
        [False, False, False, False, False, False, False,  True, False, False],
        [False, False, False, False, False, False, False, False,  True, False],
        [False, False, False, False, False, False, False, False, False,  True]])

# Dataset
From ChatGPT: precomputing indices is totally fine — especially if you have a lot of data. You're trading variance for speed and consistency, and with large enough datasets, that’s often a win. 
But if it is not diverse enough, i can look into reshuffling data every k epochs or doing dynamic sampling or including multiple samples per anchor

In [7]:
class PairedContrastiveDataset(Dataset):
    def __init__(self, embeddings_a, labels_a, embeddings_b, labels_b, n_pos=1, n_neg=1, seed=42):
        """
        embeddings_a: Tensor [N, D] for modality A (e.g. ST)
        labels_a: Tensor [N] with integer class labels
        embeddings_b: Tensor [M, D] for modality B (e.g. histo)
        labels_b: Tensor [M] with integer class labels
        n_pos: number of positive samples to draw per anchor
        n_neg: number of negative samples to draw per anchor
        """
        super().__init__()
        # Convert to tensors
        self.emb_a = torch.tensor(embeddings_a, dtype = torch.float32)
        self.labels_a = torch.tensor(labels_a, dtype = torch.int32)
        self.emb_b =torch.tensor(embeddings_b, dtype = torch.float32)
        self.labels_b = torch.tensor(labels_b, dtype = torch.int32)
        self.n_pos = n_pos
        self.n_neg = n_neg
        self.rng = random.Random(seed)

        assert len(self.emb_a) == len(self.labels_a)
        assert len(self.emb_b) == len(self.labels_b)

        self.data = []  # List of dicts with precomputed sample indices

        # Build class index lookup
        self.class_to_indices_a = defaultdict(list)
        self.class_to_indices_b = defaultdict(list)
        for i, label in enumerate(labels_a.tolist()):
            self.class_to_indices_a[label].append(i)
        for i, label in enumerate(labels_b.tolist()):
            self.class_to_indices_b[label].append(i)

        all_labels = sorted(set(labels_a.tolist()) | set(labels_b.tolist()))

        # Precompute samples for each item in modality A
        for anchor_idx in range(len(self.emb_a)):
            anchor_label = labels_a[anchor_idx].item()
            self.data.append(self._make_sample(
                anchor_idx=anchor_idx,
                anchor_mod='a',
                anchor_label=anchor_label
            ))

        # Precompute samples for each item in modality B
        for anchor_idx in range(len(self.emb_b)):
            anchor_label = labels_b[anchor_idx].item()
            self.data.append(self._make_sample(
                anchor_idx=anchor_idx,
                anchor_mod='b',
                anchor_label=anchor_label
            ))

    def _make_sample(self, anchor_idx, anchor_mod, anchor_label):
        # Determine anchor embedding source
        if anchor_mod == 'a':
            same_class_to_indices = self.class_to_indices_a
            other_class_to_indices = self.class_to_indices_b
        else:
            same_class_to_indices = self.class_to_indices_b
            other_class_to_indices = self.class_to_indices_a

        # Positive samples from same modality (excluding anchor)
        pos_same_mod = [
            idx for idx in same_class_to_indices[anchor_label]
            if idx != anchor_idx
        ]
        pos_same = random.sample(pos_same_mod, min(self.n_pos - 1, len(pos_same_mod)))

        # Negative samples from same modality
        neg_same = []
        for label, indices in same_class_to_indices.items():
            if label != anchor_label:
                neg_same.extend(indices)
        neg_same = random.sample(neg_same, min(self.n_neg, len(neg_same)))

        # Positive samples from other modality
        pos_other = random.sample(other_class_to_indices[anchor_label],
                                  min(self.n_pos, len(other_class_to_indices[anchor_label])))

        # Negative samples from other modality
        neg_other = []
        for label, indices in other_class_to_indices.items():
            if label != anchor_label:
                neg_other.extend(indices)
        neg_other = random.sample(neg_other, min(self.n_neg, len(neg_other)))

        return {
            'anchor_mod': anchor_mod,
            'anchor_idx': anchor_idx,
            'pos_same': pos_same,
            'neg_same': neg_same,
            'pos_other': pos_other,
            'neg_other': neg_other
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        # Boolean masks: 1 for positive, 0 for negative
        same_mod_is_positive = torch.tensor(
            [1] * (len(sample['pos_same']) + 1) + [0] * len(sample['neg_same']), dtype=torch.bool
        )
        other_mod_is_positive = torch.tensor(
            [1] * len(sample['pos_other']) + [0] * len(sample['neg_other']), dtype=torch.bool
        )

        if sample['anchor_mod'] == 'a':
            anchor = [sample['anchor_idx'], ]
            a_batch = torch.stack([self.emb_a[i] for i in sample['pos_same'] + anchor + sample['neg_same'] ])
            b_batch = torch.stack([self.emb_b[i] for i in sample['pos_other'] + sample['neg_other']])
            a_is_positive = same_mod_is_positive
            b_is_positive = other_mod_is_positive
        else:
            anchor = [sample['anchor_idx'],]
            b_batch = torch.stack([self.emb_b[i] for i in sample['pos_same'] + anchor + sample['neg_same']])
            a_batch = torch.stack([self.emb_a[i] for i in sample['pos_other'] + sample['neg_other']])  
            b_is_positive = same_mod_is_positive
            a_is_positive = other_mod_is_positive

        return {
            'a_batch': a_batch,
            'a_is_positive': a_is_positive,
            'b_batch': b_batch,
            'b_is_positive': b_is_positive,
        }


### Load the data and prepare it for pytorch lightning

In [8]:
path = "/p/project1/hai_fzj_bda/koenig8/ot/data/"
adata_st = sc.read_h5ad(os.path.join(path, "adata_st.h5ad"))
adata_histo = sc.read_h5ad(os.path.join(path, "adata_histo.h5ad"))
adata_st, adata_histo

(AnnData object with n_obs × n_vars = 50000 × 50
     obs: 'patch_id', 'brain_area', 'patchsize', 'x_st', 'y_st', 'z_st', 'brain_section_label', 'section'
     uns: 'neighbors', 'umap'
     obsm: 'X_umap', 'brain_area_onehot', 'brain_area_similarities', 'pca_embedding', 'pca_plus_slides', 'pca_plus_slides_scaled'
     obsp: 'connectivities', 'distances',
 AnnData object with n_obs × n_vars = 190659 × 1536
     obs: 'image_id', 'patchsize', 'center_ccf', 'pixel_coord', 'distance', 'nearest_ST', 'nearest_cell_id', 'target_atlas_plate', 'distance_new', 'x', 'y', 'z', 'x_st', 'y_st', 'z_st', 'image_nr', 'brain_area', 'group', 'slice', 'in_sample'
     obsm: 'brain_area_onehot', 'brain_area_similarities', 'uni_embedding', 'uni_pca_95', 'uni_pca_plus_coords')

In [9]:
adata_st.obsm["brain_area_onehot"].toarray().nonzero()[-1]

array([ 5,  5,  5, ..., 18,  5,  5])

In [10]:
embeddings_a=adata_st.obsm["pca_embedding"]
labels_a=adata_st.obsm["brain_area_onehot"].toarray().nonzero()[-1]
embeddings_b=adata_histo.obsm["uni_pca_95"]
labels_b=adata_histo.obsm["brain_area_onehot"].toarray().nonzero()[-1]
seed = 42

# For st, exclude 10 slides from the train set
val_slides = list(adata_st.obs["brain_section_label"].unique()[:10])
print("Validation slides", val_slides)
st_cond = adata_st.obs["brain_section_label"].isin(val_slides).to_numpy()

# For histo, exclude 20% of the train set
rng = np.random.default_rng(seed=seed) 
sample_size = int(embeddings_b.shape[0] / 5)
sample = rng.choice(embeddings_b.shape[0], size=sample_size, replace=False)
histo_cond = np.zeros(shape=(embeddings_b.shape[0]), dtype = bool)
histo_cond[sample] = True

def make_set(n_pos, n_neg, random_seed, mode = "train"):
    if mode == "train":
        _st_cond = ~st_cond
        _histo_cond = ~histo_cond
    elif mode == "val": 
        _st_cond = st_cond
        _histo_cond = histo_cond
    return PairedContrastiveDataset(
        embeddings_a=embeddings_a[_st_cond], 
        labels_a=labels_a[_st_cond], 
        embeddings_b=embeddings_b[_histo_cond], 
        labels_b=labels_b[_histo_cond], 
        n_pos=n_pos,
        n_neg=n_neg,
        seed=random_seed
    )

Validation slides ['Zhuang-ABCA-1.079', 'Zhuang-ABCA-1.089', 'Zhuang-ABCA-1.085', 'Zhuang-ABCA-1.077', 'Zhuang-ABCA-1.087', 'Zhuang-ABCA-1.049', 'Zhuang-ABCA-1.059', 'Zhuang-ABCA-1.069', 'Zhuang-ABCA-1.072', 'Zhuang-ABCA-1.082']


In [11]:
train_dataset = make_set(
    n_pos=30,
    n_neg=120,
    random_seed=seed, mode = "train"
)
val_dataset = make_set(
    n_pos=30,
    n_neg=120,
    random_seed=seed, mode = "val"
)
len(train_dataset), len(val_dataset)

(197302, 43357)

In [12]:
[x.shape for  _, x in train_dataset[0].items()]

[torch.Size([150, 50]),
 torch.Size([150]),
 torch.Size([150, 368]),
 torch.Size([150])]

In [13]:
[x for  x,_ in train_dataset[0].items()]

['a_batch', 'a_is_positive', 'b_batch', 'b_is_positive']

In [14]:
val_dataset[0].keys()

dict_keys(['a_batch', 'a_is_positive', 'b_batch', 'b_is_positive'])

==> This works nicely to fetch data, but takes quite long to build the dataset

## Use the dataloader with the model

In [15]:
CHECKPOINT_PATH = "/p/project1/hai_fzj_bda/koenig8/cl/simple_model"
NUM_WORKERS = int(os.cpu_count() * 0.75)  # Reserve some workers
torch.set_float32_matmul_precision('medium')

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)
print("Number of workers:", NUM_WORKERS)

[rank: 0] Global seed set to 42


Device: cuda:0
Number of workers: 72


In [16]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

ERROR: Could not find '/p/project1/hai_fzj_bda/koenig8/jupyter/kernels
/contrastive_learn/bin/tensorboard' (set by the `TENSORBOARD_BINARY`
environment variable). Please ensure that your PATH contains an
executable `tensorboard` program, or explicitly specify the path to a
TensorBoard binary by setting the `TENSORBOARD_BINARY` environment
variable.

In [17]:
def train_simclr(max_epochs=500, **kwargs):
    progress_bar = TQDMProgressBar(refresh_rate=2000)
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, 'SimCLR'),
                         accelerator="gpu" if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         max_epochs=max_epochs,
                         limit_train_batches=0.2,  # To only use 20% each epoch
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode='max', monitor='val_acc_top5'),
                                    LearningRateMonitor('epoch'), progress_bar])
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, 'SimCLR.ckpt')
    if os.path.isfile(pretrained_filename):
        print(f'Found pretrained model at {pretrained_filename}, loading...')
        model = SimCLR.load_from_checkpoint(pretrained_filename) # Automatically loads the model with the saved hyperparameters
    else:
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True,
                                       pin_memory=True, num_workers=NUM_WORKERS)
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False,
                                     pin_memory=True, num_workers=NUM_WORKERS)
        pl.seed_everything(42) # To be reproducable
        model = SimCLR(max_epochs=max_epochs, **kwargs)
        trainer.fit(model, train_loader, val_loader)
        model = SimCLR.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training

    return model

In [None]:
simclr_model = train_simclr(hidden_dim=128,
                            lr=5e-4,
                            temperature=0.07,
                            weight_decay=1e-4,
                            max_epochs=500)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[rank: 0] Global seed set to 42
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3     ]

  | Name        | Type        | Params
--------------------------------------------
0 | histo_model | SimpleModel | 113 K 
1 | st_model    | SimpleModel | 72.6 K
--------------------------------------------
185 K     Trainable params
0         Non-trainable params
185 K     Total params
0.743     Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Epoch 0:  46%|████▌     | 38000/82817 [06:47<08:00, 93.24it/s, loss=-8.24, v_num=1.13e+7]