In [None]:
import json
from collections import Counter
from dataclasses import dataclass

import anndata
import matplotlib.pyplot as plt
import numpy as np
import scanpy as sc
import tifffile
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import RandomHorizontalFlip, RandomVerticalFlip, Normalize, RandomCrop, Compose
from einops import rearrange
from kmeans_pytorch import kmeans

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
import multiplex_imaging_pipeline.utils as utils

In [None]:
scale = .1

In [None]:
metadata = json.load(open('../data/test_registration/HT397B1_v2/registered/metadata.json'))
metadata

In [None]:
fps = sorted(utils.listfiles('/data/estorrs/mushroom/data/test_registration/HT397B1_v2/registered',
                     regex='[0-9].h5ad$'))
fps

In [None]:
pct_expression = .02
pool = []
for fp in fps:
    a = sc.read_h5ad(fp)
    
    spot_count = (a.X.toarray()>0).sum(0)
    mask = spot_count > pct_expression * a.shape[0]
    a = a[:, mask]

    pool += a.var.index.to_list()
counts = Counter(pool)
channels = sorted([c for c, count in counts.items() if count==len(fps)])
len(channels), channels[:5]

In [None]:
slide_to_visium = {}
for fp in fps:
    sample = fp.split('/')[-1].replace('.h5ad', '')
    a = sc.read_h5ad(fp)
    label_to_barcode = {i+1:x for i, x in enumerate(a.obs.index)}
    barcode_to_label = {v:k for k, v in label_to_barcode.items()}
    a.uns['label_to_barcode'] = label_to_barcode
    a.uns['barcode_to_label'] = barcode_to_label
#     a.obsm['spatial_scaled'] = (a.obsm['spatial'] * scale).astype(np.int32)
    
    a = a[:, channels]
    sc.pp.log1p(a)
    
    slide_to_visium[sample] = a
slide_to_visium.keys()

In [None]:
visium_channels = list(channels)

In [None]:
a = next(iter(slide_to_visium.values()))
a

In [None]:
fps = sorted(utils.listfiles('/data/estorrs/mushroom/data/test_registration/HT397B1_v2/registered',
                     regex='ome.tiff$'))
fps

In [None]:
pool = []
for fp in fps:
    channels = utils.get_ome_tiff_channels(fp)
    channels = [utils.R_CHANNEL_MAPPING.get(c, c) for c in channels]
    pool += channels
Counter(pool).most_common()

In [None]:
channels = sorted([c for c, count in Counter(pool).items() if count==len(fps)])
channels

In [None]:
slide_to_multiplex = {}
for fp in fps:
    sample = fp.split('/')[-1].replace('.ome.tiff', '')
    cs, img = utils.extract_ome_tiff(fp, as_dict=False)
    img = torch.tensor(img)
    thumbnail = TF.resize(img, (int(scale * img.shape[-2]), int(scale * img.shape[-1])))
    thumbnail = thumbnail.to(torch.float32)
    
    cs = [utils.R_CHANNEL_MAPPING[c] for c in cs]
    idxs = [cs.index(c) for c in channels]
    thumbnail = thumbnail[idxs]
    
    slide_to_multiplex[sample] = thumbnail

In [None]:
multiplex_channels = list(channels)

In [None]:
slide_to_data = {k:v for k, v in slide_to_visium.items()}
slide_to_data.update(slide_to_multiplex)
slide_to_data.keys()

In [None]:
slide_to_dtype = {s:'visium' for s in slide_to_visium.keys()}
slide_to_dtype.update({s:'multiplex' for s in slide_to_multiplex.keys()})
slide_to_dtype

In [None]:
samples = sorted(slide_to_data.keys())
samples

In [None]:
@dataclass
class TransformArgs:
    top_left: tuple
    size: tuple
    vflip: bool
    hflip: bool

