In [1]:
import anndata
import torch
import logging
from torch.distributions import kl_divergence
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, Independent
from scvi.models.distributions import ZeroInflatedNegativeBinomial, NegativeBinomial
from scETM.model import BaseCellModel
from scETM.model.model_utils import get_kl, get_fully_connected_layers
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import os
from typing import *
from collections import OrderedDict
sc.set_figure_params(figsize=(10, 10), fontsize=10, dpi=120, dpi_save=250)

In [2]:
@torch.jit.script
def gaussian_analytical_kl(mu1, mu2, logsigma1, logsigma2):
    return -0.5 + logsigma2 - logsigma1 + 0.5 * (logsigma1.exp() ** 2 + (mu1 - mu2) ** 2) / (logsigma2.exp() ** 2)


@torch.jit.script
def draw_gaussian_diag_samples(mu, logsigma):
    eps = torch.empty_like(mu).normal_(0., 1.)
    return torch.exp(logsigma) * eps + mu


class norm_act_drop(torch.nn.Module):
    def __init__(self, size: int, norm_module: str = 'batch', activation: str = 'ReLU', dropout_prob: float = 0.1, final_layer: bool = False):
        super().__init__()
        self.norm = self.get_norm_layer(size, norm_module) if norm_module != 'none' else None
        self.activation, self.dropout = None, None
        if not final_layer:
            self.activation = getattr(torch.nn, activation)()
            self.dropout = torch.nn.Dropout(dropout_prob) if dropout_prob else None

    @staticmethod
    def get_norm_layer(size, norm_module='none'):
        if norm_module == 'none':
            return torch.nn.Identity()
        elif norm_module == 'batch':
            return torch.nn.BatchNorm1d(size)
        elif norm_module == 'instance':
            return torch.nn.InstanceNorm1d(size)
        elif norm_module == 'layer':
            return torch.nn.LayerNorm(size)
        else:
            return NotImplementedError(f"Not Implemented norm layer {norm_module}")

    def forward(self, x):
        if self.norm is not None:
            x = self.norm(x)
        if self.activation is not None:
            x = self.activation(x)
        if self.dropout is not None:
            x = self.dropout(x)
        return x


class Layer(torch.nn.Module):
    def __init__(self, input_channels: int, output_channels: int, bias: bool = True):
        super().__init__()
        self.layers = nn.Sequential(
            norm_act_drop(input_channels),
            nn.Linear(input_channels, output_channels, bias = bias)
        )
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.bias = bias
    
    def forward(self, x):
        return self.layers(x)

    def __repr__(self):
        return f'Layer(input_channels={self.input_channels}, output_channels={self.output_channels}, bias={self.bias})'


class Block(torch.nn.Module):
    def __init__(self, input_channels: int, output_channels: int, hidden_layers: int, residual: bool = True, pre_norm_act_drop: bool = True):
        super().__init__()
        self.pre_layer = Layer(input_channels, output_channels) if pre_norm_act_drop else nn.Linear(input_channels, output_channels)
        layers = []
        for _ in range(hidden_layers):
            layers.append(Layer(output_channels, output_channels))
        self.main_layers = nn.Sequential(*layers)
        self.residual = residual

    def forward(self, x):
        x = self.pre_layer(x)
        return (self.main_layers(x) + x) if self.residual else self.main_layers(x)


class Block2(torch.nn.Module):
    def __init__(self, input_channels: int, output_channels: int, hidden_layers: int, residual: bool = True, pre_norm_act_drop: bool = True):
        super().__init__()
        layers = []
        for i in range(hidden_layers):
            layers.append(nn.Linear(input_channels, input_channels) if i == 0 and not pre_norm_act_drop else Layer(input_channels, input_channels))
        self.main_layers = nn.Sequential(*layers)
        self.post_layer = Layer(input_channels, output_channels)
        self.residual = residual

    def forward(self, x):
        x = (self.main_layers(x) + x) if self.residual else self.main_layers(x)
        return self.post_layer(x)


