In [1]:
import logging
import os
import re
from typing import Optional, List

import scanpy as sc 
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import wandb
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler, BatchSampler
from torchvision.transforms import ColorJitter, Normalize, RandomHorizontalFlip, RandomVerticalFlip, RandomAdjustSharpness
from timm import create_model
from einops import rearrange, reduce
from skimage.color import label2rgb
from skimage.measure import regionprops_table


note that you need to make bug fix to diffusers v0.3.0

in ~/.local/lib/python3.9/site-packages/diffusers/models/unet_blocks.py you need to change out_channels parameter in DownEncoderBlock2D to make the unet work for >2 downsamples

out_channels=in_channels if add_downsample else out_channels,


In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mestorrs[0m ([33mtme-st[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

#### load datasets

In [3]:
cell_type_mapping = {
    'Epithelial': ['KRT18'],
    'T cell': ['IL7R'],
    'Macrophage': ['CD14', 'FCGR3A'],
    'DC': ['HLA-DRA', 'ITGAX'],
    'Fibroblast': ['BGN'],
    'Endothelial': ['PECAM1'],
    'Acinar': ['PRSS1'],
    'Islet': ['INS']
}

genes = sorted([v for vs in cell_type_mapping.values() for v in vs])

In [4]:
fps = [os.path.join('../data/pdac', x) for x in os.listdir('../data/pdac')]
fps

['../data/pdac/HT270P1-S1H1U1',
 '../data/pdac/HT264P1-S1H2U1',
 '../data/pdac/HT424P1-H3A1U1',
 '../data/pdac/HT434P1-S1H3U1',
 '../data/pdac/HT427P1-S1H1U1',
 '../data/pdac/HT416P1-S1H1A1U1']

In [5]:
train_samples = [
    'HT270P1-S1H1U1',
#     'HT424P1-H3A1U1',
# #     'HT434P1-S1H3U1',
#     'HT427P1-S1H1U1',
]

val_samples = [
    'HT264P1-S1H2U1',
#     'HT416P1-S1H1A1U1'
]

In [6]:
# def scale_expression_counts(a, scale=100.):
#     X = a.X.toarray()
# #     X = np.log1p(X.toarray())
# #      0-1
# #     X -= np.expand_dims(X.min(axis=0), 0)
# #     X /= np.expand_dims(X.max(axis=0), 0)
# #     X *= scale
# #     X = X.astype(np.int32).astype(np.float32)
#     a.X = X.astype(np.int32)
    
#     return a

In [19]:
sample_to_adata = {}

for fp in fps:
    s = fp.split('/')[-1]
    if s in train_samples or s in val_samples:
        a = sc.read_h5ad(os.path.join(fp, 'adata.h5ad'))
        a.X = a.layers['counts'].toarray().astype(np.int32) # using raw counts
        a = a[:, genes]
#         a = scale_expression_counts(a)
        sample_to_adata[s] = a

        print(s, a.shape)

HT270P1-S1H1U1 (3940, 10)
HT264P1-S1H2U1 (3234, 10)


In [20]:
sample_to_train_adata = {k:v for k, v in sample_to_adata.items() if k in train_samples}
sample_to_val_adata = {k:v for k, v in sample_to_adata.items() if k in val_samples}

In [21]:
a = next(iter(sample_to_train_adata.values()))
a

View of AnnData object with n_obs × n_vars = 3940 × 10
    obs: 'in_tissue', 'array_row', 'array_col', 'clusters', 'spot_index'
    var: 'gene_ids', 'feature_types', 'genome', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'full_res_he', 'hvg', 'leiden', 'log1p', 'neighbors', 'nuclei_segmentation_1X_notrim', 'nuclei_segmentation_1X_trimmed', 'pca', 'rescaled_he', 'rescaled_spot_masks', 'rescaled_spot_metadata', 'segmented_nuclei_coords', 'segmented_nuclei_coords_1X_notrim', 'segmented_nuclei_coords_1X_trimmed', 'spatial', 'trimmed', 'umap'
    obsm: 'X_pca', 'X_umap', 'spatial', 'spatial_16X_notrim', 'spatial_16X_trimmed', 'spatial_1X_notrim', 'spatial_1X_trimmed', 'spatial_2X_notrim', 'spatial_2X_trimmed', 'spatial_4X_notrim', 'spatial_4X_trimmed', 'spatial_8X_notrim', 'spatial_8X_trimmed', 'spatial_trimmed'
    varm: 'PCs'
    layers: 'counts'
    obsp: 'connectivities', 'distances'

In [22]:
a.obs

Unnamed: 0,in_tissue,array_row,array_col,clusters,spot_index
AAACAACGAATAGTTC-1,1,0,16,17,1
AAACAAGTATCTCCCA-1,1,50,102,3,2
AAACAATCTACTAGCA-1,1,3,43,2,3
AAACACCAATAACTGC-1,1,59,19,5,4
AAACAGAGCGACTCCT-1,1,14,94,14,5
...,...,...,...,...,...
TTGTTGTGTGTCAAGA-1,1,31,77,17,3936
TTGTTTCACATCCAGG-1,1,58,42,5,3937
TTGTTTCATTAGTCTA-1,1,60,30,5,3938
TTGTTTCCATACAACT-1,1,45,27,1,3939


In [24]:
class HETransform(object):
    def __init__(self, p=.8, brightness=.1, contrast=.1, saturation=.1, hue=.1, sharpness=.3,
                 no_flip=False, no_color=False, normalize=True):
        self.brightness = brightness
        self.contrast = contrast
        self.saturation = saturation
        self.hue = hue
        self.sharpness = sharpness
        self.no_flip = no_flip
        self.no_color = no_color
        
        if normalize:
            self.normalize = Normalize((0.771, 0.651, 0.752), (0.229, 0.288, 0.224)) # from HT397B1-H2 ffpe H&E image
        else:
            self.normalize = nn.Identity()
 
        self.p = p
    
    def apply_color_transforms(self, x, brightness, contrast, saturation, hue, sharpness):
        x = TF.adjust_brightness(x, brightness)
        x = TF.adjust_contrast(x, contrast)
        x = TF.adjust_saturation(x, saturation)
        x = TF.adjust_hue(x, hue)
        x = TF.adjust_sharpness(x, sharpness)
        return x
        
    def __call__(self, he, mask):
        """
        """
        if isinstance(he, torch.Tensor):
            hes = [he]
            masks = [mask]
            return_type = 'image'
        elif isinstance(he, dict):
            keys = list(he.keys())
            hes = [he[k] for k in keys]
            masks = [mask[k] for k in keys]
            return_type = 'dict'
        else:
            hes = he
            masks = mask
            return_type = 'list'
                                
        # we apply transforms with probability p
        if torch.rand(size=(1,)) < self.p:
            if not self.no_color:
                brightness, contrast, saturation, hue, sharpness = (
                    np.random.uniform(max(0, 1 - self.brightness), 1 + self.brightness, size=1)[0],
                    np.random.uniform(max(0, 1 - self.contrast), 1 + self.contrast, size=1)[0],
                    np.random.uniform(max(0, 1 - self.saturation), 1 + self.saturation, size=1)[0],
                    np.random.uniform(-self.hue, self.hue, size=1)[0],
                    np.random.uniform(max(0, 1 - self.sharpness), 1 + self.sharpness, size=1)[0],
                )
                # apply color jitter and sharpness
                hes = [self.apply_color_transforms(x, brightness, contrast, saturation, hue, sharpness)
                       for x in hes]
            
            # vertical and horizontal flips happen with p=.5
            do_hflip, do_vflip = torch.rand(size=(2,)) < .5 
            if do_hflip and not self.no_flip:
                hes = [TF.hflip(x) for x in hes]
                masks = [TF.hflip(x) for x in masks]
            if do_vflip and not self.no_flip:
                hes = [TF.vflip(x) for x in hes]
                masks = [TF.vflip(x) for x in masks]
        
        # normalize he
        hes = [self.normalize(x) for x in hes]
                    
        if return_type == 'image':
            return hes[0], masks[0]
        elif return_type == 'dict':
            return {k:v for k, v in zip(keys, hes)}, {k:v for k, v in zip(keys, masks)}
        return hes, masks

In [25]:
def reflection_mosiac(x, border=256, dtype=torch.int32):
    max_r, max_c = x.shape[-2], x.shape[-1]
    if len(x.shape) == 3:
        mosaic = torch.zeros((x.shape[0], max_r + (border * 2), max_c + (border * 2))).to(dtype)
    else:
        mosaic = torch.zeros((max_r + (border * 2), max_c + (border * 2))).to(dtype)
    
    # make tiles
    top_left = TF.pad(x, padding=[border, border, 0, 0], padding_mode='reflect')
    top_right = TF.pad(x, padding=[0, border, border, 0], padding_mode='reflect')
    bottom_left = TF.pad(x, padding=[border, 0, 0, border], padding_mode='reflect')
    bottom_right = TF.pad(x, padding=[0, 0, border, border], padding_mode='reflect')
    
    if len(x.shape) == 3:
        mosaic[:, :max_r + border, :max_c + border] = top_left
        mosaic[:, :max_r + border, border:] = top_right
        mosaic[:, border:, :max_c + border] = bottom_left
        mosaic[:, border:, border:] = bottom_right
    else:
        mosaic[:max_r + border, :max_c + border] = top_left
        mosaic[:max_r + border, border:] = top_right
        mosaic[border:, :max_c + border] = bottom_left
        mosaic[border:, border:] = bottom_right
    
    return mosaic

In [26]:
def get_img_dicts(adata, keys=('2X', '8X'), border=256, mode='spot'):
    he_dict = {}
    for k, x in adata.uns['rescaled_he'].items():
        if re.findall(r'^[0-9]+X.*$', k) and 'trimmed' in k:
            scale = int(re.sub(r'^([0-9]+)X.*$', r'\1', k))
            x = torch.tensor(rearrange(x, 'h w c -> c h w')).to(torch.uint8)
            x = reflection_mosiac(x, border=border, dtype=torch.uint8)
            x = rearrange(x, 'c h w -> h w c').numpy().astype(np.uint8)
            he_dict[f'{scale}X'] = x
            
    mask_dict = {}
    for k, x in adata.uns['rescaled_spot_masks'].items():
        if re.findall(r'^[0-9]+X.*$', k) and 'trimmed' in k:
            scale = int(re.sub(r'^([0-9]+)X.*$', r'\1', k))
            x = torch.tensor(x.astype(np.int32))
            x = reflection_mosiac(x, border=border, dtype=torch.int32)
            x = x.numpy().astype(np.int32)
            mask_dict[f'{scale}X'] = x
            
    if mode == 'hex':
        mask_dict = {k:convert_spot_masks(adata, mask=v, key=f'{k}_trimmed', mode='hex')
                     for k, v in mask_dict.items()}
            
    he_dict = {k:v for k, v in he_dict.items() if k in keys}
    mask_dict = {k:v for k, v in mask_dict.items() if k in keys}
    
    return he_dict, mask_dict

In [116]:
def convert_rgb(img):
    if img.shape[-1] == 3:
        img = rearrange(img, 'h w c -> c h w')
        
    if not isinstance(img, torch.Tensor):
        img = torch.tensor(img)
        
    if img.dtype == torch.uint8:
        img = TF.convert_image_dtype(img, dtype=torch.float32)
    
    if img.max() > 1.:
        img -= img.min()
        img /= img.max()
    
    return img


def create_color_augmentations(he, he_context, transform, n=10, dtype=torch.uint8):
    aug = torch.zeros(n, he.shape[0], he.shape[1], he.shape[2], dtype=dtype)
    aug_context = torch.zeros(n, he_context.shape[0], he_context.shape[1], he_context.shape[2], dtype=dtype)
    # one at a time to keep RAM down
    for i in range(n):
        (aug_he, aug_he_context), _ = transform([he, he_context], None)
        
        # rescale to 0-1
        aug_he -= aug_he.min()
        aug_he /= aug_he.max()
        aug_he_context -= aug_he_context.min()
        aug_he_context /= aug_he_context.max()
        
        # back to dtype
        aug_he = TF.convert_image_dtype(aug_he, dtype)
        aug_he_context = TF.convert_image_dtype(aug_he_context, dtype)
        
        aug[i] = aug_he
        aug_context[i] = aug_he_context
    
    return aug, aug_context


def create_masks(labeled_mask, max_area, thresh=.5):
    voxel_idxs = torch.unique(labeled_mask)[1:].numpy().astype(int)
    masks = torch.zeros((len(voxel_idxs), labeled_mask.shape[0], labeled_mask.shape[1]), dtype=torch.bool)
    for i, l in enumerate(voxel_idxs):
        m = masks[i]
        m[labeled_mask==l] = 1

    keep = masks.sum(dim=(-1,-2)) / max_area > thresh
    masks = masks[keep]
    voxel_idxs = voxel_idxs[keep]
        
    return masks, voxel_idxs
        

class STDataset(Dataset):
    """ST Dataset"""
    def __init__(self, adata, he, he_context, labeled_mask, context_scaler, coordinate_key,
                 tile_size=512, he_color_transform=None, he_post_color_transform=None,
                 border=512, max_jitter=0., n_augmentations=10, min_voxel_fraction=.5, normalize=True):
        self.spot_ids, _ = zip(*sorted([(sid, sidx) for sid, sidx in zip(adata.obs.index, adata.obs['spot_index'])],
                                    key=lambda x: x[1]))
        self.spot_ids = np.asarray(self.spot_ids)
        self.adata = adata[self.spot_ids]
        self.context_scaler = context_scaler
        self.he = convert_rgb(he)
        self.he_context = convert_rgb(he_context)
        self.labeled_mask = torch.tensor(labeled_mask, dtype=torch.int32) if not isinstance(labeled_mask, torch.Tensor) else labeled_mask
        
        self.he_color_transform = he_color_transform
        self.he_post_color_transform = he_post_color_transform
        
        self.n_augmentations = n_augmentations
        if self.n_augmentations is not None:
            self.aug_he, self.aug_he_context = create_color_augmentations(
                self.he, self.he_context, self.he_color_transform,
                n=self.n_augmentations, dtype=torch.uint8)
        else:
            self.aug_he = rearrange(self.he, 'c h w -> 1 c h w')
            self.aug_he_context = rearrange(self.he_context, 'c h w -> 1 c h w')
            
        if normalize:
            self.normalize = Normalize((0.771, 0.651, 0.752), (0.229, 0.288, 0.224)) # from HT397B1-H2 ffpe H&E image
        else:
            self.normalize = nn.Identity()
                
        self.exp = torch.tensor(adata.X, dtype=torch.int32)

        self.tile_size = tile_size 
        
        self.offset = int(self.tile_size // 2 + 1)
        self.border = border
        self.max_jitter = max_jitter
  
        self.pixel_coords = np.asarray([[int(r), int(c)] for c, r in self.adata.obsm[coordinate_key]])
    
        idxs = np.random.choice(np.arange(self.pixel_coords.shape[0]), size=1000)
        max_spots = 0
        for i in idxs:
            r, c = self.pixel_coords[i]
            r, c = r + self.border, c + self.border # adjust for reflection padding
            r, c = r - self.offset, c - self.offset # adjust from center to top left
            lm = TF.crop(self.labeled_mask, top=r, left=c, height=self.tile_size, width=self.tile_size)
            m, _ = create_masks(lm, 1000, thresh=.0001)
            m = m[m.sum(dim=(-1, -2))>0]
            max_spots = max(max_spots, m.shape[0])
    
        self.max_spots = 2**int(np.log2(max_spots) + 1)
        
        r, c = self.he.shape[1] // 2, self.he.shape[2] // 2
        tile = TF.crop(self.labeled_mask, top=r, left=c, height=self.tile_size, width=self.tile_size)
        self.max_voxel_area = np.max(
            regionprops_table(tile.detach().numpy(), properties=['label', 'area'])['area'])
        self.min_voxel_fraction = min_voxel_fraction
        
    def __len__(self):
        return len(self.pixel_coords)
    
    def __getitem__(self, idx):
        r, c = self.pixel_coords[idx]
        r = int(r + np.random.uniform(-self.max_jitter, self.max_jitter))
        c = int(c + np.random.uniform(-self.max_jitter, self.max_jitter))
        r_context, c_context = r // self.context_scaler, c // self.context_scaler
        
        # add offset and border
        r, c = r + self.border - self.offset, c + self.border - self.offset
        r_context, c_context = r_context + self.border - self.offset, c_context + self.border - self.offset
        
        if self.n_augmentations is not None:
            i = torch.randint(0, self.n_augmentations, (1,)).item()
        else:
            i = 0
            
        he = TF.crop(self.aug_he[i], top=r, left=c, height=self.tile_size, width=self.tile_size)
        he_context = TF.crop(self.aug_he_context[i], top=r_context, left=c_context,
                             height=self.tile_size, width=self.tile_size)
        
        # we need to go to float32
        he = TF.convert_image_dtype(he, torch.float32)
        he_context = TF.convert_image_dtype(he_context, torch.float32)
        
        # create mask
        lm = TF.crop(self.labeled_mask, top=r, left=c, height=self.tile_size, width=self.tile_size)
        m, v_idxs = create_masks(lm, self.max_voxel_area, thresh=self.min_voxel_fraction)
        idxs = v_idxs - 1
        masks = torch.zeros((self.max_spots, m.shape[1], m.shape[2]), dtype=m.dtype)
        masks[:len(idxs)] = m
        
        # post color augs
        if self.he_post_color_transform is not None:
            (he, he_context), (masks,) = self.he_post_color_transform([he, he_context], [masks])
        
        # exp
        exp = torch.zeros((self.max_spots, self.exp.shape[1]), dtype=self.exp.dtype)
        exp[:len(idxs)] = self.exp[idxs]
        
        voxel_idxs = torch.zeros((self.max_spots,), dtype=torch.int32)
        voxel_idxs[:len(idxs)] = torch.tensor(v_idxs, dtype=torch.int32)
        
        return {
            'he': self.normalize(he),
            'he_context': self.normalize(he_context),
            'he_orig': he,
            'masks': masks,
            'voxel_idxs': voxel_idxs,
            'exp': exp,
            'n_voxels': len(idxs)
        }
        
        

In [None]:
scales = [2, 8]
keys = [f'{s}X' for s in scales]
mode = 'spot'
tile_size = 256
spot_radius = next(iter(sample_to_adata.values())).uns['rescaled_spot_metadata'][keys[0] + '_trimmed']['spot_radius']
jitter = int(spot_radius * 2)
border = jitter * 10

train_he_color_transform = HETransform(p=.95, brightness=.1, contrast=.1, saturation=.1, hue=.1,
                                 no_flip=True, no_color=False, normalize=False)
train_he_post_color_transform = HETransform(p=.95, no_flip=False, no_color=True, normalize=False)

sample_to_train_ds = {}
for s in train_samples:
    a = sample_to_adata[s]

    he_dict, mask_dict = get_img_dicts(a, keys=keys, border=border, mode=mode)

    ds = STDataset(
        a, he_dict[keys[0]], he_dict[keys[1]], mask_dict[keys[0]], scales[1] // scales[0], 'spatial_2X_trimmed',
        tile_size=tile_size, he_color_transform=train_he_color_transform, 
        he_post_color_transform=train_he_post_color_transform,
        max_jitter=jitter, border=border, n_augmentations=2, min_voxel_fraction=.5, normalize=True)
    
    sample_to_train_ds[s] = ds
    

In [None]:
val_he_color_transform = HETransform(p=0., normalize=False)
val_he_post_color_transform = HETransform(p=0., normalize=False)

sample_to_val_ds = {}
for s in val_samples:
    a = sample_to_adata[s]

    he_dict, mask_dict = get_img_dicts(a, keys=keys, border=border, mode=mode)

    ds = STDataset(
        a, he_dict[keys[0]], he_dict[keys[1]], mask_dict[keys[0]], scales[1] // scales[0], 'spatial_2X_trimmed',
        tile_size=tile_size, he_color_transform=val_he_color_transform, 
        he_post_color_transform=val_he_post_color_transform,
        max_jitter=0., border=border, n_augmentations=None, min_voxel_fraction=.5, normalize=True)
    
    sample_to_val_ds[s] = ds
    

In [None]:
class MultisampleSTDataset(Dataset):
    def __init__(self, ds_dict):
        super().__init__()
        self.samples = list(ds_dict.keys())
        self.ds_dict = ds_dict
        
        self.mapping = [(k, i) for k, ds in ds_dict.items() for i in range(len(ds))]
        
    def __len__(self):
        return len(self.mapping)

    def __getitem__(self, idx):
        k, i = self.mapping[idx]
        
        d = self.ds_dict[k][i]

        return d

In [None]:
train_ds = MultisampleSTDataset(sample_to_train_ds)
len(train_ds)

In [None]:
val_ds = MultisampleSTDataset(sample_to_val_ds)
len(val_ds)

In [None]:
batch_size = 32
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=1)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

In [None]:
%%time
batch = next(iter(train_dl))

In [67]:
ds = next(iter(sample_to_train_ds.values()))

In [68]:
d = ds[0]

In [69]:
d['he'].shape

torch.Size([3, 256, 256])

In [81]:
def convert_rgb(img):
    if img.shape[-1] == 3:
        img = rearrange(img, 'h w c -> c h w')
        
    if not isinstance(img, torch.Tensor):
        img = torch.tensor(img)
        
    if torch.max(img)>1.1:
        img = TF.convert_image_dtype(img, dtype=torch.float32)
            
    return img


def generate_voxel_masks(labeled_mask, padded_spot_idxs, n_voxels):
    masks = torch.zeros((padded_spot_idxs.shape[0], labeled_mask.shape[0], labeled_mask.shape[1]))
    for i, l in enumerate(padded_spot_idxs[:n_voxels]):
        m = labeled_mask.clone().detach()
        m[m!=l] = 0.
        m[m==l] = 1.
        masks[i] = m
    return masks


def generate_padded_exp(adata, spot_idxs, max_spots, masks, pixels_per_voxel, use_raw=False):
    if use_raw:
        x = adata.raw.X[spot_idxs]
    else:
        x = adata.X[spot_idxs]

    if 'sparse' in str(type(x)).lower():
        x = x.toarray()

    padded_exp = torch.zeros((max_spots, x.shape[1]))
    padded_exp[:x.shape[0]] = torch.tensor(x)

    padded_exp *= (masks.sum(dim=(1,2)).unsqueeze(dim=-1) + 1.) / pixels_per_voxel
    padded_exp = torch.round(padded_exp)
    
    return padded_exp

def generate_expression_tiles(masks, exp, n_voxels, tile_res=32):    
    tiles = torch.zeros(masks.shape[-2],
                        masks.shape[-1],
                        exp.shape[-1])
    
    for m, e in zip(masks[:n_voxels], exp[:n_voxels]):
        tiles[m==1] = e 
        
    tiles = reduce(tiles, '(h1 h2) (w1 w2) c -> h1 w1 c', 'mean', h2=tile_res, w2=tile_res)
        
    return tiles

def get_n_voxels(padded_voxel_idxs):
    if padded_voxel_idxs.sum() == 0:
        return 0
    idx = padded_voxel_idxs.flip((0,)).nonzero()[0].item()
    return len(padded_voxel_idxs[:-idx])


class STDataset(Dataset):
    """ST Dataset"""
    def __init__(self, adata, he_dict, mask_dict, coordinate_key,
                 tile_sizes=512, use_raw=False, he_transform=None,
                 border=512, max_jitter=0.,
                 tile_res=32, normalize=True):
        """
        adata: AnnData object
            - .X must be unnormalized counts
            - must have column in .obs['spot_index'] that specified the spot index in the scaled mask dict
        scaled_he_dict: dict
            - values are rgb H&E images, keys are '[0-9]+X', where the integer in front of X is the scale factor of the H&E image.
        scaled_mask_dict: dict
            - values are labeled images where 0 is background and all other pixels coorespond to index stored in .obs['spot_index'].
        """
        super().__init__()
        # make sure we are ordered by spot index
        self.spot_ids, _ = zip(*sorted([(sid, sidx) for sid, sidx in zip(adata.obs.index, adata.obs['spot_index'])],
                                    key=lambda x: x[1]))
        self.spot_ids = np.asarray(self.spot_ids)
        self.adata = adata[self.spot_ids]
        self.he_dict = {k:convert_rgb(v) for k, v in he_dict.items()}
        self.mask_dict = {k:torch.tensor(v).to(torch.int32) if not isinstance(v, torch.Tensor) else v
                                 for k, v in mask_dict.items()}
        self.scales = sorted([int(re.sub(r'^([0-9]+)X$', r'\1', k)) for k in self.he_dict.keys()])
        self.tile_res = tile_res
        
        if normalize:
            self.normalize = Normalize((0.771, 0.651, 0.752), (0.229, 0.288, 0.224)) # from HT397B1-H2 ffpe H&E image
        else:
            self.normalize = nn.Identity()
        
        if isinstance(tile_sizes, int):
            self.tile_sizes = [tile_sizes] * len(self.scales)
        else:
            self.tile_sizes = tile_sizes # defines the size of crops to be taken from each h&e resolution
        
        _, n_row, n_col = self.he_dict[str(self.scales[0]) + 'X'].shape
        self.offset = int(self.tile_sizes[0] // 2 + 1)
        self.border = border
        self.max_jitter = max_jitter
  
        self.pixel_coords = np.asarray([[int(c), int(r)] for r, c in self.adata.obsm[coordinate_key]])

        self.he_transform = he_transform

        # expression related
        self.use_raw = use_raw
        
        idxs = np.random.choice(np.arange(self.pixel_coords.shape[0]), size=1000)
        key = str(self.scales[0]) + 'X'
        max_spots = 0
        self.pixels_per_voxel = 0
        for i in idxs:
            mask = self.mask_dict[key]
            r, c = self.pixel_coords[i]
            r, c = r + self.border, c + self.border # adjust for reflection padding
            r, c = r - self.offset, c - self.offset # adjust from center to top left
            m = TF.crop(mask, top=r, left=c, height=self.tile_sizes[0], width=self.tile_sizes[0])
            max_spots = max(max_spots, len(np.unique(m)))
        self.max_spots = 2**int(np.log2(max_spots) + 1)
        r, c = n_row // 2, n_col // 2
        tile = self.mask_dict[key][r:r + self.tile_sizes[0] * 2, c:c + self.tile_sizes[0] * 2]
        self.pixels_per_voxel = np.max(
            regionprops_table(tile.detach().numpy(), properties=['label', 'area'])['area'])
        

    def __len__(self):
        return len(self.pixel_coords)

    def __getitem__(self, idx):
        r, c = self.pixel_coords[idx]
        r = int(r + np.random.uniform(-self.max_jitter, self.max_jitter))
        c = int(c + np.random.uniform(-self.max_jitter, self.max_jitter))
        initial = self.scales[0]
        scale_to_coords = {s: (
                               int((r / (s / initial)) - self.offset + self.border),
                               int((c / (s / initial)) - self.offset + self.border)
                           )
                           for i, s in enumerate(self.scales)}
        he_tile_dict, mask_tile_dict = {}, {}
        for scale, tile_size in zip(self.scales, self.tile_sizes):
            key = f'{scale}X'
            he, mask = self.he_dict[key], self.mask_dict[key]
            r, c = scale_to_coords[scale]
            he_tile_dict[key] = TF.crop(he, top=r, left=c, height=tile_size, width=tile_size)
            mask_tile_dict[key] = TF.crop(mask, top=r, left=c, height=tile_size, width=tile_size)
            
        if self.he_transform is not None:
            he_tile_dict, mask_tile_dict = self.he_transform(he_tile_dict, mask_tile_dict)
        spot_idxs = torch.unique(mask_tile_dict[str(self.scales[0]) + 'X']).numpy() - 1
        if spot_idxs[0] == -1:
            spot_idxs = spot_idxs[1:] # drop first value, which is background
        padded_spot_idxs = np.asarray([0] * self.max_spots)
        padded_spot_idxs[:spot_idxs.shape[0]] = spot_idxs + 1
        masks = generate_voxel_masks(mask_tile_dict[str(self.scales[0]) + 'X'],
                                     padded_spot_idxs, len(spot_idxs))
        padded_exp = generate_padded_exp(self.adata, spot_idxs, self.max_spots,
                                         masks, self.pixels_per_voxel, use_raw=self.use_raw)
        
        padded_spot_idxs = torch.tensor(padded_spot_idxs, dtype=torch.int16)
        
        return {
            'he': self.normalize(he_tile_dict['2X']),
            'he_context': self.normalize(he_tile_dict['8X']),
            'he_orig': he_tile_dict['2X'],
            'masks': masks.to(torch.bool),
            'voxel_idxs': padded_spot_idxs,
            'exp': padded_exp.to(torch.int32),
            'n_voxels': get_n_voxels(padded_spot_idxs)
        }
        

In [82]:
def reflection_mosiac(x, border=256, dtype=torch.int32):
    max_r, max_c = x.shape[-2], x.shape[-1]
    if len(x.shape) == 3:
        mosaic = torch.zeros((x.shape[0], max_r + (border * 2), max_c + (border * 2))).to(dtype)
    else:
        mosaic = torch.zeros((max_r + (border * 2), max_c + (border * 2))).to(dtype)
    
    # make tiles
    top_left = TF.pad(x, padding=[border, border, 0, 0], padding_mode='reflect')
    top_right = TF.pad(x, padding=[0, border, border, 0], padding_mode='reflect')
    bottom_left = TF.pad(x, padding=[border, 0, 0, border], padding_mode='reflect')
    bottom_right = TF.pad(x, padding=[0, 0, border, border], padding_mode='reflect')
    
    if len(x.shape) == 3:
        mosaic[:, :max_r + border, :max_c + border] = top_left
        mosaic[:, :max_r + border, border:] = top_right
        mosaic[:, border:, :max_c + border] = bottom_left
        mosaic[:, border:, border:] = bottom_right
    else:
        mosaic[:max_r + border, :max_c + border] = top_left
        mosaic[:max_r + border, border:] = top_right
        mosaic[border:, :max_c + border] = bottom_left
        mosaic[border:, border:] = bottom_right
    
    return mosaic

In [83]:
def get_img_dicts(adata, keys=('2X', '8X'), border=256, mode='spot'):
    he_dict = {}
    for k, x in adata.uns['rescaled_he'].items():
        if re.findall(r'^[0-9]+X.*$', k) and 'trimmed' in k:
            scale = int(re.sub(r'^([0-9]+)X.*$', r'\1', k))
            x = torch.tensor(rearrange(x, 'h w c -> c h w')).to(torch.uint8)
            x = reflection_mosiac(x, border=border, dtype=torch.uint8)
            x = rearrange(x, 'c h w -> h w c').numpy().astype(np.uint8)
            he_dict[f'{scale}X'] = x
            
    mask_dict = {}
    for k, x in adata.uns['rescaled_spot_masks'].items():
        if re.findall(r'^[0-9]+X.*$', k) and 'trimmed' in k:
            scale = int(re.sub(r'^([0-9]+)X.*$', r'\1', k))
            x = torch.tensor(x.astype(np.int32))
            x = reflection_mosiac(x, border=border, dtype=torch.int32)
            x = x.numpy().astype(np.int32)
            mask_dict[f'{scale}X'] = x
            
    if mode == 'hex':
        mask_dict = {k:convert_spot_masks(adata, mask=v, key=f'{k}_trimmed', mode='hex')
                     for k, v in mask_dict.items()}
            
    he_dict = {k:v for k, v in he_dict.items() if k in keys}
    mask_dict = {k:v for k, v in mask_dict.items() if k in keys}
    
    return he_dict, mask_dict

In [84]:
keys = ['2X', '8X']
mode = 'spot'
tile_size = 256
spot_radius = next(iter(sample_to_adata.values())).uns['rescaled_spot_metadata'][keys[0] + '_trimmed']['spot_radius']
jitter = int(spot_radius * 2)
border = jitter * 10
# border = spot_radius * 4

train_he_transform = HETransform(p=.95, brightness=.1, contrast=.1, saturation=.1, hue=.1, normalize=False)

sample_to_train_ds = {}
for s in train_samples:
    a = sample_to_adata[s]

    he_dict, mask_dict = get_img_dicts(a, keys=keys, border=border, mode=mode)

    ds = STDataset(a, he_dict, mask_dict, 'spatial_2X_trimmed',
                     tile_sizes=tile_size, he_transform=train_he_transform,
                     max_jitter=jitter, border=border, normalize=True)
    
    sample_to_train_ds[s] = ds
    

In [85]:
val_he_transform = HETransform(p=0.0, normalize=False)

sample_to_val_ds = {}
for s in val_samples:
    a = sample_to_adata[s]
    he_dict, mask_dict = get_img_dicts(a, keys=keys, border=border, mode=mode)
    
    ds = STDataset(a, he_dict, mask_dict, 'spatial_2X_trimmed',
                   tile_sizes=tile_size, he_transform=val_he_transform,
                   max_jitter=0., border=border, normalize=True)
    
    sample_to_val_ds[s] = ds

In [86]:
class MultisampleSTDataset(Dataset):
    def __init__(self, ds_dict):
        super().__init__()
        self.samples = list(ds_dict.keys())
        self.ds_dict = ds_dict
        
        self.mapping = [(k, i) for k, ds in ds_dict.items() for i in range(len(ds))]
        
    def __len__(self):
        return len(self.mapping)

    def __getitem__(self, idx):
        k, i = self.mapping[idx]
        
        d = self.ds_dict[k][i]

        return d

In [87]:
train_ds = MultisampleSTDataset(sample_to_train_ds)
len(train_ds)

3940

In [88]:
# import time
# time.sleep(60 * 10)

In [89]:
val_ds = MultisampleSTDataset(sample_to_val_ds)
len(val_ds)

3234

In [90]:
batch_size = 32
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=1)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

In [93]:
%%time
batch = next(iter(train_dl))

CPU times: user 41.3 ms, sys: 112 ms, total: 153 ms
Wall time: 2.7 s


In [None]:
a = sample_to_adata[train_samples[0]]
a.var.index.to_list()

###### data inspection

In [None]:
def cte(padded_exp, masks, n_voxels):
    tile = torch.zeros((masks.shape[1], masks.shape[2], padded_exp.shape[1]))
    for exp, m in list(zip(padded_exp, masks))[:n_voxels]:
        tile[m==1] = exp.to(torch.float32)
    return tile

In [None]:
for i in range(50):
    print(i)
    img = rearrange(train_ds[i]['he_orig'], 'c h w -> h w c')
    img -= img.min()
    img /= img.max()
    plt.imshow(img)
    plt.show()

In [None]:
# i = 166
# i = 66
# i = 22
i = 14
d = train_ds[i]

In [None]:
img = rearrange(d['he'], 'c h w -> h w c')
img -= img.min()
img /= img.max()
plt.imshow(img)

In [None]:
img = rearrange(d['he_orig'], 'c h w -> h w c')
plt.imshow(img)

In [None]:
plt.imshow(torch.sum(d['masks'], dim=0))

In [None]:
for j in range(d['n_voxels']):
    plt.imshow(d['masks'][j])
    plt.show()

In [None]:
d['masks'].shape

In [None]:
torch.sum(d['masks'], dim=(-1, -2))

In [None]:
a.var.index.to_list().index('KRT18')

In [None]:
gene = 'IL7R'
recon = cte(d['exp'], d['masks'], d['n_voxels'])
plt.imshow(recon[:, :, a.var.index.to_list().index(gene)])

In [None]:
gene = 'KRT18'
recon = cte(d['exp'], d['masks'], d['n_voxels'])
plt.imshow(recon[:, :, a.var.index.to_list().index(gene)])

In [None]:
gene = 'PECAM1'
recon = cte(d['exp'], d['masks'], d['n_voxels'])
plt.imshow(recon[:, :, a.var.index.to_list().index(gene)])

In [None]:
d['exp']

In [None]:
d['voxel_idxs']

In [None]:
pool = set(d['voxel_idxs'][:d['n_voxels']].detach().numpy())
a.obs['highlight'] = ['yes' if i in pool else 'no'
                               for i in a.obs['spot_index']]
sc.pl.spatial(a, color='highlight')

In [None]:
sc.pl.spatial(a, color='IL7R')

In [None]:
i = 1

In [None]:
img = rearrange(val_ds[i]['he_orig'], 'c h w -> h w c')
plt.imshow(img)

In [None]:
gene = 'EPCAM'
recon = cte(val_ds[i]['exp'], val_ds[i]['masks'], val_ds[i]['n_voxels'])
plt.imshow(recon[:, :, val_adata.var.index.to_list().index(gene)])

In [None]:
pool = set(val_ds[i]['voxel_idxs'][:val_ds[i]['n_voxels']].detach().numpy())
val_adata.obs['highlight'] = ['yes' if i in pool else 'no'
                               for i in val_adata.obs['spot_index']]
sc.pl.spatial(val_adata, color='highlight')

#### model

In [None]:
"""
modified from https://gist.github.com/rwightman/f8b24f4e6f5504aba03e999e02460d31
"""
class Unet(nn.Module):
    """Unet is a fully convolution neural network for image semantic segmentation
    Args:
        encoder_name: name of classification model (without last dense layers) used as feature
            extractor to build segmentation model.
        encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
        decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks
        decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
            is used.
        num_classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
        center: if ``True`` add ``Conv2dReLU`` block on encoder head
    NOTE: This is based off an old version of Unet in https://github.com/qubvel/segmentation_models.pytorch
    """

    def __init__(
            self,
            backbone='resnet34',
            backbone_kwargs=None,
            backbone_indices=None,
            decoder_use_batchnorm=True,
            decoder_channels=(256, 128, 64, 32, 16),
            in_chans=1,
            num_classes=5,
            center=False,
            norm_layer=nn.BatchNorm2d,
    ):
        super().__init__()
        backbone_kwargs = backbone_kwargs or {}
        # NOTE some models need different backbone indices specified based on the alignment of features
        # and some models won't have a full enough range of feature strides to work properly.
        encoder = create_model(
            backbone, features_only=True, out_indices=backbone_indices, in_chans=in_chans,
            pretrained=False, **backbone_kwargs)
        encoder_channels = encoder.feature_info.channels()[::-1]
        self.encoder = encoder

        if not decoder_use_batchnorm:
            norm_layer = None
        self.decoder = UnetDecoder(
            encoder_channels=encoder_channels,
            decoder_channels=decoder_channels,
            final_channels=num_classes,
            norm_layer=norm_layer,
            center=center,
        )

    def forward(self, x: torch.Tensor):
        x = self.encoder(x)
        x.reverse()  # torchscript doesn't work with [::-1]
        x = self.decoder(x)
        return x


class Conv2dBnAct(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0,
                 stride=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False)
        self.bn = norm_layer(out_channels)
        self.act = act_layer(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return x


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2.0, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
        super().__init__()
        conv_args = dict(kernel_size=3, padding=1, act_layer=act_layer)
        self.scale_factor = scale_factor
        if norm_layer is None:
            self.conv1 = Conv2dBnAct(in_channels, out_channels, **conv_args)
            self.conv2 = Conv2dBnAct(out_channels, out_channels,  **conv_args)
        else:
            self.conv1 = Conv2dBnAct(in_channels, out_channels, norm_layer=norm_layer, **conv_args)
            self.conv2 = Conv2dBnAct(out_channels, out_channels, norm_layer=norm_layer, **conv_args)

    def forward(self, x, skip: Optional[torch.Tensor] = None):
        if self.scale_factor != 1.0:
            x = F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class UnetDecoder(nn.Module):

    def __init__(
            self,
            encoder_channels,
            decoder_channels=(256, 128, 64, 32, 16),
            final_channels=1,
            norm_layer=nn.BatchNorm2d,
            center=False,
    ):
        super().__init__()

        if center:
            channels = encoder_channels[0]
            self.center = DecoderBlock(channels, channels, scale_factor=1.0, norm_layer=norm_layer)
        else:
            self.center = nn.Identity()

        in_channels = [in_chs + skip_chs for in_chs, skip_chs in zip(
            [encoder_channels[0]] + list(decoder_channels[:-1]),
            list(encoder_channels[1:]) + [0])]
        out_channels = decoder_channels

        self.blocks = nn.ModuleList()
        for in_chs, out_chs in zip(in_channels, out_channels):
            self.blocks.append(DecoderBlock(in_chs, out_chs, norm_layer=norm_layer))
        self.final_conv = nn.Conv2d(out_channels[-1], final_channels, kernel_size=(1, 1))

        self._init_weight()

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x: List[torch.Tensor]):
        encoder_head = x[0]
        skips = x[1:]
        x = self.center(encoder_head)
        for i, b in enumerate(self.blocks):
            skip = skips[i] if i < len(skips) else None
            x = b(x, skip)
        x = self.final_conv(x)
        return x

In [None]:
class UnetBased(nn.Module):
    def __init__(
        self,
        genes,
        tile_resolution = 16,
        n_metagenes = 20,
        in_channels = 3,
        out_channels = 64,
        decoder_channels = (128, 64, 32, 16, 8),
        context_decoder_channels = (128, 64, 32, 16, 8),
        he_scaler = .1,
        kl_scaler = .001,
        exp_scaler = 1.
    ):
        super().__init__()
        
        self.genes = genes
        self.n_genes = len(genes)

        
        self.he_scaler = he_scaler
        self.kl_scaler = kl_scaler
        self.exp_scaler = exp_scaler
        
        self.unet = Unet(backbone='resnet34',
                         decoder_channels=decoder_channels,
                         in_chans=in_channels,
                         num_classes=out_channels)
        
        self.context_unet = Unet(backbone='resnet34',
                         decoder_channels=context_decoder_channels,
                         in_chans=in_channels,
                         num_classes=out_channels)
        
        self.post_unet_conv = nn.Conv2d(in_channels=out_channels * 2, out_channels=out_channels,
                                        kernel_size=1)

        # latent mu and var
        self.latent_mu = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1)
        self.latent_var = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1)
        self.latent_norm = nn.BatchNorm2d(out_channels)
        
        self.n_metagenes = n_metagenes
        self.tile_resolution = tile_resolution
        self.metagenes = torch.nn.Parameter(torch.rand(self.n_metagenes, self.n_genes))
        self.scale_factors = torch.nn.Parameter(torch.rand(self.n_genes))
        self.p = torch.nn.Parameter(torch.rand(self.n_genes))
        
        self.post_decode_he = torch.nn.Conv2d(out_channels, 3, 1)
        self.post_decode_exp = torch.nn.Conv2d(out_channels, self.n_metagenes, 1)
        
        self.he_loss = torch.nn.MSELoss()
        
    def _kl_divergence(self, z, mu, std):
        # lightning imp.
        # Monte carlo KL divergence
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)

        log_qzx = q.log_prob(z)
        log_pz = p.log_prob(z)

        kl = (log_qzx - log_pz)
        kl = kl.sum(-1)

        return kl

    def encode(self, x, x_context, use_means=False):
        x_encoded = self.unet(x)
        x_context_encoded = self.context_unet(x)
        
        x_encoded = torch.concat((x_encoded, x_context_encoded), dim=1)
        x_encoded = self.post_unet_conv(x_encoded)
        
        x_encoded = self.latent_norm(x_encoded)
        
        mu, log_var = self.latent_mu(x_encoded), self.latent_var(x_encoded)
        
        # sample z from parameterized distributions
        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu, std)
        # get our latent
        if use_means:
            z = mu
        else:
            z = q.rsample()

        return z, mu, std
    
    def calculate_loss(self, he_true, exp_true, result):
        exp_loss = torch.mean(-result['nb'].log_prob(exp_true))
        
        kl_loss = torch.mean(self._kl_divergence(result['z'], result['z_mu'], result['z_std']))
        
        he_loss = torch.mean(self.he_loss(he_true, result['he']))
        
        return {
            'overall_loss': exp_loss * self.exp_scaler + kl_loss * self.kl_scaler + he_loss * self.he_scaler,
            'exp_loss': exp_loss,
            'kl_loss': kl_loss,
            'he_loss': he_loss
        }
    
    def reconstruct_expression(self, dec, masks=None, voxel_idxs=None, reduce_to_voxel=True):
        x = self.post_decode_exp(dec) # (b c h w)
        
        if reduce_to_voxel:
            x = reduce_to_voxel_level(x, masks) # (b, v, m)
        else:  
            x = rearrange(x, 'b c h w -> b h w c')
        
        r = x @ self.metagenes
        r = r * self.scale_factors
        r = F.softplus(r)
        
        p = torch.sigmoid(self.p)
        
        if reduce_to_voxel:
            p = rearrange(p, 'c -> 1 1 c')
            r = mask_nb_params(r, voxel_idxs)
        else:
            p = rearrange(p, 'c -> 1 1 1 c')
            
        r += .00000001
            
        nb = torch.distributions.NegativeBinomial(r, p)
        
        return {
            'r': r,
            'p': p,
            'exp': nb.mean,
            'nb': nb,
            'metagene_activity': x # (b v m)
        }
    
    def reconstruct_he(self, dec):
        he = self.post_decode_he(dec)
        return he

    def forward(self, x, x_context, masks=None, voxel_idxs=None, reduce_to_voxel=True, use_means=False):
        z, z_mu, z_std = self.encode(x, x_context, use_means=use_means)

        he = self.reconstruct_he(z)
        
        exp_result = self.reconstruct_expression(z, masks=masks, voxel_idxs=voxel_idxs,
                                                 reduce_to_voxel=reduce_to_voxel)
        
        result = {
            'z': z,
            'z_mu': z_mu,
            'z_std': z_std,
            'he': he,
        }
        result.update(exp_result)

        return result