def format_expression(tiles, adatas, patch_size):
    # add batch dim if there is none
    if len(tiles.shape) == 2:
        tiles = tiles.unsqueeze(0)
    if isinstance(adatas, anndata.AnnData):
        adatas = [adatas]
    
    exp_imgs = []
    for tile, adata in zip(tiles, adatas):
        tile = rearrange(tile, '(ph h) (pw w) -> h w (ph pw)', ph=patch_size, pw=patch_size)
        x = torch.unique(tile, dim=-1)

        exp = torch.zeros(x.shape[0], x.shape[1], adata.shape[1], dtype=torch.float32)
        l2b = adata.uns['label_to_barcode']
        for i in range(x.shape[0]):
            for j in range(x.shape[1]):
                labels = x[i, j]
                labels = labels[labels!=0]
                if len(labels):
                    barcodes = [l2b[l.item()] for l in labels]
                    exp[i, j] = torch.tensor(adata[barcodes].X.mean(0))
        exp = rearrange(exp, 'h w c -> c h w')
        exp_imgs.append(exp)
    
    return torch.stack(exp_imgs)

def get_slide_to_labeled(slide_to_adata, crop=True, scale=.1):
    slide_to_labeled = {}
    for s, a in slide_to_adata.items():
        a.obsm['spatial_scaled'] = (a.obsm['spatial'] * scale).astype(np.int32)
        labeled_locations = np.zeros(
            (np.asarray(a.uns['he_rescaled_warped'].shape[:2]) * scale).astype(int), dtype=int)
        for barcode, (c, r) in zip(a.obs.index, a.obsm['spatial_scaled']):
            labeled_locations[r, c] = a.uns['barcode_to_label'][barcode]

        if crop:
            min_c, min_r = a.obsm['spatial_scaled'].min(0)
            max_c, max_r = a.obsm['spatial_scaled'].max(0)
        else:
            min_r, min_c = 0, 0
            max_r, max_c = labeled_locations.shape
        labeled_locations = labeled_locations[min_r:max_r, min_c:max_c]
        slide_to_labeled[s] = torch.tensor(labeled_locations)
    return slide_to_labeled

class TransformVisium(object):
    def __init__(
        self,
        size=(256, 256),
        patch_size=32,
        normalize=None,
    ):
        self.size = size
        self.patch_size = patch_size

        self.normalize = normalize if normalize is not None else nn.Identity()

    def __call__(self, x, adatas, transform_args=None):
        if transform_args is not None:
            x = TF.crop(
                x,
                transform_args.top_left[0],
                transform_args.top_left[1],
                transform_args.size[0],
                transform_args.size[1],
            )
            if transform_args.hflip:
                x = TF.hflip(x)
            if transform_args.vflip:
                x = TF.vflip(x)

        x = format_expression(x, adatas, patch_size=self.patch_size)
        x = self.normalize(x)
        return x
    
class TransformMultiplex(object):
    def __init__(
        self,
        size=(256, 256),
        normalize=None,
    ):
        self.size = size

        self.normalize = normalize if normalize is not None else nn.Identity()

    def __call__(self, x, transform_args=None):
        if transform_args is not None:
            x = TF.crop(
                x,
                transform_args.top_left[0],
                transform_args.top_left[1],
                transform_args.size[0],
                transform_args.size[1],
            )
            if transform_args.hflip:
                x = TF.hflip(x)
            if transform_args.vflip:
                x = TF.vflip(x)
        
        x = self.normalize(x)
        return x
    
    
class SlideDataset(Dataset):
    def __init__(
        self,
        order,
        slide_to_data,
        slide_to_dtype,
        multiplex_transform=None,
        visium_transform=None,
        scale=.1,
        size=(256, 256),
    ):
        self.scale = scale
        self.size = size
        self.slides = order
        self.slide_to_dtype = slide_to_dtype
        self.dtypes = sorted(set(slide_to_dtype.values()))
        
        self.slide_to_multiplex = {
            s:obj for s, obj in slide_to_data.items()
            if slide_to_dtype[s] == 'multiplex'
        }
        
        self.slide_to_visium_adata = {
            s:obj for s, obj in slide_to_data.items()
            if slide_to_dtype[s] == 'visium'
        }
        self.slide_to_visium_labeled = get_slide_to_labeled(
            self.slide_to_visium_adata, crop=False, scale=scale)
        
        multiplex_img = next(iter(self.slide_to_multiplex.values())) # (c h w)
        visium_img = next(iter(self.slide_to_visium_labeled.values())) # (h w)
        assert multiplex_img.shape[-2:] == visium_img.shape[-2:]

        self.multiplex_stacked = torch.stack([
            self.slide_to_multiplex[s]
            for s in self.slides
            if s in self.slide_to_multiplex
        ]) # (b c h w)
        self.visium_stacked = torch.stack([
            self.slide_to_visium_labeled[s]
            for s in self.slides
            if s in self.slide_to_visium_labeled
        ]) # (b h w)
        
        self.dtype_order = torch.tensor([self.dtypes.index(self.slide_to_dtype[s])
                            for s in self.slides])
        