class Encoder(torch.nn.Module):
    def __init__(self, input_channels: int, hidden_sizes: Sequence[int], hidden_layers: int):
        super().__init__()
        self.blocks = nn.ModuleDict()
        for i, size in enumerate(hidden_sizes):
            self.blocks[str(size)] = Block(input_channels, size, hidden_layers, pre_norm_act_drop = bool(i))
            input_channels = size
    
    def forward(self, x):
        hiddens = {}
        for size, block in self.blocks.items():
            x = block(x)
            hiddens[size] = x
        return hiddens


class DecoderBlock(torch.nn.Module):
    def __init__(self, input_channels: int, output_channels: int, hidden_layers: int, pre_norm_act_drop: bool = True):
        super().__init__()
        self.prior = Block2(input_channels, input_channels * 3, hidden_layers, residual = False, pre_norm_act_drop = pre_norm_act_drop)
        self.posterior = Block2(input_channels * 2, input_channels * 2, hidden_layers, residual = False, pre_norm_act_drop = pre_norm_act_drop)
        self.output_projection = Block2(input_channels * 2, output_channels, hidden_layers)

    def sample_posterior(self, dec_hidden, enc_hidden):
        if dec_hidden.shape != enc_hidden.shape:
            dec_hidden = dec_hidden.expand(enc_hidden.shape)
        qm, qv = self.posterior(torch.cat([dec_hidden, enc_hidden], dim=-1)).chunk(2, dim=-1)
        pm, pv, dec_adduct = self.prior(dec_hidden).chunk(3, dim=-1)
        qv = qv.clamp(-10, 10)
        pv = pv.clamp(-10, 10)
        dec_hidden = dec_hidden + dec_adduct
        if self.training:
            z = draw_gaussian_diag_samples(qm, qv)
        else:
            z = qm
        kl = gaussian_analytical_kl(qm, pm, qv, pv).sum(-1).mean()
        return z, dec_hidden, kl

    def sample_prior(self, dec_hidden):
        pm, pv, dec_adduct = self.prior(dec_hidden).chunk(3, dim=-1)
        dec_hidden = dec_hidden + dec_adduct
        z = draw_gaussian_diag_samples(pm, pv)
        return z, dec_hidden

    def forward(self, dec_hidden, enc_hidden):
        z, dec_hidden, kl = self.sample_posterior(dec_hidden, enc_hidden)
        dec_hidden = self.output_projection(torch.cat([dec_hidden, z], dim=-1))
        return z, dec_hidden, kl


class Decoder(torch.nn.Module):
    def __init__(self, output_channels: int, hidden_sizes: Sequence[int], hidden_layers: int):
        super().__init__()
        self.blocks = nn.ModuleDict()
        hidden_sizes = list(reversed(hidden_sizes)) + [output_channels]
        for i, size in enumerate(hidden_sizes[:-1]):
            self.blocks[str(size)] = DecoderBlock(size, hidden_sizes[i + 1], hidden_layers, pre_norm_act_drop = bool(i))
        self.dec_hidden_init = nn.Parameter(torch.randn((1, hidden_sizes[0])))

    def forward(self, enc_hiddens):
        dec_hidden = self.dec_hidden_init
        zs, kls = OrderedDict(), OrderedDict()
        for size, block in self.blocks.items():
            zs[size], dec_hidden, kls[size] = block(dec_hidden, enc_hiddens[size])
        return zs, dec_hidden, kls

