In [None]:
import logging
import os
import re
from typing import Optional, List

import scanpy as sc 
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import wandb
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler, BatchSampler
from torchvision.transforms import ColorJitter, Normalize, RandomHorizontalFlip, RandomVerticalFlip, RandomAdjustSharpness
from timm import create_model
from einops import rearrange, reduce
from skimage.color import label2rgb
from skimage.measure import regionprops_table


note that you need to make bug fix to diffusers v0.3.0

in ~/.local/lib/python3.9/site-packages/diffusers/models/unet_blocks.py you need to change out_channels parameter in DownEncoderBlock2D to make the unet work for >2 downsamples

out_channels=in_channels if add_downsample else out_channels,


In [None]:
wandb.login()

#### load datasets

In [None]:
dataset_dir = '../data/pytorch_datasets/pdac_v17/'

In [None]:
# !ls ../data/pytorch_datasets/pdac_v8/

In [None]:
def get_n_voxels(padded_voxel_idxs):
    idx = padded_voxel_idxs.flip((0,)).nonzero()[0].item()
    return len(padded_voxel_idxs[:-idx])

In [None]:
class DirectorySTDataset(Dataset):
    def __init__(self, directory, normalize=True):
        super().__init__()
        self.dir = directory
        self.fps = [os.path.join(self.dir, fp) for fp in os.listdir(self.dir) if fp[-3:]=='.pt']
        
        if normalize:
            self.normalize = Normalize((0.771, 0.651, 0.752), (0.229, 0.288, 0.224)) # from HT397B1-H2 ffpe H&E image
        else:
            self.normalize = nn.Identity()
        
    def __len__(self):
        return len(self.fps)

    def __getitem__(self, idx):
        
        fp = self.fps[idx]
        obj = torch.load(fp)
        img = TF.convert_image_dtype(obj['he'], dtype=torch.float32)

        return {
            'he': self.normalize(img),
            'he_context': self.normalize(TF.convert_image_dtype(obj['he_context'], dtype=torch.float32)),
            'he_orig': img,
            'masks': obj['masks'],
            'voxel_idxs': obj['voxel_idxs'],
            'exp': obj['exp'],
            'exp_tiles': obj['exp_tiles'],
            'n_voxels': get_n_voxels(obj['voxel_idxs']),
            'b': torch.tensor([1.])
        }

In [None]:
train_ds = DirectorySTDataset(os.path.join(dataset_dir, 'train', 'data'))
# train_ds = DirectorySTDataset(os.path.join(dataset_dir, 'train'))
len(train_ds)

In [None]:
# import time
# time.sleep(60 * 10)

In [None]:
val_ds = DirectorySTDataset(os.path.join(dataset_dir, 'val', 'data'))
# val_ds = DirectorySTDataset(os.path.join(dataset_dir, 'val'))
len(val_ds)

In [None]:
directory = os.path.join(dataset_dir, 'train', 'adatas')
# directory = os.path.join(dataset_dir, 'train')
fps = [os.path.join(directory, fp) for fp in os.listdir(directory) if fp[-5:]=='.h5ad']
sample_to_train_adata = {}
for fp in fps:
    s = fp.split('/')[-1].replace('.h5ad', '')
    sample_to_train_adata[s] = sc.read_h5ad(fp)
sample_to_train_adata.keys()

In [None]:
directory = os.path.join(dataset_dir, 'val', 'adatas')
# directory = os.path.join(dataset_dir, 'val')
fps = [os.path.join(directory, fp) for fp in os.listdir(directory) if fp[-5:]=='.h5ad']
sample_to_val_adata = {}
for fp in fps:
    s = fp.split('/')[-1].replace('.h5ad', '')
    sample_to_val_adata[s] = sc.read_h5ad(fp)
sample_to_val_adata.keys()

In [None]:
batch_size = 32
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=1)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

In [None]:
%%time
batch = next(iter(train_dl))

###### data inspection

In [None]:
def cte(padded_exp, masks, n_voxels):
    tile = torch.zeros((masks.shape[1], masks.shape[2], padded_exp.shape[1]))
    for exp, m in list(zip(padded_exp, masks))[:n_voxels]:
        tile[m==1] = exp.to(torch.float32)
    return tile

In [None]:
for i in range(100):
    print(i)
    img = rearrange(train_ds[i]['he'], 'c h w -> h w c')
    img -= img.min()
    img /= img.max()
    plt.imshow(img)
    plt.show()

In [None]:
# i = 166
# i = 66
# i = 22
i = 8

In [None]:
img = rearrange(train_ds[i]['he'], 'c h w -> h w c')
img -= img.min()
img /= img.max()
plt.imshow(img)

In [None]:
img = rearrange(train_ds[i]['he_orig'], 'c h w -> h w c')
plt.imshow(img)

In [None]:
plt.imshow(torch.sum(train_ds[i]['masks'], dim=0))

In [None]:
for j in range(train_ds[i]['n_voxels']):
    plt.imshow(train_ds[i]['masks'][j])
    plt.show()