In [None]:
def reduce_to_voxel_level(x, masks):
    """
    x - (b, c, h, w)
    masks - (b, v, h, w)
    
    out - (b, v, c)
    """
    masks = masks.unsqueeze(dim=2) # (b, v, c, h, w)
    x = x.unsqueeze(dim=1).repeat(1, 16, 1, 1, 1) # (b, v, c, h, w)
    x *= masks 
    return x.sum(dim=(-1, -2)) # (b, v, m)

def mask_nb_params(r, voxel_idxs):
    mask = torch.zeros_like(voxel_idxs, dtype=torch.bool)
    if r.is_cuda:
        mask = mask.cuda()
        
    mask[voxel_idxs == 0] = 1

    mask = mask.unsqueeze(dim=-1)
    masked_r = r.masked_fill(mask, 0.)
    
    return masked_r

def construct_tile_expression(padded_exp, masks, n_voxels, normalize=True):
    tile = torch.zeros((masks.shape[0], masks.shape[-2], masks.shape[-1], padded_exp.shape[-1]),
                       device=padded_exp.device)
    for b in range(tile.shape[0]):
        for exp, m in zip(padded_exp[b], masks[b]):
            tile[b, :, :][m==1] = exp.to(torch.float32)
            
    tile = rearrange(tile, 'b h w c -> b c h w')
    tile = tile.detach().cpu().numpy()
    
    tile /= np.expand_dims(tile.max(axis=(0, -2, -1)), (0, -2, -1))

    return rearrange(tile, 'b c h w -> b h w c')