#         index_order, counts = [], {dtype:0 for dtype in self.dtypes}
#         for s in self.slides:
#             dtype = self.slide_to_dtype[s]
#             index_order.append(
#                 [self.dtypes.index(dtype), counts[dtype]]
#             )
#             counts[dtype] += 1
#         self.index_order = torch.tensor(index_order, dtype=torch.long)
        
        self.multiplex_transform = multiplex_transform
        self.visium_transform = visium_transform
    
    def __len__(self):
        return np.iinfo(np.int64).max # make infinite
    
    def __getitem__(self, idx):
        transform_args = TransformArgs(
            top_left = (
                np.random.randint(0, self.multiplex_stacked.shape[-2] - self.size[-2]),
                np.random.randint(0, self.multiplex_stacked.shape[-1] - self.size[-1])
            ),
            size=self.size,
            vflip=np.random.rand() > .5,
            hflip=np.random.rand() > .5,   
        )
        
        multiplex_tiles = self.multiplex_transform(
            self.multiplex_stacked,
            transform_args=transform_args
        )
        visium_tiles = self.visium_transform(
            self.visium_stacked,
            [slide_to_visium[s] for s in self.slides if self.slide_to_dtype[s]=='visium'],
            transform_args=transform_args
        )

        return {
            'stacked_multiplex': multiplex_tiles, # (b c h w)
            'stacked_visium': visium_tiles, # (b c h/ps w/ps)
        }

In [None]:
size = (256, 256)
patch_size = 32

In [None]:
means = np.vstack([a.X.toarray().mean(0) for a in slide_to_visium.values()]).mean(0)
stds = np.vstack([a.X.toarray().std(0) for a in slide_to_visium.values()]).mean(0)
normalize = Normalize(means, stds)
visium_transform = TransformVisium(normalize=normalize, size=size, patch_size=patch_size)

In [None]:
means = torch.cat([x.mean(dim=(-2, -1)).unsqueeze(0) for x in slide_to_multiplex.values()]).mean(0)
stds = torch.cat([x.std(dim=(-2, -1)).unsqueeze(0) for x in slide_to_multiplex.values()]).mean(0)
normalize = Normalize(means, stds)
multiplex_transform = TransformMultiplex(normalize=normalize, size=size)

In [None]:
ds = SlideDataset(
    order=samples,
    slide_to_data=slide_to_data,
    slide_to_dtype=slide_to_dtype,
    multiplex_transform=multiplex_transform,
    visium_transform=visium_transform,
    scale=scale,
    size=size,
)

In [None]:
ds.dtype_order

In [None]:
ds.dtypes

In [None]:
ds.multiplex_stacked.shape, ds.visium_stacked.shape

In [None]:
ds.slides

In [None]:
d = ds[0]

In [None]:
d.keys()

In [None]:
d['stacked_multiplex'].shape, d['stacked_visium'].shape

In [None]:
fig, axs = plt.subplots(ncols=len(d['stacked_multiplex']))
for ax, img in zip(axs, d['stacked_multiplex']):
    ax.imshow(img[multiplex_channels.index('Pan-Cytokeratin')])
    ax.axis('off')

In [None]:
fig, axs = plt.subplots(ncols=len(d['stacked_visium']))
for ax, img in zip(axs, d['stacked_visium']):
    ax.imshow(img[list(a.var.index).index('EPCAM')])
    ax.axis('off')

In [None]:
dl = DataLoader(ds, batch_size=16, num_workers=1)

In [None]:
b = next(iter(dl))
b.keys()

In [None]:
b['stacked_multiplex'].shape, b['stacked_visium'].shape

In [None]:
from vit_pytorch import ViT
from einops.layers.torch import Rearrange

In [None]:
v = ViT(
    image_size = (256, 256),
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
)