In [None]:
torch.sum(train_ds[i]['masks'], dim=(-1, -2))

In [None]:
gene = 'IL7R'
recon = cte(train_ds[i]['exp'], train_ds[i]['masks'], train_ds[i]['n_voxels'])
plt.imshow(recon[:, :, train_adata.var.index.to_list().index(gene)])

In [None]:
recon[:, :, train_adata.var.index.to_list().index(gene)]

In [None]:
plt.imshow(train_ds[i]['exp_tiles'][:, :, train_adata.var.index.to_list().index(gene)])

In [None]:
train_ds[i]['exp']

In [None]:
train_ds[i]['voxel_idxs']

In [None]:
pool = set(train_ds[i]['voxel_idxs'][:train_ds[i]['n_voxels']].detach().numpy())
train_adata.obs['highlight'] = ['yes' if i in pool else 'no'
                               for i in train_adata.obs['spot_index']]
sc.pl.spatial(train_adata, color='highlight')

In [None]:
sc.pl.spatial(train_adata, color='IL7R')

In [None]:
i = 1

In [None]:
img = rearrange(val_ds[i]['he_orig'], 'c h w -> h w c')
plt.imshow(img)

In [None]:
gene = 'EPCAM'
recon = cte(val_ds[i]['exp'], val_ds[i]['masks'], val_ds[i]['n_voxels'])
plt.imshow(recon[:, :, val_adata.var.index.to_list().index(gene)])

In [None]:
pool = set(val_ds[i]['voxel_idxs'][:val_ds[i]['n_voxels']].detach().numpy())
val_adata.obs['highlight'] = ['yes' if i in pool else 'no'
                               for i in val_adata.obs['spot_index']]
sc.pl.spatial(val_adata, color='highlight')

#### model

In [None]:
"""
modified from https://gist.github.com/rwightman/f8b24f4e6f5504aba03e999e02460d31
"""
class Unet(nn.Module):
    """Unet is a fully convolution neural network for image semantic segmentation
    Args:
        encoder_name: name of classification model (without last dense layers) used as feature
            extractor to build segmentation model.
        encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
        decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks
        decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
            is used.
        num_classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
        center: if ``True`` add ``Conv2dReLU`` block on encoder head
    NOTE: This is based off an old version of Unet in https://github.com/qubvel/segmentation_models.pytorch
    """

    def __init__(
            self,
            backbone='resnet34',
            backbone_kwargs=None,
            backbone_indices=None,
            decoder_use_batchnorm=True,
            decoder_channels=(256, 128, 64, 32, 16),
            in_chans=1,
            num_classes=5,
            center=False,
            norm_layer=nn.BatchNorm2d,
    ):
        super().__init__()
        backbone_kwargs = backbone_kwargs or {}
        # NOTE some models need different backbone indices specified based on the alignment of features
        # and some models won't have a full enough range of feature strides to work properly.
        encoder = create_model(
            backbone, features_only=True, out_indices=backbone_indices, in_chans=in_chans,
            pretrained=False, **backbone_kwargs)
        encoder_channels = encoder.feature_info.channels()[::-1]
        self.encoder = encoder

        if not decoder_use_batchnorm:
            norm_layer = None
        self.decoder = UnetDecoder(
            encoder_channels=encoder_channels,
            decoder_channels=decoder_channels,
            final_channels=num_classes,
            norm_layer=norm_layer,
            center=center,
        )

    def forward(self, x: torch.Tensor):
        x = self.encoder(x)
        x.reverse()  # torchscript doesn't work with [::-1]
        x = self.decoder(x)
        return x


class Conv2dBnAct(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0,
                 stride=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False)
        self.bn = norm_layer(out_channels)
        self.act = act_layer(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return x


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2.0, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
        super().__init__()
        conv_args = dict(kernel_size=3, padding=1, act_layer=act_layer)
        self.scale_factor = scale_factor
        if norm_layer is None:
            self.conv1 = Conv2dBnAct(in_channels, out_channels, **conv_args)
            self.conv2 = Conv2dBnAct(out_channels, out_channels,  **conv_args)
        else:
            self.conv1 = Conv2dBnAct(in_channels, out_channels, norm_layer=norm_layer, **conv_args)
            self.conv2 = Conv2dBnAct(out_channels, out_channels, norm_layer=norm_layer, **conv_args)

    def forward(self, x, skip: Optional[torch.Tensor] = None):
        if self.scale_factor != 1.0:
            x = F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x


class UnetDecoder(nn.Module):

    def __init__(
            self,
            encoder_channels,
            decoder_channels=(256, 128, 64, 32, 16),
            final_channels=1,
            norm_layer=nn.BatchNorm2d,
            center=False,
    ):
        super().__init__()

        if center:
            channels = encoder_channels[0]
            self.center = DecoderBlock(channels, channels, scale_factor=1.0, norm_layer=norm_layer)
        else:
            self.center = nn.Identity()

        in_channels = [in_chs + skip_chs for in_chs, skip_chs in zip(
            [encoder_channels[0]] + list(decoder_channels[:-1]),
            list(encoder_channels[1:]) + [0])]
        out_channels = decoder_channels

        self.blocks = nn.ModuleList()
        for in_chs, out_chs in zip(in_channels, out_channels):
            self.blocks.append(DecoderBlock(in_chs, out_chs, norm_layer=norm_layer))
        self.final_conv = nn.Conv2d(out_channels[-1], final_channels, kernel_size=(1, 1))

        self._init_weight()

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x: List[torch.Tensor]):
        encoder_head = x[0]
        skips = x[1:]
        x = self.center(encoder_head)
        for i, b in enumerate(self.blocks):
            skip = skips[i] if i < len(skips) else None
            x = b(x, skip)
        x = self.final_conv(x)
        return x