In [None]:
def log_intermediates(logger, batch, result, plot_genes, model,
                      n_samples=8, result_full_res=None, identifier='train'):
    model_genes = np.asarray(model.genes)
    g2i = {g:i for i, g in enumerate(model_genes)}
    gene_idxs = np.asarray([g2i[g] for g in plot_genes])
    
    img = batch['he_context'][:n_samples].clone().detach()
    img -= img.min()
    img /= img.max()
    logger.log_image(
        key=f"{identifier}/he_context",
        images=[img],
        caption=[f'{identifier} he context']
    )
    
    img = batch['he'][:n_samples].clone().detach()
    img -= img.min()
    img /= img.max()
    logger.log_image(
        key=f"{identifier}/he_groundtruth",
        images=[img],
        caption=[f'{identifier} he tile 2x']
    )
    
    img = result['he'][:n_samples].clone().detach() # (b c h w)
    img -= img.min()
    img /= img.max()
    logger.log_image(
        key=f"{identifier}/he_reconstruction",
        images=[img],
        caption=[f'{identifier} he tile recon']
    )

    recon = construct_tile_expression(batch['exp'], batch['masks'], batch['n_voxels'])
    recon = recon[:n_samples, :, :, gene_idxs]
    recon = torch.tensor(rearrange(recon, 'b h w c -> c b 1 h w'))
    logger.log_image(
        key=f"{identifier}/exp_groundtruth",
        images=[img for img in recon],
        caption=[g for g in plot_genes]
    )

    recon = construct_tile_expression(result['exp'], batch['masks'], batch['n_voxels'])
    recon = recon[:n_samples, :, :, gene_idxs]
    recon = torch.tensor(rearrange(recon, 'b h w c -> c b 1 h w'))
    logger.log_image(
        key=f"{identifier}/exp_reconstruction",
        images=[img for img in recon],
        caption=[g for g in plot_genes]
    )
    
    recon = result['metagene_activity'][:n_samples, :, gene_idxs].clone().detach().to(torch.float32)
    recon -= recon.min()
    recon /= recon.max()
    logger.log_image(
        key=f"{identifier}/metagene_activity",
        images=[img for img in recon],
    )
    
    vals = model.metagenes.clone().detach().cpu().numpy()
    vals = vals[:, gene_idxs]
    df = pd.DataFrame(data=vals, columns=plot_genes)
    logger.log_text(
        key=f'{identifier}/metagenes',
        dataframe=df
    )
    
    vals = model.scale_factors.clone().detach().cpu().numpy()
    vals = vals[gene_idxs]
    df = pd.DataFrame(data=[vals], columns=plot_genes)
    logger.log_text(
        key=f'{identifier}/scale_factors',
        dataframe=df
    )
    
    if result_full_res is not None:
        recon = result_full_res['exp'][:n_samples].clone().detach()
        recon = recon[:, :, :, gene_idxs]
        recon = rearrange(recon, 'b h w c -> c b 1 h w')
        logger.log_image(
            key=f"{identifier}/full_res_exp_reconstruction",
            images=[img for img in recon],
            caption=[g for g in plot_genes]
        )
    
    