In [3]:
class Model(BaseCellModel):

    emb_names = ['z']
    clustering_input = 'z'
    
    def __init__(self, n_genes, n_batches, input_batch_id = False,
        hidden_sizes = (256, 128),
        hidden_layers = 1,
        norm_cells = True,
        normed_loss = True,
        device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    ):
        super().__init__(n_genes, n_batches, need_batch = input_batch_id, device = device)

        self.norm_cells = norm_cells
        self.normed_loss = normed_loss
        self.encoder = Encoder(n_genes, hidden_sizes, hidden_layers)
        self.decoder = Decoder(n_genes, hidden_sizes, hidden_layers)

    def forward(self, data_dict, hyper_param_dict=dict(val=True)):
        cells, library_size = data_dict['cells'], data_dict['library_size']
        normed_cells = cells / library_size
        input_cells = normed_cells if self.norm_cells else cells
        # if self.input_batch_id:
        #     input_cells = torch.cat((input_cells, self._get_batch_indices_oh(data_dict)), dim=1)
        
        hiddens = self.encoder(input_cells)
        zs, recon_logit, kls = self.decoder(hiddens)
        recon_log = F.log_softmax(recon_logit, dim=-1)
        nll = (-recon_log * normed_cells if self.normed_loss else cells).sum(-1).mean()

        fwd_dict = dict(
            z = torch.cat(list(zs.values()), dim=-1),
            recon_log=recon_log,
            nll = nll
        )
        fwd_dict.update(zs)

        if not self.training:
            return fwd_dict

        total_kl = 0.
        for kl in kls.values():
            total_kl += kl
        loss = nll + hyper_param_dict['beta'] * total_kl

        record = dict(loss=loss, nll=nll, total_kl=total_kl)
        record = {k: v.detach().item() for k, v in record.items()}
        return loss, fwd_dict, record

In [4]:
from scETM import UnsupervisedTrainer, evaluate
adata = anndata.read_h5ad("../../../data/TM/FACS.h5ad")
adata.obs['cell_types'] = adata.obs.cell_ontology_class
adata.obs['batch_indices'] = adata.obs['mouse.id']
dataset_name = 'TM'
resolutions = (0.24, 0.32, 0.48, 0.64, 0.8, 1, 1.3)
# dataset_name = 'cortex'
# adata = anndata.read_h5ad("../../../data/cortex/cortex_full.h5ad")
# resolutions = (0.08, 0.12, 0.16, 0.24, 0.32, 0.48, 0.64)
model = Model(adata.n_vars, adata.obs.batch_indices.nunique(), hidden_sizes = (128, 64, 32), hidden_layers = 2).to(torch.device('cuda'))
trainer = UnsupervisedTrainer(
        model,
        adata,
        ckpt_dir = '../results/HVAE',
        init_lr = 5e-3,
        lr_decay = 6e-5,
        batch_size = 4000,
        train_instance_name = f"{dataset_name}_1e-9"
)
trainer.train(n_epochs = 4000, eval_every = 1000, max_kl_weight=1e-9, eval_kwargs = dict(resolutions = resolutions))

[2021-04-28 20:13:31,062] INFO - scETM.logging_utils: UnsupervisedTrainer.__init__(Model(
  (encoder): Encoder(
    (blocks): ModuleDict(
      (128): Block(
        (pre_layer): Linear(in_features=22964, out_features=128, bias=True)
        (main_layers): Sequential(
          (0): Layer(input_channels=128, output_channels=128, bias=True)
          (1): Layer(input_channels=128, output_channels=128, bias=True)
        )
      )
      (64): Block(
        (pre_layer): Layer(input_channels=128, output_channels=64, bias=True)
        (main_layers): Sequential(
          (0): Layer(input_channels=64, output_channels=64, bias=True)
          (1): Layer(input_channels=64, output_channels=64, bias=True)
        )
      )
      (32): Block(
        (pre_layer): Layer(input_channels=64, output_channels=32, bias=True)
        (main_layers): Sequential(
          (0): Layer(input_channels=32, output_channels=32, bias=True)
          (1): Layer(input_channels=32, output_channels=32, bias=True)
  

KeyboardInterrupt: 

In [None]:
model.get_cell_embeddings_and_nll(adata, batch_size = 4000, emb_names=['z', '128', '64', '32'])
from scETM import evaluate
evaluate(adata, embedding_key = '128', plot_dir = trainer.ckpt_dir, plot_fname = f'cortex_z_epoch{int(trainer.epoch)}_128', resolutions = resolutions, n_jobs = 1)
evaluate(adata, embedding_key = '64', plot_dir = trainer.ckpt_dir, plot_fname = f'cortex_z_epoch{int(trainer.epoch)}_64', resolutions = resolutions, n_jobs = 1)
evaluate(adata, embedding_key = '32', plot_dir = trainer.ckpt_dir, plot_fname = f'cortex_z_epoch{int(trainer.epoch)}_32', resolutions = resolutions, n_jobs = 1)