In [None]:
class UnetBased(nn.Module):
    def __init__(
        self,
        genes,
        tile_resolution = 16,
        n_metagenes = 20,
        in_channels = 3,
        out_channels = 64,
        decoder_channels = (128, 64, 32, 16, 8),
        context_decoder_channels = (128, 64, 32, 16, 8),
        he_scaler = .1,
        kl_scaler = .001,
        exp_scaler = 1.
    ):
        super().__init__()
        
        self.genes = genes
        self.n_genes = len(genes)

        
        self.he_scaler = he_scaler
        self.kl_scaler = kl_scaler
        self.exp_scaler = exp_scaler
        
        self.unet = Unet(backbone='resnet34',
                         decoder_channels=decoder_channels,
                         in_chans=in_channels,
                         num_classes=out_channels)
        
        self.context_unet = Unet(backbone='resnet34',
                         decoder_channels=context_decoder_channels,
                         in_chans=in_channels,
                         num_classes=out_channels)
        
        self.post_unet_conv = nn.Sequential(
            nn.Conv2d(in_channels=out_channels * 2, out_channels=out_channels, kernel_size=1),
            nn.AvgPool2d(kernel_size=tile_resolution)
        )

        # latent mu and var
        self.latent_mu = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1)
        self.latent_var = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1)
        
        self.n_metagenes = n_metagenes
        self.tile_resolution = tile_resolution
        self.metagenes = torch.nn.Parameter(torch.rand(self.n_metagenes, self.n_genes))
        self.scale_factors = torch.nn.Parameter(torch.rand(self.n_genes))
        self.p = torch.nn.Parameter(torch.rand(self.n_genes))
        
        self.post_decode_he = torch.nn.Conv2d(out_channels, 3, 1)
        self.post_decode_exp = torch.nn.Conv2d(out_channels, self.n_metagenes, 1)
        
        self.he_loss = torch.nn.MSELoss()
        
    def _kl_divergence(self, z, mu, std):
        # lightning imp.
        # Monte carlo KL divergence
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)

        log_qzx = q.log_prob(z)
        log_pz = p.log_prob(z)

        kl = (log_qzx - log_pz)
        kl = kl.sum(-1)

        return kl

    def encode(self, x, x_context, use_means=False):
        x_encoded = self.unet(x)
        x_context_encoded = self.context_unet(x)
        
        x_encoded = torch.concat((x_encoded, x_context_encoded), dim=1)
        x_encoded = self.post_unet_conv(x_encoded)
        
        mu, log_var = self.latent_mu(x_encoded), self.latent_var(x_encoded)
        
        # sample z from parameterized distributions
        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu, std)
        # get our latent
        if use_means:
            z = mu
        else:
            z = q.rsample()

        return z, mu, std
    
    def calculate_loss(self, exp_true, result):
        exp_loss = torch.mean(-result['nb'].log_prob(exp_true))
        
        kl_loss = torch.mean(self._kl_divergence(result['z'], result['z_mu'], result['z_std']))
        
#         he_loss = torch.mean(self.he_loss(he_true, result['he']))
        
        return {
            'overall_loss': exp_loss * self.exp_scaler + kl_loss * self.kl_scaler,
            'exp_loss': exp_loss,
            'kl_loss': kl_loss,
#             'he_loss': he_loss
        }
    
    def reconstruct_expression(self, dec, reduce_to_tile=True):
        x = self.post_decode_exp(dec)
        
#         if reduce_to_tile:
#             x = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'sum', h2=self.tile_resolution, w2=self.tile_resolution)
            
        x = rearrange(x, 'b c h w -> b h w c')
        
        r = x @ self.metagenes
        r = r * self.scale_factors
        r = F.softplus(r)
        r += .00000001
        
        p = torch.sigmoid(self.p)
        p = rearrange(p, 'c -> 1 1 1 c')
        
        nb = torch.distributions.NegativeBinomial(r, p)
        
        return {
            'r': r,
            'p': p,
            'exp': nb.mean,
            'nb': nb,
            'metagene_activity': x # (b, h w m)
        }
    