In [None]:
class xFuseLightning(pl.LightningModule):
    def __init__(self, autokl, lr=1e-4, n_samples=8, plot_genes=['IL7R', 'KRT18', 'BGN', 'PECAM1', 'INS'],
                 train_epoch_fraction=.1):
        super().__init__()
        
        self.autokl = autokl
        self.lr = lr
        self.plot_genes = plot_genes
        self.n_samples = n_samples
        self.train_epoch_fraction = train_epoch_fraction
        
        self.save_hyperparameters(ignore=['autokl'])

    def training_step(self, batch, batch_idx):
        x, x_context, masks, voxel_idxs, exp = batch['he'], batch['he_context'], batch['masks'], batch['voxel_idxs'], batch['exp']
        result = self.autokl(x, x_context, masks=masks, voxel_idxs=voxel_idxs)
        losses = self.autokl.calculate_loss(x, exp, result)
        losses = {f'train/{k}':v for k, v in losses.items()}
        self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
        losses['loss'] = losses['train/overall_loss']
        
        # only log 10-ish% of training epochs
        if batch_idx == 0 and torch.rand(1).item() < self.train_epoch_fraction:
            result_full_res = self.autokl(x[:1], x_context[:1], reduce_to_voxel=False)
            log_intermediates(self.logger, batch, result, self.plot_genes, self.autokl,
                              n_samples=self.n_samples, result_full_res=result_full_res, identifier='train')
        
        return losses
    
    def validation_step(self, batch, batch_idx):
        x, x_context, masks, voxel_idxs, exp = batch['he'], batch['he_context'], batch['masks'], batch['voxel_idxs'], batch['exp']
        result = self.autokl(x, x_context, masks=masks, voxel_idxs=voxel_idxs)
        losses = self.autokl.calculate_loss(x, exp, result)
        losses = {f'val/{k}':v for k, v in losses.items()}
        self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
        
        if batch_idx == 0:
            result_full_res = self.autokl(x[:1], x_context[:1], reduce_to_voxel=False)
            log_intermediates(self.logger, batch, result, self.plot_genes, self.autokl,
                              n_samples=self.n_samples, result_full_res=result_full_res, identifier='val')
        
        return losses

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
    
    
    

