In [66]:
import json
import re
import os
from pathlib import Path

import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
import tifffile
import torch
import torch.nn.functional as T
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from einops import rearrange

In [2]:
%load_ext autoreload

In [3]:
%autoreload 2

In [4]:
import mushroom.utils as utils

In [13]:
s0 = torch.load('/data/estorrs/DINO-extended/data/runs/HT397B1/metagene_dino_v2/results/s0_50knn.pt')
s1 = torch.load('/data/estorrs/DINO-extended/data/runs/HT397B1_v1/codex/run_1/results/s1_20knn.pt')
s2 = torch.load('/data/estorrs/DINO-extended/data/runs/HT397B1_v1/codex/run_1/results/s2_20knn.pt')
s3 = torch.load('/data/estorrs/DINO-extended/data/runs/HT397B1/metagene_dino_v2/results/s3_50knn.pt')
for k, v in s0.items():
    print(k, v.shape)

clustered_patches torch.Size([192, 188])
patch_centroids torch.Size([50, 1024])
patch_embs torch.Size([1024, 192, 188])


In [14]:
for s in [s0, s1, s2, s3]:
    print(s['patch_embs'].shape)

torch.Size([1024, 192, 188])
torch.Size([1024, 240, 236])
torch.Size([1024, 240, 236])
torch.Size([1024, 192, 188])


In [15]:
x1.shape[1:]

torch.Size([240, 236])

In [16]:
x0, x1, x2, x3 = s0['patch_embs'], s1['patch_embs'], s2['patch_embs'], s3['patch_embs']
x0 = TF.resize(x0, x1.shape[1:], antialias=False)
x3 = TF.resize(x3, x1.shape[1:], antialias=False)
x0.shape, x1.shape, x2.shape, x3.shape

(torch.Size([1024, 240, 236]),
 torch.Size([1024, 240, 236]),
 torch.Size([1024, 240, 236]),
 torch.Size([1024, 240, 236]))

In [23]:
slices = [x0, x1, x2, x3]

In [25]:
slices = [rearrange(x, 'c h w -> (h w) c') for x in slices]
slices[0].shape

torch.Size([56640, 1024])

In [46]:
def normalize(x):
    x -= x.mean(dim=0)
    x /= x.std(dim=0)
    return x

In [54]:
slices = [normalize(x) for x in slices]

In [92]:
class SliceTripletDataset(Dataset):
    def __init__(self, slices):
        self.slices = slices
        self.idxs = torch.arange(len(self.slices[0]))
    
    def __len__(self):
        return len(self.idxs)
    
    def __getitem__(self, idx):
        neg_patch_idx = torch.randint(0, len(self.idxs), (1,)).item()
        
        pool = np.arange(len(self.slices))
        anchor_slide_idx = np.random.choice(pool)
        pool = np.delete(pool, anchor_slide_idx)
        pos_slide_idx = np.random.choice(pool)
        neg_slide_idx = np.random.choice(pool)
        
        return {
            'anchor': self.slices[anchor_slide_idx][idx],
            'pos': self.slices[pos_slide_idx][idx],
            'neg': self.slices[neg_slide_idx][neg_patch_idx]
        }

In [93]:
ds = SliceTripletDataset(slices)

In [94]:
ds[0]

{'anchor': tensor([-0.7145, -0.5873,  1.8824,  ..., -2.1483,  0.8828,  1.9796]),
 'pos': tensor([-0.6299,  0.7921,  1.2935,  ..., -0.6145,  0.9024,  1.0062]),
 'neg': tensor([ 0.9497,  0.0120, -1.4810,  ..., -0.5854, -0.6937, -1.2840])}

In [95]:
batch_size = 256
dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=10)

In [96]:
b = next(iter(dl))

In [87]:
class SliceCluster(torch.nn.Module):
    def __init__(self, in_dim, emb_dim=64):
        super().__init__()
        self.in_dim = in_dim
        self.emb_dim = emb_dim
        
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(self.in_dim, self.in_dim // 2),
            torch.nn.BatchNorm1d(self.in_dim // 2),
            torch.nn.ReLU(),
            torch.nn.Linear(self.in_dim // 2, self.in_dim // 4),
            torch.nn.BatchNorm1d(self.in_dim // 4),
            torch.nn.ReLU(),
            torch.nn.Linear(self.in_dim // 4, self.in_dim // 8),
            torch.nn.BatchNorm1d(self.in_dim // 8),
            torch.nn.ReLU(),
            torch.nn.Linear(self.in_dim // 8, self.emb_dim)
        )
        
        self.loss = torch.nn.TripletMarginLoss()
        
    def calculate_loss(self, anchor, pos, neg):
        return self.loss(anchor, pos, neg)
        
    def forward(self, anchor, pos, neg):
        anchor = self.encoder(anchor)
        pos = self.encoder(pos)
        neg = self.encoder(neg)
        
        return anchor, pos, neg

In [88]:
model = SliceCluster(len(ds[0]['anchor']))

In [99]:
anchor, pos, neg = b['anchor'], b['pos'], b['neg']
anchor, pos, neg = model(anchor, pos, neg)
anchor.shape

torch.Size([256, 64])

In [100]:
model.calculate_loss(anchor, pos, neg)

tensor(1.0697, grad_fn=<MeanBackward0>)