In [None]:
v.to_patch_embedding

In [None]:
@dataclass
class ShapeArgs:
    patch_size: int
    n_channels: int

In [None]:
dtype_to_shape_args = {
    'multiplex': ShapeArgs(patch_size=32, n_channels=b['stacked_multiplex'].shape[-3]),
    'visium': ShapeArgs(patch_size=1, n_channels=b['stacked_visium'].shape[-3])
}

In [None]:
from vit_pytorch.slide_mae import SlideMAEV3

In [None]:
mae = SlideMAEV3(
    encoder=v,
    decoder_dim=512,
    n_slides=len(samples),
    dtypes=ds.dtypes,
    dtype_to_shape_args=dtype_to_shape_args,
    slide_dtype_order=ds.dtype_order,
)

In [None]:
mae = mae.cuda()

In [None]:
imgs = [
    b['stacked_multiplex'].cuda(),
    b['stacked_visium'].cuda()
]

In [None]:
recon_loss, triplet_loss, overall_loss, pixels = mae(imgs)

In [None]:
recon_loss, triplet_loss, overall_loss

In [None]:
def plot_recons(pixels, n_cols=6):
    fig, axs = plt.subplots(nrows=2, ncols=n_cols)
    for i in range(len(ds.dtypes)):
        dtype = ds.dtypes[i]
        pred_pixels = pixels[i]
        args = dtype_to_shape_args[dtype]
        pred_pixels = rearrange(pred_pixels, 'b (ph pw) (h w c) -> b c (h ph) (w pw)',
                                ph=8, pw=8, c=args.n_channels, h=args.patch_size, w=args.patch_size)

        if dtype == 'multiplex':
            x = pred_pixels[:n_cols, channels.index('Pan-Cytokeratin')].cpu().detach()
            for col in range(n_cols):
                axs[0, col].imshow(x[col])
                axs[0, col].axis('off')

        if dtype == 'visium':
            c_idx = list(next(iter(ds.slide_to_visium_adata.values())).var.index).index('EPCAM')
            x = pred_pixels[:n_cols, c_idx].cpu().detach()
            for col in range(n_cols):
                axs[1, col].imshow(x[col])
                axs[1, col].axis('off')

In [None]:
plot_recons(pixels)

In [None]:
iters = 100000
lr = 1e-4
opt = torch.optim.Adam(mae.parameters(), lr=lr)

In [None]:
!mkdir -p ../data/mae_v7

In [None]:
dl = DataLoader(ds, batch_size=16, num_workers=10)

In [None]:
for i, b in enumerate(dl):
    opt.zero_grad()
    
    imgs = [
        b['stacked_multiplex'].cuda(),
        b['stacked_visium'].cuda()
    ]
    recon_loss, triplet_loss, overall_loss, pixels = mae(imgs)
    overall_loss.backward()
    opt.step()
    
    print(i, recon_loss, triplet_loss, overall_loss)
    
    if i % 100 == 0:
        plot_recons(pixels)
        plt.title('predicted')
        plt.show()
        
    if i % 5000 == 0:
        torch.save(mae.state_dict(), f'../data/mae_v7/{i}iter.pt')
        
    if i == iters:
        break

In [None]:
# torch.save(v.state_dict(), f'../data/mae_v3/1500iter.pt')

In [None]:
size = (256, 256)

normalize = Normalize(means, stds)
transform = InferenceTransformVisium(size=(256, 256), patch_size=32, normalize=normalize)
inference_ds = InferenceSlideDatasetVisium(slide_to_adata, size=(256, 256), transform=transform, crop=False)

In [None]:
inference_dl = DataLoader(inference_ds, batch_size=32, shuffle=False)

In [None]:
x = inference_ds.image_from_tiles(inference_ds.slide_to_tiles[inference_ds.slides[0]],
                                  to_expression=True, adata=inference_ds.slide_to_adata['s0'])
x.shape

In [None]:
plt.imshow(x[channels.index('EPCAM')])

In [None]:
plt.imshow(x[channels.index('IL7R')])

In [None]:
d = inference_ds[0]
d.keys()

In [None]:
d['img'].shape

In [None]:
v = ViT(
    image_size = 8,
    patch_size = 1,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    channels=len(channels),
)