###### test forwards

In [None]:
batch = next(iter(train_dl))
x, b, masks, voxel_idxs, exp, exp_tiles = batch['he'], batch['b'], batch['masks'], batch['voxel_idxs'], batch['exp'], batch['exp_tiles']


In [None]:
# autokl = AutoencoderKL(
#     train_adata.shape[1],
#     tile_resolution=32,
#     n_metagenes=20,
#     in_channels=3,
#     out_channels=64,
#     down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
#     up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
#     block_out_channels=[8, 16, 32, 64],
#     norm_num_groups=8,
#     latent_channels=4,
# )

In [None]:
autokl = UnetBased(
    train_adata.var.index.to_list(),
    tile_resolution=32,
    n_metagenes=20,
    in_channels=3,
    out_channels=64,
    
)

In [None]:
result = autokl(x)

In [None]:
result.keys()

In [None]:
result['exp'].shape

In [None]:
plt.imshow(result['exp'][0, :, :, 5].detach().numpy())

In [None]:
plt.imshow(result['metagene_activity'][0, :, :, 0].detach().numpy())

In [None]:
losses = autokl.calculate_loss(x, exp_tiles, result)

In [None]:
losses

#### training loop

In [None]:
project = 'unet_based'
log_dir = '/scratch1/fs1/dinglab/estorrs/deep-spatial-genomics/logs'