#     def reconstruct_he(self, dec):
#         he = self.post_decode_he(dec)
#         return he

    def forward(self, x, x_context, reduce_to_tile=True, use_means=False):
        z, z_mu, z_std = self.encode(x, x_context, use_means=use_means)

#         he = self.reconstruct_he(z)
        
        exp_result = self.reconstruct_expression(z, reduce_to_tile=reduce_to_tile)
        
        result = {
            'z': z,
            'z_mu': z_mu,
            'z_std': z_std,
#             'he': he,
        }
        result.update(exp_result)

        return result

In [None]:
def log_intermediates(logger, batch, result, plot_genes, model,
                      n_samples=8, result_full_res=None, identifier='train'):
    model_genes = np.asarray(model.genes)
    g2i = {g:i for i, g in enumerate(model_genes)}
    gene_idxs = np.asarray([g2i[g] for g in plot_genes])
    
    img = batch['he_context'][:n_samples].clone().detach()
    img -= img.min()
    img /= img.max()
    logger.log_image(
        key=f"{identifier}/he_context",
        images=[img],
        caption=[f'{identifier} he context']
    )
    
    img = batch['he'][:n_samples].clone().detach()
    img -= img.min()
    img /= img.max()
    logger.log_image(
        key=f"{identifier}/he_groundtruth",
        images=[img],
        caption=[f'{identifier} he tile 2x']
    )
    
#     img = result['he'][:n_samples].clone().detach() # (b c h w)
#     img -= img.min()
#     img /= img.max()
#     logger.log_image(
#         key=f"{identifier}/he_reconstruction",
#         images=[img],
#         caption=[f'{identifier} he tile recon']
#     )

    recon = batch['exp_tiles'][:n_samples].clone().detach().to(torch.float16)
    recon = recon[:, :, :, gene_idxs]
    recon -= recon.min()
    recon /= recon.max()
    recon = rearrange(recon, 'b h w c -> c b 1 h w')
    logger.log_image(
        key=f"{identifier}/exp_groundtruth",
        images=[img for img in recon],
        caption=[g for g in plot_genes]
    )

    recon = result['exp'][:n_samples].clone().detach().to(torch.float16)
    recon = recon[:, :, :, gene_idxs]
    recon -= recon.min()
    recon /= recon.max()
    recon = rearrange(recon, 'b h w c -> c b 1 h w')
    logger.log_image(
        key=f"{identifier}/exp_reconstruction",
        images=[img for img in recon],
        caption=[g for g in plot_genes]
    )
    
    recon = result['metagene_activity'][:1].clone().detach().to(torch.float32)
    recon -= recon.min()
    recon /= recon.max()
    recon = rearrange(recon, 'b h w c -> c b 1 h w')
    logger.log_image(
        key=f"{identifier}/metagene_activity",
        images=[img for img in recon],
        caption=list(range(model.n_metagenes))
    )
    
    vals = model.metagenes.clone().detach().cpu().numpy()
    vals = vals[:, gene_idxs]
    df = pd.DataFrame(data=vals, columns=plot_genes)
    logger.log_text(
        key=f'{identifier}/scale_factors',
        dataframe=df
    )
    
    vals = model.scale_factors.clone().detach().cpu().numpy()
    vals = vals[gene_idxs]
    df = pd.DataFrame(data=[vals], columns=plot_genes)
    logger.log_text(
        key=f'{identifier}/scale_factors',
        dataframe=df
    )
    
    if result_full_res is not None:
        recon = result_full_res['exp'][:n_samples].clone().detach().to(torch.float16)
        recon = recon[:, :, :, gene_idxs]
        recon = rearrange(recon, 'b h w c -> c b 1 h w')
        logger.log_image(
            key=f"{identifier}/full_res_exp_reconstruction",
            images=[img for img in recon],
            caption=[g for g in plot_genes]
        )
    
    

In [None]:
class xFuseLightning(pl.LightningModule):
    def __init__(self, autokl, lr=1e-3, n_samples=16, plot_genes=['IL7R', 'KRT18', 'BGN', 'PECAM1'],
                 train_epoch_fraction=.1):
        super().__init__()
        
        self.autokl = autokl
        self.lr = lr
        self.plot_genes = plot_genes
        self.n_samples = n_samples
        self.train_epoch_fraction = train_epoch_fraction
        
        self.save_hyperparameters(ignore=['autokl'])

    def training_step(self, batch, batch_idx):
        x, x_context, b, masks, voxel_idxs, exp, exp_tiles = batch['he'], batch['he_context'], batch['b'], batch['masks'], batch['voxel_idxs'], batch['exp'], batch['exp_tiles']
        result = self.autokl(x, x_context)
        losses = self.autokl.calculate_loss(exp_tiles, result)
        losses = {f'train/{k}':v for k, v in losses.items()}
        self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
        losses['loss'] = losses['train/overall_loss']
        
        # only log 10-ish% of training epochs
        if batch_idx == 0 and torch.rand(1).item() < self.train_epoch_fraction:
            result_full_res = self.autokl(x[:1], x_context[:1], reduce_to_tile=False)
            log_intermediates(self.logger, batch, result, self.plot_genes, self.autokl,
                              n_samples=self.n_samples, result_full_res=result_full_res, identifier='train')
        
        return losses
    
    def validation_step(self, batch, batch_idx):
        x, x_context, b, masks, voxel_idxs, exp, exp_tiles = batch['he'], batch['he_context'], batch['b'], batch['masks'], batch['voxel_idxs'], batch['exp'], batch['exp_tiles']
        result = self.autokl(x, x_context)
        losses = self.autokl.calculate_loss(exp_tiles, result)
        losses = {f'val/{k}':v for k, v in losses.items()}
        self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True)
        
        if batch_idx == 0:
            result_full_res = self.autokl(x[:1], x_context[:1], reduce_to_tile=False)
            log_intermediates(self.logger, batch, result, self.plot_genes, self.autokl,
                              n_samples=self.n_samples, result_full_res=result_full_res, identifier='val')
        
        return losses

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
    
    
    