In [None]:
mae = SlideMAEV2(
    encoder = v,
    n_slides = len(slide_to_adata),
    decoder_dim = 512,      # paper showed good results with just 512
    decoder_depth = 6       # anywhere from 1 to 8
)

In [None]:
mae.load_state_dict(torch.load('../data/mae_v5/6000iter.pt'))

In [None]:
mae.eval()

In [None]:
all_encoded_tokens = torch.zeros(len(inference_ds), size[0] // 32, size[1] // 32, v.pos_embedding.shape[-1])
all_decoded_tokens = torch.zeros(len(inference_ds), size[0] // 32, size[1] // 32, mae.decoder_dim)
all_pred_patches = torch.zeros(len(inference_ds), len(channels), size[0] // 32, size[1] // 32)
bs = inference_dl.batch_size
with torch.no_grad():
    for i, b in enumerate(inference_dl):
        x, slide_idx = b['img'], b['slide_idx']
        if v.pos_embedding.is_cuda:
            x, slide_idx = x.to(v.pos_embedding.device), slide_idx.to(v.pos_embedding.device)
        
        encoded_tokens = mae.encode(x, slide_idx)
        decoded_tokens = mae.decode(encoded_tokens)
        pred_pixel_values = mae.to_pixels(decoded_tokens[:, 1:])

        encoded_tokens = rearrange(encoded_tokens[:, 1:], 'b (h w) d -> b h w d',
                                  h=size[0] // 32, w=size[1] // 32)
        decoded_tokens = rearrange(decoded_tokens[:, 1:], 'b (h w) d -> b h w d',
                                  h=size[0] // 32, w=size[1] // 32)
        pred_patches = rearrange(
            pred_pixel_values, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)',
            h=size[0] // 32, w=size[0] // 32, p1=1, p2=1, c=len(channels))
        
        all_encoded_tokens[i * bs:(i + 1) * bs] = encoded_tokens.cpu().detach()
        all_decoded_tokens[i * bs:(i + 1) * bs] = decoded_tokens.cpu().detach()
        all_pred_patches[i * bs:(i + 1) * bs] = pred_patches.cpu().detach()

In [None]:
all_encoded_tokens.shape, all_decoded_tokens.shape, all_pred_patches.shape

In [None]:
x = inference_ds.slide_from_tiles(all_pred_patches, 0, size=all_pred_patches.shape[-2:])
x.shape

In [None]:
genes = [
    'EPCAM', 'KRT18',
    'IL7R',
    'BGN', 'SPARC', 'VIM',
]
n_cols = 2
n_rows = len(genes) // n_cols + 1
rc = [(i, j) for i in range(n_rows) for j in range(n_cols)]

fig, axs = plt.subplots(ncols=n_cols, nrows=n_rows)
rc = [(i, j) for i in range(n_rows) for j in range(n_cols)]
for (row_idx, col_idx), c in zip(rc, genes):
    ax = axs[row_idx, col_idx]
    ax.imshow(x[channels.index(c)])
    ax.axis('off')
    ax.set_title(c)

In [None]:
all_encoded_tokens.shape

In [None]:
# testing regressing out tokens

In [None]:
import statsmodels.api as sm

In [None]:
# data = sm.datasets.scotland.load()
# data.exog = sm.add_constant(data.exog)

# gamma_model = sm.GLM(data.endog, data.exog, family=sm.families.Gamma())
# gamma_results = gamma_model.fit()



In [None]:
# gamma_results.summary()

In [None]:
from sklearn.linear_model import LinearRegression

In [None]:
x = all_encoded_tokens.clone()
x = rearrange(x, 'n h w d -> (n h w) d')
x = x.numpy()
x.shape

In [None]:
regressors = torch.zeros(all_encoded_tokens.shape[:-1], dtype=torch.long).unsqueeze(-1)
idx_to_var = [(r, c) for r in range(regressors.shape[1]) for c in range(regressors.shape[2])]
var_to_idx = {v:i for i, v in enumerate(idx_to_var)}
for i, (slide_idx, row_idx, col_idx) in enumerate(inference_ds.idx_to_coord):
    for r in range(regressors.shape[1]):
        for c in range(regressors.shape[2]):
            regressors[i, r, c] = var_to_idx[(r, c)]
target = rearrange(regressors, 'n h w d -> (n h w) d').squeeze()
target = torch.nn.functional.one_hot(target).numpy()
# target = target.numpy()
target.shape

In [None]:
# idxs = np.random.choice(np.arange(target.shape[0]), size=1000, replace=False)
# x = x[idxs]
# target = target[idxs]
# x.shape, target.shape

In [None]:
# lm = LinearRegression()
# lm.fit(x, target)
# residuals = lm.predict(x) - target
# residuals.shape

In [None]:
lm = LinearRegression()
lm.fit(target, x)
residuals = lm.predict(target) - x
residuals.shape

In [None]:
x = torch.tensor(residuals)
x.shape

In [None]:
clusterer = KMeans(n_clusters=20)
cluster_ids = clusterer.fit_transform(x.numpy())
cluster_ids = torch.tensor(cluster_ids.argmin(1))
cluster_ids.shape

In [None]:
cluster_imgs = rearrange(cluster_ids, '(n h w) -> n 1 h w',
                        n=all_encoded_tokens.shape[0], h=all_encoded_tokens.shape[1], w=all_encoded_tokens.shape[2])
labeled_img = inference_ds.slide_from_tiles(
    cluster_imgs, 0, size=(cluster_imgs.shape[-2], cluster_imgs.shape[-1])).squeeze().to(torch.long)

stacked_labeled = []
for i in range(len(slide_to_adata)):
    stacked_labeled.append(inference_ds.slide_from_tiles(
        cluster_imgs, i, size=(cluster_imgs.shape[-2], cluster_imgs.shape[-1])).squeeze().to(torch.long))
stacked_labeled = torch.stack(stacked_labeled)
stacked_labeled.shape

In [None]:
cmap = sns.color_palette('tab20')
for i, labeled in enumerate(stacked_labeled):
    plt.imshow(display_labeled_as_rgb(labeled, cmap=cmap))
    plt.show()

In [None]:
cmap

In [None]:
cluster = 20

In [None]:
to_labeled = torch.zeros(338, 8, 8, 1, dtype=torch.long)
idx_to_str = {}
idx = 0
for i, (slide_idx, row_idx, col_idx) in enumerate(inference_ds.idx_to_coord):
    for r in range(labeled.shape[1]):
        for c in range(labeled.shape[2]):
            idx_to_str[idx] = f'slide{slide_idx}_row{row_idx}_col{col_idx}_{r}_{c}'
            labeled[i, r, c] = idx
            idx += 1
labeled.shape

In [None]:
z = inference_ds.slide_to_tiles['s0']
z.shape

In [None]:
z = inference_ds.slide_to_labeled['s0']
z.shape

In [None]:
a = slide_to_adata['s3']
a

In [None]:
sc.pp.calculate_qc_metrics(a, inplace=True)

In [None]:
a

In [None]:
a.obs

In [None]:
sns.distplot(a.obs['n_genes_by_counts'])

In [None]:
sns.distplot(a.obs['n_genes_by_counts'])

In [None]:
x = all_encoded_tokens.clone()
x = rearrange(x, 'n h w d -> (n h w) d')
# x /= x.std(0)

In [None]:
from sklearn.cluster import KMeans
clusterer = KMeans(n_clusters=20)
cluster_ids = clusterer.fit_transform(x.numpy())
cluster_ids = torch.tensor(cluster_ids.argmin(1))
cluster_ids.shape

In [None]:
# num_clusters = 20
# cluster_ids, cluster_centers = kmeans(
#     X=x, num_clusters=num_clusters, distance='euclidean', device=torch.device('cuda:1'), tol=1.,
# )
# cluster_ids = cluster_ids.cpu().detach()

In [None]:
cluster_imgs = rearrange(cluster_ids, '(n h w) -> n 1 h w',
                        n=all_encoded_tokens.shape[0], h=all_encoded_tokens.shape[1], w=all_encoded_tokens.shape[2])
labeled_img = inference_ds.slide_from_tiles(
    cluster_imgs, 0, size=(cluster_imgs.shape[-2], cluster_imgs.shape[-1])).squeeze().to(torch.long)
labeled_img.shape

In [None]:
def display_labeled_as_rgb(labeled, cmap=None):
    if isinstance(labeled, torch.Tensor):
        labeled = labeled.numpy()
    cmap = sns.color_palette() if cmap is None else cmap
    labels = sorted(np.unique(labeled))
    if len(cmap) < len(labels):
        raise RuntimeError('cmap is too small')
    new = np.zeros((labeled.shape[0], labeled.shape[1], 3))
    for l in labels:
        c = cmap[l]
        new[labeled==l] = c
    return new

In [None]:
stacked_labeled = []
for i in range(len(slide_to_adata)):
    stacked_labeled.append(inference_ds.slide_from_tiles(
        cluster_imgs, i, size=(cluster_imgs.shape[-2], cluster_imgs.shape[-1])).squeeze().to(torch.long))
stacked_labeled = torch.stack(stacked_labeled)
stacked_labeled.shape

In [None]:
cmap = sns.color_palette('tab20')
for i, labeled in enumerate(stacked_labeled):
    plt.imshow(display_labeled_as_rgb(labeled, cmap=cmap))
    plt.show()

In [None]:
size = (256, 256)

In [None]:
shape = inference_ds.slide_to_labeled['s0'].shape
shape

In [None]:
labeled = torch.arange(shape[0] * shape[1])
labeled = rearrange(labeled, '(h w) -> h w', h=shape[0], w=shape[1])
plt.imshow(labeled)

In [None]:
tiles = inference_ds.to_tiles(labeled.unsqueeze(0))
tiles.shape

In [None]:
plt.imshow(tiles[6, 8, 0])

In [None]:
out = inference_ds.image_from_tiles(tiles).squeeze()
plt.imshow(out)

In [None]:
338, 8, 8, 1024

In [None]:
labeled = torch.zeros(338, 8, 8, 1, dtype=torch.long)
idx_to_str = {}
idx = 0
for i, (slide_idx, row_idx, col_idx) in enumerate(inference_ds.idx_to_coord):
    for r in range(labeled.shape[1]):
        for c in range(labeled.shape[2]):
            idx_to_str[idx] = f'slide{slide_idx}_row{row_idx}_col{col_idx}_{r}_{c}'
            labeled[i, r, c] = idx
            idx += 1
labeled.shape

In [None]:
idx_to_str[labeled[10, 1, 1, 0].item()]

In [None]:
pre_cluster_labeled = rearrange(labeled, 'n h w d -> (n h w) d')
pre_cluster_labeled.shape

In [None]:
post_cluster_labeled = pre_cluster_labeled.squeeze()
post_cluster_labeled.shape

In [None]:
post_cluster_image = rearrange(post_cluster_labeled, '(n h w) -> n 1 h w',
                        n=all_encoded_tokens.shape[0], h=all_encoded_tokens.shape[1], w=all_encoded_tokens.shape[2])
post_cluster_image.shape

In [None]:
slide_from_tiles_img = inference_ds.slide_from_tiles(
    post_cluster_image, 0, size=(post_cluster_image.shape[-2], post_cluster_image.shape[-1])).squeeze().to(torch.long)
slide_from_tiles_img.shape

In [None]:
idx = slide_from_tiles_img[0, 0].item()
idx_to_str[idx]

In [None]:
idx = slide_from_tiles_img[4, 0].item()
idx_to_str[idx]

In [None]:
idx = slide_from_tiles_img[33, 44].item()
idx_to_str[idx]

In [None]:
8 * 4

In [None]:
tiles = inference_ds.slide_to_tiles['s0']
tiles.shape

In [None]:
mask = torch.zeros(len(inference_ds.idx_to_coord), 8, 8, dtype=torch.bool)
for i, (slide_idx, row_idx, col_idx) in enumerate(inference_ds.idx_to_coord):
    slide = inference_ds.slides[slide_idx]
    labeled_tile = inference_ds.slide_to_tiles[slide][row_idx, col_idx, 0]
    labeled_tile = rearrange(labeled_tile, '(ph h) (pw w) -> h w (ph pw)', ph=32, pw=32)
    mask[i] = labeled_tile.sum(dim=-1) > 0
mask.shape

In [None]:
mask[60]

In [None]:
# slide_to_hulls = {}
# for slide in inference_ds.slides:
#     x = inference_ds.image_from_tiles(inference_ds.slide_to_tiles[slide]).squeeze()
#     z = 52 / x.shape[0]
#     pts = (x!=0).argwhere().to(torch.float32)
#     pts *= z
#     pts = pts.to(torch.long)
    
#     mask = np.zeros((52, 52))
#     for r, c in pts:
#         mask[r, c] = True
#     slide_to_hulls[slide] = mask
# #     hull = convex_hull_image(mask)
# #     slide_to_hulls[slide] = hull
# for slide, hull in slide_to_hulls.items():
#     plt.imshow(hull)
#     plt.title(slide)
#     plt.show()

In [None]:
zzz = mask.unsqueeze(1)
zzz_img = inference_ds.slide_from_tiles(
    zzz, 0, size=(zzz.shape[-2], zzz.shape[-1])).squeeze().to(torch.long)
zzz_img.shape

In [None]:
plt.imshow(zzz_img)

In [None]:
x = all_encoded_tokens.clone()
x[~mask] = torch.zeros(x.shape[-1])
x = rearrange(x, 'n h w d -> (n h w) d')
x /= x.std(0)

In [None]:
(x.sum(-1)==0).sum()

In [None]:
x

In [None]:
from sklearn.cluster import KMeans
clusterer = KMeans(n_clusters=10)
cluster_ids = clusterer.fit_transform(x.numpy())
cluster_ids.shape

In [None]:
cluster_ids = cluster_ids.argmin(1)
np.unique(cluster_ids, return_counts=True)

In [None]:
cluster_ids = torch.tensor(cluster_ids)

In [None]:
# num_clusters = 5
# cluster_ids, cluster_centers = kmeans(
#     X=x, num_clusters=num_clusters, distance='euclidean', device=torch.device('cuda:1'), tol=1.,
# )

In [None]:
# cluster_ids, cluster_centers = cluster_ids.cpu().detach(), cluster_centers.cpu().detach()

In [None]:
# cluster_ids.shape, cluster_centers.shape

In [None]:
cluster_imgs = rearrange(cluster_ids, '(n h w) -> n 1 h w',
                        n=all_encoded_tokens.shape[0], h=all_encoded_tokens.shape[1], w=all_encoded_tokens.shape[2])
cluster_imgs.shape

In [None]:
labeled_img = inference_ds.slide_from_tiles(
    cluster_imgs, 0, size=(cluster_imgs.shape[-2], cluster_imgs.shape[-1])).squeeze().to(torch.long)
labeled_img.shape

In [None]:
labeled_img

In [None]:
def display_labeled_as_rgb(labeled, cmap=None):
    if isinstance(labeled, torch.Tensor):
        labeled = labeled.numpy()
    cmap = sns.color_palette() if cmap is None else cmap
    labels = sorted(np.unique(labeled))
    if len(cmap) < len(labels):
        raise RuntimeError('cmap is too small')
    new = np.zeros((labeled.shape[0], labeled.shape[1], 3))
    for l in labels:
        c = cmap[l]
        new[labeled==l] = c
    return new

In [None]:
import seaborn as sns

In [None]:
# cmap = sns.color_palette('tab20') + sns.color_palette('tab20b') + sns.color_palette('tab20c')
cmap = sns.color_palette('tab20')
plt.imshow(display_labeled_as_rgb(labeled_img, cmap=cmap))

In [None]:
stacked_labeled = []
for i in range(len(slide_to_adata)):
    stacked_labeled.append(inference_ds.slide_from_tiles(
        cluster_imgs, i, size=(cluster_imgs.shape[-2], cluster_imgs.shape[-1])).squeeze().to(torch.long))
stacked_labeled = torch.stack(stacked_labeled)
stacked_labeled.shape

In [None]:
cmap = sns.color_palette('tab20')
for i, labeled in enumerate(stacked_labeled):
    plt.imshow(display_labeled_as_rgb(labeled, cmap=cmap))
    plt.show()

In [None]:
sc.pl.spatial(slide_to_adata['s0'])

In [None]:
sc.pl.spatial(slide_to_adata['s3'])

In [None]:
sc.pl.spatial(slide_to_adata['s3'], color='EPCAM')