In [None]:
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(project=project, save_dir=log_dir)

In [None]:
# wandb.finish()

In [None]:
config = {
    'n_genes': next(iter(sample_to_train_adata.values())).shape[1],
    'genes': next(iter(sample_to_train_adata.values())).var.index.to_numpy(),
    'plot_genes': ['IL7R', 'BGN', 'KRT18', 'PECAM1', 'INS'],
    'n_covariates': 1,
    'n_metagenes': 10,
    'latent_dim': 64,
    'tile_resolution': 32,
    'he_scale': '2X',
    'he_context_scale': '8X',
    'encoder': {
        'model': 'unet',
        'in_channels': 3,
        'decoder_channels': (128, 64, 32, 16, 8),
        'context_decoder_channels': (128, 64, 32, 16, 8),
    },
    'kl_scalers': {
        'exp_scaler': 1.,
        'kl_scaler': .0001,
        'he_scaler': .1,
    },
    'training': {
        'train_samples': list(sample_to_train_adata.keys()),
        'val_samples': list(sample_to_val_adata.keys()),
        'log_n_samples': 8,
        'max_epochs': 1000,
        'check_val_every_n_epoch': 10,
        'log_train_fraction': 1.,
        'log_every_n_steps': 1,
        'accelerator': 'gpu',
        'devices': 1,
        'limit_train_batches': 1.,
        'limit_val_batches': .1,
        'lr': 1e-4,
        'batch_size': batch_size,
        'precision': 32
    },
}
logger.experiment.config.update(config)