###### test forwards

In [None]:
batch = next(iter(train_dl))
x, b, masks, voxel_idxs, exp, exp_tiles = batch['he'], batch['b'], batch['masks'], batch['voxel_idxs'], batch['exp'], batch['exp_tiles']


In [None]:
# autokl = AutoencoderKL(
#     train_adata.shape[1],
#     tile_resolution=32,
#     n_metagenes=20,
#     in_channels=3,
#     out_channels=64,
#     down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
#     up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
#     block_out_channels=[8, 16, 32, 64],
#     norm_num_groups=8,
#     latent_channels=4,
# )

In [None]:
autokl = UnetBased(
    train_adata.var.index.to_list(),
    tile_resolution=32,
    n_metagenes=20,
    in_channels=3,
    out_channels=64,
    
)

In [None]:
result = autokl(x)

In [None]:
result.keys()

In [None]:
result['exp'].shape

In [None]:
plt.imshow(result['exp'][0, :, :, 5].detach().numpy())

In [None]:
plt.imshow(result['metagene_activity'][0, :, :, 0].detach().numpy())

In [None]:
losses = autokl.calculate_loss(x, exp_tiles, result)

In [None]:
losses

#### training loop

In [None]:
project = 'unet_based'
log_dir = '/scratch1/fs1/dinglab/estorrs/deep-spatial-genomics/logs'

In [None]:
from pytorch_lightning.loggers import WandbLogger
logger = WandbLogger(project=project, save_dir=log_dir)

In [None]:
# wandb.finish()

In [None]:
config = {
    'n_genes': next(iter(sample_to_train_adata.values())).shape[1],
    'genes': next(iter(sample_to_train_adata.values())).var.index.to_numpy(),
    'n_covariates': 1,
    'n_metagenes': 10,
    'latent_dim': 64,
    'tile_resolution': 32,
    'he_scale': '2X',
    'he_context_scale': '8X',
    'encoder': {
        'model': 'unet',
        'in_channels': 3,
        'decoder_channels': (128, 64, 32, 16, 8),
        'context_decoder_channels': (128, 64, 32, 16, 8),
    },
    'kl_scalers': {
        'exp_scaler': 1.,
        'kl_scaler': .0001,
        'he_scaler': .1,
    },
    'training': {
        'train_samples': list(sample_to_train_adata.keys()),
        'val_samples': list(sample_to_val_adata.keys()),
        'log_n_samples': 8,
        'max_epochs': 200,
        'check_val_every_n_epoch': 10,
        'log_train_fraction': 1.,
        'log_every_n_steps': 1,
        'accelerator': 'gpu',
        'devices': 1,
        'limit_train_batches': 1.,
        'limit_val_batches': .1,
        'lr': 1e-5,
        'batch_size': batch_size,
        'precision': 32
    },
}
logger.experiment.config.update(config)

In [None]:
autokl = UnetBased(
    config['genes'],
    tile_resolution=config['tile_resolution'],
    n_metagenes=config['n_metagenes'],
    in_channels=config['encoder']['in_channels'],
    out_channels=config['latent_dim'],
    decoder_channels=config['encoder']['decoder_channels'],
    context_decoder_channels=config['encoder']['context_decoder_channels'],
    he_scaler=config['kl_scalers']['he_scaler'],
    kl_scaler=config['kl_scalers']['kl_scaler'],
    exp_scaler=config['kl_scalers']['exp_scaler'],
)
model = xFuseLightning(autokl, lr=config['training']['lr'],
                       n_samples=config['training']['log_n_samples'],
                       train_epoch_fraction=config['training']['log_train_fraction'])

In [None]:
trainer = pl.Trainer(
    devices=config['training']['devices'],
    accelerator=config['training']['accelerator'],
    check_val_every_n_epoch=config['training']['check_val_every_n_epoch'],
    enable_checkpointing=False,
    limit_val_batches=config['training']['limit_val_batches'],
    limit_train_batches=config['training']['limit_train_batches'],
    log_every_n_steps=config['training']['log_every_n_steps'],
    max_epochs=config['training']['max_epochs'],
    precision=config['training']['precision'],
    logger=logger
)

In [None]:
trainer.fit(model=model, train_dataloaders=train_dl, val_dataloaders=val_dl)

In [None]:
torch.save(model.state_dict(), '/scratch1/fs1/dinglab/estorrs/deep-spatial-genomics/runs/xfuse_improved_v4/model.pt')

In [None]:
autokl = UnetBased(
    config['genes'],
    tile_resolution=config['tile_resolution'],
    n_metagenes=config['n_metagenes'],
    in_channels=config['encoder']['in_channels'],
    out_channels=config['latent_dim'],
    decoder_channels=config['encoder']['decoder_channels'],
    context_decoder_channels=config['encoder']['context_decoder_channels'],
    he_scaler=config['kl_scalers']['he_scaler'],
    kl_scaler=config['kl_scalers']['kl_scaler'],
    exp_scaler=config['kl_scalers']['exp_scaler'],
)
model = xFuseLightning(autokl, lr=config['training']['lr'],
                       n_samples=config['training']['log_n_samples'],
                       train_epoch_fraction=config['training']['log_train_fraction'])
model.load_state_dict(torch.load('/scratch1/fs1/dinglab/estorrs/deep-spatial-genomics/runs/xfuse_improved_v4/model.pt'))

In [None]:
# a = next(iter(sample_to_train_adata.values()))
# a = sample_to_val_adata['HT264P1-S1H2U1']
a = sample_to_train_adata['HT270P1-S1H1U1']
a

In [None]:
he = a.uns['rescaled_he']['2X_notrim']
# he = a.uns['rescaled_he']['2X_trimmed']

he.shape

In [None]:
plt.imshow(he)

In [None]:
context_he = a.uns['rescaled_he']['8X_notrim']
# context_he = a.uns['rescaled_he']['8X_trimmed']
context_he.shape

In [None]:
plt.imshow(context_he)

In [None]:
he.max()

In [None]:
def reflection_mosiac(x, border=256):
    max_r, max_c = x.shape[-2], x.shape[-1]
    if len(x.shape) == 3:
        mosaic = torch.zeros((x.shape[0], max_r + (border * 2), max_c + (border * 2))).to(x.dtype)
    else:
        mosaic = torch.zeros((max_r + (border * 2), max_c + (border * 2))).to(x.dtype)
    
    # make tiles
    top_left = TF.pad(x, padding=[border, border, 0, 0], padding_mode='reflect')
    top_right = TF.pad(x, padding=[0, border, border, 0], padding_mode='reflect')
    bottom_left = TF.pad(x, padding=[border, 0, 0, border], padding_mode='reflect')
    bottom_right = TF.pad(x, padding=[0, 0, border, border], padding_mode='reflect')
    
    if len(x.shape) == 3:
        mosaic[:, :max_r + border, :max_c + border] = top_left
        mosaic[:, :max_r + border, border:] = top_right
        mosaic[:, border:, :max_c + border] = bottom_left
        mosaic[:, border:, border:] = bottom_right
    else:
        mosaic[:max_r + border, :max_c + border] = top_left
        mosaic[:max_r + border, border:] = top_right
        mosaic[border:, :max_c + border] = bottom_left
        mosaic[border:, border:] = bottom_right
    
    return mosaic