In [None]:
autokl = UnetBased(
    config['genes'],
    tile_resolution=config['tile_resolution'],
    n_metagenes=config['n_metagenes'],
    in_channels=config['encoder']['in_channels'],
    out_channels=config['latent_dim'],
    decoder_channels=config['encoder']['decoder_channels'],
    context_decoder_channels=config['encoder']['context_decoder_channels'],
    he_scaler=config['kl_scalers']['he_scaler'],
    kl_scaler=config['kl_scalers']['kl_scaler'],
    exp_scaler=config['kl_scalers']['exp_scaler'],
)
model = xFuseLightning(autokl, lr=config['training']['lr'],
                       n_samples=config['training']['log_n_samples'],
                       train_epoch_fraction=config['training']['log_train_fraction'],
                       plot_genes=config['plot_genes'])

In [None]:
trainer = pl.Trainer(
    devices=config['training']['devices'],
    accelerator=config['training']['accelerator'],
    check_val_every_n_epoch=config['training']['check_val_every_n_epoch'],
    enable_checkpointing=False,
    limit_val_batches=config['training']['limit_val_batches'],
    limit_train_batches=config['training']['limit_train_batches'],
    log_every_n_steps=config['training']['log_every_n_steps'],
    max_epochs=config['training']['max_epochs'],
    precision=config['training']['precision'],
    logger=logger
)

In [None]:
trainer.fit(model=model, train_dataloaders=train_dl, val_dataloaders=val_dl)

In [None]:
!mkdir -p /scratch1/fs1/dinglab/estorrs/deep-spatial-genomics/runs/xfuse_improved_v5/

In [None]:
torch.save(model.state_dict(), '/scratch1/fs1/dinglab/estorrs/deep-spatial-genomics/runs/xfuse_improved_v5/model.pt')

In [None]:
autokl = UnetBased(
    config['genes'],
    tile_resolution=config['tile_resolution'],
    n_metagenes=config['n_metagenes'],
    in_channels=config['encoder']['in_channels'],
    out_channels=config['latent_dim'],
    decoder_channels=config['encoder']['decoder_channels'],
    context_decoder_channels=config['encoder']['context_decoder_channels'],
    he_scaler=config['kl_scalers']['he_scaler'],
    kl_scaler=config['kl_scalers']['kl_scaler'],
    exp_scaler=config['kl_scalers']['exp_scaler'],
)
model = xFuseLightning(autokl, lr=config['training']['lr'],
                       n_samples=config['training']['log_n_samples'],
                       train_epoch_fraction=config['training']['log_train_fraction'])
model.load_state_dict(torch.load('/scratch1/fs1/dinglab/estorrs/deep-spatial-genomics/runs/xfuse_improved_v5/model.pt'))

In [None]:
# a = next(iter(sample_to_train_adata.values()))
# a = sample_to_val_adata['HT264P1-S1H2U1']
a = sample_to_train_adata['HT270P1-S1H1U1']
a

In [None]:
# he = a.uns['rescaled_he']['2X_notrim']
he = a.uns['rescaled_he']['2X_trimmed']

he.shape

In [None]:
plt.imshow(he)

In [None]:
# context_he = a.uns['rescaled_he']['8X_notrim']
context_he = a.uns['rescaled_he']['8X_trimmed']
context_he.shape

In [None]:
plt.imshow(context_he)

In [None]:
def reflection_mosiac(x, border=256):
    max_r, max_c = x.shape[-2], x.shape[-1]
    if len(x.shape) == 3:
        mosaic = torch.zeros((x.shape[0], max_r + (border * 2), max_c + (border * 2))).to(x.dtype)
    else:
        mosaic = torch.zeros((max_r + (border * 2), max_c + (border * 2))).to(x.dtype)
    
    # make tiles
    top_left = TF.pad(x, padding=[border, border, 0, 0], padding_mode='reflect')
    top_right = TF.pad(x, padding=[0, border, border, 0], padding_mode='reflect')
    bottom_left = TF.pad(x, padding=[border, 0, 0, border], padding_mode='reflect')
    bottom_right = TF.pad(x, padding=[0, 0, border, border], padding_mode='reflect')
    
    if len(x.shape) == 3:
        mosaic[:, :max_r + border, :max_c + border] = top_left
        mosaic[:, :max_r + border, border:] = top_right
        mosaic[:, border:, :max_c + border] = bottom_left
        mosaic[:, border:, border:] = bottom_right
    else:
        mosaic[:max_r + border, :max_c + border] = top_left
        mosaic[:max_r + border, border:] = top_right
        mosaic[border:, :max_c + border] = bottom_left
        mosaic[border:, border:] = bottom_right
    
    return mosaic