In [None]:
def tile_img(img, context_img, context_scale, tile_size=256, window_scale=2):
    context_img = reflection_mosiac(context_img, border=tile_size)
    
    n_rows = (img.shape[1] // tile_size) * window_scale + 1
    n_cols = (img.shape[2] // tile_size) * window_scale + 1
    
    tiles = torch.ones(n_rows * n_cols, img.shape[0], tile_size, tile_size, dtype=img.dtype)
    expanded = torch.ones(img.shape[0],
                          n_rows * tile_size // window_scale + tile_size,
                          n_cols * tile_size // window_scale + tile_size,
                          dtype=img.dtype)
    expanded[:, :img.shape[1], :img.shape[2]] = img
    
    context_tiles = torch.ones(n_rows * n_cols, img.shape[0], tile_size, tile_size, dtype=context_img.dtype)
    
    idx = 0
    top_left = []
    for r in range(n_rows):
        for c in range(n_cols):
            r1 = r * tile_size // window_scale
            c1 = c * tile_size // window_scale
            r2 = r1 + tile_size
            c2 = c1 + tile_size
            tiles[idx] = expanded[:, r1:r2, c1:c2]
            top_left.append((r, c))
            
            center_r = r1 + (tile_size / 2)
            center_c = c1 + (tile_size / 2)
            center_r = int(center_r / context_scale) 
            center_c = int(center_c / context_scale) 
            r1 = center_r - tile_size // 2 + tile_size
            c1 = center_c - tile_size // 2 + tile_size
            r2 = r1 + tile_size
            c2 = c1 + tile_size
            context_tiles[idx] = context_img[:, r1:r2, c1:c2]
            idx += 1
    return tiles, context_tiles, top_left       

In [None]:
def prepare_img(img):
    if img.shape[0] != 3:
        img = rearrange(img, 'h w c -> c h w')
    if not isinstance(img, torch.Tensor):
        img = torch.tensor(img)
    if not img.dtype == torch.float32:
        img = TF.convert_image_dtype(img, dtype=torch.float32)
    
    return img 

class TiledHEDataset(Dataset):
    def __init__(self, img, context_img, context_scale, normalize=True, tile_size=256, window_scale=2):
        super().__init__()
        
        img = prepare_img(img)
        context_img = prepare_img(context_img)

        self.img = img
        self.context_img = context_img
        self.context_scale = context_scale
        self.tiles, self.context_tiles, self.idx_to_top_left = tile_img(img, context_img, context_scale,
                                                                        tile_size=tile_size, window_scale=window_scale)
        
        if normalize:
            self.normalize = Normalize((0.771, 0.651, 0.752), (0.229, 0.288, 0.224)) # from HT397B1-H2 ffpe H&E image
        else:
            self.normalize = nn.Identity()
            
    def max_r(self):
        rs, cs = zip(*self.idx_to_top_left)
        return np.max(rs)

    def max_c(self):
        rs, cs = zip(*self.idx_to_top_left)
        return np.max(cs)
        
    def __len__(self):
        return self.tiles.shape[0]

    def __getitem__(self, idx):
        tile, context_tile, (r, c) = self.tiles[idx], self.context_tiles[idx], self.idx_to_top_left[idx]
#         print(tile.shape, context_tile.shape, r, c)
        return {
            'he': self.normalize(tile),
            'context_he': self.normalize(context_tile),
            'top': r,
            'left': c
        }

In [None]:
def add_tile(tile, new, r, c, max_r, max_c, window_scale=2,):
    ts = tile.shape[0] // window_scale
    normal_offset = ts // 2
    edge_offset = tile.shape[0] // 2
    rc, cc = r * ts + edge_offset, c * ts + edge_offset
    trc, tcc = edge_offset, edge_offset

    r1 = rc - normal_offset
    r2 = rc + normal_offset
    c1 = cc - normal_offset
    c2 = cc + normal_offset
    tr1 = trc - normal_offset
    tr2 = trc + normal_offset
    tc1 = tcc - normal_offset
    tc2 = tcc + normal_offset
    if r == 0:
        r1 = rc - edge_offset
        r2 = rc + normal_offset
        tr1 = trc - edge_offset
        tr2 = trc + normal_offset
    if c == 0:
        c1 = cc - edge_offset
        c2 = cc + normal_offset
        tc1 = tcc - edge_offset
        tc2 = tcc + normal_offset
    if r == max_r:
        r1 = rc - normal_offset
        r2 = rc + edge_offset
        tr1 = trc - normal_offset
        tr2 = trc + edge_offset
    if c == max_c:
        c1 = cc - normal_offset
        c2 = cc + edge_offset
        tc1 = tcc - normal_offset
        tc2 = tcc + edge_offset

    new[r1:r2, c1:c2] = tile[tr1:tr2, tc1:tc2]

def predict_he(img, context_img, context_scale, model,
               batch_size=8, reduce_to_tile=True,
               gene_idxs=None, keep=('he', 'exp', 'metagene_activity'),
               window_scale=2, rescale=False, crop=True):
    ds = TiledHEDataset(img, context_img, context_scale)
#     print(len(ds))
#     raise RuntimeError()
    dl = DataLoader(ds, batch_size=batch_size, shuffle=False)
    
    if gene_idxs is None:
        gene_idxs = np.arange(len(model.autokl.genes))
#     print(len(ds))
#     print(gene_idxs)
    import psutil
    print(psutil.virtual_memory())
    model.eval()
    with torch.no_grad():
        batch = next(iter(dl))
        he_x, he_context_x, top, left = batch['he'], batch['context_he'], batch['top'], batch['left']

        if next(iter(model.autokl.parameters())).is_cuda:
            he_x, he_context_x = he_x.cuda(), he_context_x.cuda()

        result = model.autokl(he_x, he_context_x, reduce_to_tile=reduce_to_tile)
        item = {k:v[0].detach().cpu()
               for k, v in result.items()
               if k in keep}
        item['exp'] = item['exp'][:, :, gene_idxs]
        if 'he' in item.keys():
            item['he'] = rearrange(item['he'], 'c h w -> h w c')
    max_r, max_c = ds.max_r(), ds.max_c()
    print(max_r, max_c, window_scale)
    for k, tile in item.items():
        if k in keep:
            print(tile.shape[0] * (max_r + 2) // window_scale, tile.shape[1] * (max_c + 2) // window_scale, tile.shape[2], tile.dtype)
    print(psutil.virtual_memory())
    key_to_new = {k:torch.ones(tile.shape[0] * (max_r + 2) // window_scale,
                               tile.shape[1] * (max_c + 2) // window_scale,
                               tile.shape[2],
                               dtype=tile.dtype)
                 for k, tile in item.items()
                 if k in keep}
    
    model.eval()
    with torch.no_grad():
        for batch in dl:
            he_x, he_context_x, top, left = batch['he'], batch['context_he'], batch['top'], batch['left']

            if next(iter(model.autokl.parameters())).is_cuda:
                he_x, he_context_x = he_x.cuda(), he_context_x.cuda()

            result = model.autokl(he_x, he_context_x, reduce_to_tile=reduce_to_tile)

            for i in range(result['exp'].shape[0]):
                item = {k:v[i].detach().cpu()
                       for k, v in result.items()
                       if k in keep}
                item['exp'] = item['exp'][:, :, gene_idxs]
                if 'he' in item.keys():
                    item['he'] = rearrange(item['he'], 'c h w -> h w c')
                r, c = top[i], left[i]
                for k in keep:
                    tile = item[k]
                    add_tile(tile, key_to_new[k], r, c, max_r, max_c, window_scale=window_scale)
         
    
    if rescale:
        for k, new in key_to_new.items():
            if k == 'he':
                new -= new.min()
                new /= new.max()
            else:
                new -= new.amin(dim=(0, 1))
                new /= new.amax(dim=(0, 1))

    return key_to_new


In [None]:
model = model.cuda()

In [None]:
# ds = TiledHEDataset(train_adata.uns['rescaled_he']['2X_trimmed'],
#                     train_adata.uns['rescaled_he']['8X_trimmed'],
#                     4)

In [None]:
# show_genes = ['EPCAM', 'KRT18', 'CD8A', 'IL7R', 'MS4A1', 'PECAM1', 'BGN']
show_genes = ['KRT18', 'IL7R', 'PECAM1', 'BGN', 'INS']
# show_genes = model.autokl.genes
gene_idxs = [i for i, g in enumerate(model.autokl.genes) if g in show_genes]
show_genes = model.autokl.genes[gene_idxs]
key_to_retiled = predict_he(he, context_he, 4, model,
                            gene_idxs=gene_idxs, rescale=True, window_scale=2,
                            keep=('exp',))

In [None]:
for k, img in key_to_retiled.items():
    print(k, img.shape)

In [None]:
plt.imshow(key_to_retiled['he'])

In [None]:
for i, g in enumerate(show_genes):
    plt.imshow(key_to_retiled['exp'][:, :, i])
    plt.title(g)
    plt.show()

In [None]:
for i in range(key_to_retiled['metagene_activity'].shape[-1]):
    plt.imshow(key_to_retiled['metagene_activity'][:, :, i])
    plt.title(i)
    plt.show()

In [None]:
key_to_retiled = predict_he(he, context_he, 4, model, gene_idxs=gene_idxs,
                            reduce_to_tile=False, batch_size=2, rescale=True, window_scale=2)

In [None]:
for i, g in enumerate(show_genes):
    plt.imshow(np.log1p(key_to_retiled['exp'][:, :, i]))
    plt.title(g)
    plt.show()

In [None]:
cr1, cr2 = 3000, 4000
cc1, cc2 = 3000, 4000
for i, g in enumerate(show_genes):
    plt.imshow(key_to_retiled['exp'][cr1:cr2, cc1:cc2, i])
    plt.title(g)
    plt.show()

In [None]:
cr1, cr2 = 3000, 4000
cc1, cc2 = 3000, 4000
for i in range(key_to_retiled['metagene_activity'].shape[-1]):
    plt.imshow(key_to_retiled['metagene_activity'][cr1:cr2, cc1:cc2, i])
    plt.title(i)
    plt.show()

In [None]:
plt.imshow(he[cr1:cr2, cc1:cc2])

In [None]:
spot_mask = a.uns['rescaled_spot_masks']['2X_trimmed']
spot_mask

In [None]:
plt.imshow(spot_mask[cr1:cr2, cc1:cc2]>0)

In [None]:
## save gene expression image
exp = TF.convert_image_dtype(key_to_retiled['exp'][:he.shape[0], :he.shape[1]], torch.uint8)
exp.shape

In [None]:
plt.imshow(exp[:, :, 0])

In [None]:
torch.save(exp, '../data/annotations/pdac/HT270P1-S1H1U1_exp_8x8.pt')

In [None]:
## save gene expression image
meta = TF.convert_image_dtype(key_to_retiled['metagene_activity'][:he.shape[0], :he.shape[1]], torch.uint8)
meta.shape

In [None]:
plt.imshow(meta[:, :, 0])

In [None]:
torch.save(meta, '../data/annotations/pdac/HT270P1-S1H1U1_meta_8x8.pt')

In [None]:
import seaborn as sns
metagenes = model.autokl.metagenes.detach().cpu().numpy()
metagenes = pd.DataFrame(data=metagenes, columns=model.autokl.genes, index=np.arange(metagenes.shape[0]))
metagenes = metagenes.loc[:, show_genes]
sns.clustermap(data=metagenes)

#### apply directly to H&E