In [None]:
def tile_img(img, context_img, context_scale, tile_size=256, window_scale=2):
    context_img = reflection_mosiac(context_img, border=tile_size)
    
    n_rows = (img.shape[1] // tile_size) * window_scale + 1
    n_cols = (img.shape[2] // tile_size) * window_scale + 1
    
    tiles = torch.ones(n_rows * n_cols, img.shape[0], tile_size, tile_size, dtype=img.dtype)
    expanded = torch.ones(img.shape[0],
                          n_rows * tile_size // window_scale + tile_size,
                          n_cols * tile_size // window_scale + tile_size,
                          dtype=img.dtype)
    expanded[:, :img.shape[1], :img.shape[2]] = img
    
    context_tiles = torch.ones(n_rows * n_cols, img.shape[0], tile_size, tile_size, dtype=context_img.dtype)
    
    idx = 0
    top_left = []
    for r in range(n_rows):
        for c in range(n_cols):
            r1 = r * tile_size // window_scale
            c1 = c * tile_size // window_scale
            r2 = r1 + tile_size
            c2 = c1 + tile_size
            tiles[idx] = expanded[:, r1:r2, c1:c2]
            top_left.append((r, c))
            
            center_r = r1 + (tile_size / 2)
            center_c = c1 + (tile_size / 2)
            center_r = int(center_r / context_scale) 
            center_c = int(center_c / context_scale) 
            r1 = center_r - tile_size // 2 + tile_size
            c1 = center_c - tile_size // 2 + tile_size
            r2 = r1 + tile_size
            c2 = c1 + tile_size
            context_tiles[idx] = context_img[:, r1:r2, c1:c2]
            idx += 1
    return tiles, context_tiles, top_left

def get_tile(r, c, img, context_img, context_scale, tile_size=256, window_scale=2):
    r1 = r * tile_size // window_scale
    c1 = c * tile_size // window_scale
    r2 = r1 + tile_size
    c2 = c1 + tile_size
    tile = img[:, r1:r2, c1:c2]

    center_r = r1 + (tile_size / 2)
    center_c = c1 + (tile_size / 2)
    center_r = int(center_r / context_scale) 
    center_c = int(center_c / context_scale) 
    r1 = center_r - tile_size // 2 + tile_size
    c1 = center_c - tile_size // 2 + tile_size
    r2 = r1 + tile_size
    c2 = c1 + tile_size
    context_tile = context_img[:, r1:r2, c1:c2]
    
    return tile, context_tile

In [None]:
def prepare_img(img):
    if img.shape[0] != 3:
        img = rearrange(img, 'h w c -> c h w')
    if not isinstance(img, torch.Tensor):
        img = torch.tensor(img)
    if not img.dtype == torch.float32:
        img = TF.convert_image_dtype(img, dtype=torch.float32)
    
    return img 

class TiledHEDatasetV2(Dataset):
    def __init__(self, img, context_img, context_scale, normalize=True, tile_size=256, window_scale=2):
        super().__init__()
        
        img = prepare_img(img)
        context_img = prepare_img(context_img)

        self.context_scale = context_scale
        self.tile_size = tile_size
        self.window_scale = window_scale
        
        if normalize:
            self.normalize = Normalize((0.771, 0.651, 0.752), (0.229, 0.288, 0.224)) # from HT397B1-H2 ffpe H&E image
        else:
            self.normalize = nn.Identity()
            
        self.context_img = reflection_mosiac(context_img, border=tile_size)
    
        self.n_rows = (img.shape[1] // tile_size) * window_scale + 2
        self.n_cols = (img.shape[2] // tile_size) * window_scale + 2
        
        self.img = torch.ones(img.shape[0],
                          self.n_rows * tile_size // window_scale + tile_size,
                          self.n_cols * tile_size // window_scale + tile_size,
                          dtype=img.dtype)
        self.img[:, :img.shape[1], :img.shape[2]] = img
        
        self.idx_to_top_left = [(r, c)
                                for r in range(self.n_rows)
                                for c in range(self.n_cols)]
            
    def __len__(self):
        return len(self.idx_to_top_left)

    def __getitem__(self, idx):
        r, c = self.idx_to_top_left[idx]
        
        tile, context_tile = get_tile(r, c, self.img, self.context_img, self.context_scale,
                                     tile_size=self.tile_size, window_scale=self.window_scale)
        return {
            'he': self.normalize(tile),
            'context_he': self.normalize(context_tile),
            'top': r,
            'left': c
        }
    
class TiledHEDataset(Dataset):
    def __init__(self, img, context_img, context_scale, normalize=True, tile_size=256, window_scale=2):
        super().__init__()
        
        img = prepare_img(img)
        context_img = prepare_img(context_img)

        self.img = img
        self.context_img = context_img
        self.context_scale = context_scale
        self.tiles, self.context_tiles, self.idx_to_top_left = tile_img(img, context_img, context_scale,
                                                                        tile_size=tile_size, window_scale=window_scale)
        
        if normalize:
            self.normalize = Normalize((0.771, 0.651, 0.752), (0.229, 0.288, 0.224)) # from HT397B1-H2 ffpe H&E image
        else:
            self.normalize = nn.Identity()
            
    def max_r(self):
        rs, cs = zip(*self.idx_to_top_left)
        return np.max(rs)

    def max_c(self):
        rs, cs = zip(*self.idx_to_top_left)
        return np.max(cs)
        
    def __len__(self):
        return self.tiles.shape[0]

    def __getitem__(self, idx):
        tile, context_tile, (r, c) = self.tiles[idx], self.context_tiles[idx], self.idx_to_top_left[idx]
        return {
            'he': self.normalize(tile),
            'context_he': self.normalize(context_tile),
            'top': r,
            'left': c
        }

In [None]:
def add_tile(tile, new, r, c, max_r, max_c, window_scale=2,):
    ts = tile.shape[0] // window_scale
    normal_offset = ts // 2
    edge_offset = tile.shape[0] // 2
    rc, cc = r * ts + edge_offset, c * ts + edge_offset
    trc, tcc = edge_offset, edge_offset

    r1 = rc - normal_offset
    r2 = rc + normal_offset
    c1 = cc - normal_offset
    c2 = cc + normal_offset
    tr1 = trc - normal_offset
    tr2 = trc + normal_offset
    tc1 = tcc - normal_offset
    tc2 = tcc + normal_offset
    if r == 0:
        r1 = rc - edge_offset
        r2 = rc + normal_offset
        tr1 = trc - edge_offset
        tr2 = trc + normal_offset
    if c == 0:
        c1 = cc - edge_offset
        c2 = cc + normal_offset
        tc1 = tcc - edge_offset
        tc2 = tcc + normal_offset
    if r == max_r:
        r1 = rc - normal_offset
        r2 = rc + edge_offset
        tr1 = trc - normal_offset
        tr2 = trc + edge_offset
    if c == max_c:
        c1 = cc - normal_offset
        c2 = cc + edge_offset
        tc1 = tcc - normal_offset
        tc2 = tcc + edge_offset

    new[r1:r2, c1:c2] = tile[tr1:tr2, tc1:tc2]

def predict_he(img, context_img, context_scale, model,
               batch_size=8,
               gene_idxs=None, keep=('he', 'exp', 'metagene_activity'),
               window_scale=2, rescale=False, resize=True, crop=True):
    orig_shape = img.shape
    ds = TiledHEDatasetV2(img, context_img, context_scale, window_scale=window_scale)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=False)
    
    if gene_idxs is None:
        gene_idxs = np.arange(len(model.autokl.genes))
    
    with torch.no_grad():
        batch = next(iter(dl))
        he, he_context, top, left = batch['he'], batch['context_he'], batch['top'], batch['left']

        if next(iter(model.autokl.parameters())).is_cuda:
            he, he_context = he.cuda(), he_context.cuda()

        result = model.autokl(he, context_he, reduce_to_voxel=False)
        item = {k:v[0].detach().cpu()
               for k, v in result.items()
               if k in keep}
        item['exp'] = item['exp'][:, :, gene_idxs]
        item['he'] = rearrange(item['he'], 'c h w -> h w c')
    max_r, max_c = ds.n_rows, ds.n_cols
    key_to_new = {k:torch.ones(tile.shape[0] * (max_r + 2) // window_scale,
                               tile.shape[1] * (max_c + 2) // window_scale,
                               tile.shape[2],
                               dtype=tile.dtype)
                 for k, tile in item.items()}
#     return item
    
    model.eval()
    with torch.no_grad():
        for batch in dl:
            he, he_context, top, left = batch['he'], batch['context_he'], batch['top'], batch['left']

            if next(iter(model.autokl.parameters())).is_cuda:
                he, he_context = he.cuda(), he_context.cuda()

            result = model.autokl(he, context_he, reduce_to_voxel=False)

            for i in range(result['exp'].shape[0]):
                item = {k:v[i].detach().cpu()
                       for k, v in result.items()
                       if k in keep}
                item['exp'] = item['exp'][:, :, gene_idxs]
                item['he'] = rearrange(item['he'], 'c h w -> h w c')
                r, c = top[i], left[i]
                for k in keep:
                    tile = item[k]
                    add_tile(tile, key_to_new[k], r, c, max_r, max_c, window_scale=window_scale)
    
    if resize:
        for k, new in key_to_new.items():
            new = new[:orig_shape[0], :orig_shape[1]]
            key_to_new[k] = new
    
    if rescale:
        for k, new in key_to_new.items():
            if k == 'he':
                new -= new.min()
                new /= new.max()
            else:
                new -= new.amin(dim=(0, 1))
                new /= new.amax(dim=(0, 1))
            key_to_new[k] = new

    return key_to_new


In [None]:
model = model.cuda()

In [None]:
# show_genes = ['EPCAM', 'KRT18', 'CD8A', 'IL7R', 'MS4A1', 'PECAM1', 'BGN']
# show_genes = ['KRT18', 'IL7R', 'PECAM1', 'BGN']
show_genes = model.autokl.genes
gene_idxs = [i for i, g in enumerate(model.autokl.genes) if g in show_genes]
show_genes = model.autokl.genes[gene_idxs]
key_to_retiled = predict_he(he, context_he, 4, model, gene_idxs=gene_idxs, rescale=True, window_scale=2)

In [None]:
he.shape

In [None]:
# for k, img in key_to_retiled.items():
#     key_to_retiled[k] = img[:he.shape[0], :he.shape[1]]

In [None]:
for k, img in key_to_retiled.items():
    print(k, img.shape, img.max(), img.amax(dim=(0, 1)))

In [None]:
plt.imshow(key_to_retiled['he'])

In [None]:
plt.imshow(he)

In [None]:
plt.imshow(key_to_retiled['exp'][:, :, 0])

In [None]:
for i, g in enumerate(show_genes):
    plt.imshow(key_to_retiled['exp'][:, :, i])
    plt.title(g)
    plt.show()

In [None]:
for i in range(key_to_retiled['metagene_activity'].shape[-1]):
    plt.imshow(key_to_retiled['metagene_activity'][:, :, i])
    plt.title(i)
    plt.show()

In [None]:
# key_to_retiled = predict_he(he, context_he, 4, model, gene_idxs=gene_idxs,
#                             reduce_to_tile=False, batch_size=2, rescale=True, window_scale=2)

In [None]:
# for i, g in enumerate(show_genes):
#     plt.imshow(np.log1p(key_to_retiled['exp'][:, :, i]))
#     plt.title(g)
#     plt.show()

In [None]:
cr1, cr2 = 3500, 4500
cc1, cc2 = 4000, 5000
for i, g in enumerate(show_genes):
    plt.imshow(key_to_retiled['exp'][cr1:cr2, cc1:cc2, i])
    plt.title(g)
    plt.show()

In [None]:
plt.imshow(he[cr1:cr2, cc1:cc2])

In [None]:
cr1, cr2 = 4000, 6000
cc1, cc2 = 4000, 6000
for i, g in enumerate(show_genes):
    plt.imshow(key_to_retiled['exp'][cr1:cr2, cc1:cc2, i])
    plt.title(g)
    plt.show()

In [None]:
plt.imshow(he[cr1:cr2, cc1:cc2])

In [None]:
spot_mask = a.uns['rescaled_spot_masks']['2X_trimmed']
spot_mask

In [None]:
plt.imshow(spot_mask[cr1:cr2, cc1:cc2]>0)

In [None]:
key_to_retiled['exp'].max()

In [None]:
## save gene expression image
exp = TF.convert_image_dtype(key_to_retiled['exp'], torch.uint8)
exp.shape

In [None]:
torch.save(exp, '../data/annotations/pdac/HT270P1-S1H1U1_exp.pt')

In [None]:
## save gene expression image
meta = TF.convert_image_dtype(key_to_retiled['metagene_activity'], torch.uint8)
meta.shape

In [None]:
torch.save(meta, '../data/annotations/pdac/HT270P1-S1H1U1_meta.pt')

In [None]:
import seaborn as sns
metagenes = model.autokl.metagenes.detach().cpu().numpy()
metagenes = pd.DataFrame(data=metagenes, columns=model.autokl.genes, index=np.arange(metagenes.shape[0]))
metagenes = metagenes.loc[:, show_genes]
sns.clustermap(data=metagenes)

#### apply directly to H&E