<a href="https://colab.research.google.com/github/kicysh/final_task_of_world_model_lecture_2021/blob/main/src/LDVAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# pip

In [None]:
!pip install scanpy

# data

In [None]:
!mkdir data
!gsutil cp gs://h5ad/2019-02-Pijuan-Sala-et-al-Nature/pijuan_sala_atlas.h5ad /content/data
path_of_data = '/content/data/pijuan_sala_atlas.h5ad'


# setting


In [2]:
# setting
SETTING_BATCHNORM_EPS = 0.001
SETTING_BATCHNORM_MOMENTUM = 0.01
SETTING_ENCODER_Z_DROPOUT_P = 0.1
SETTING_ENCODER_L_DROPOUT_P = 0.1
SETTING_HIDDEN_DIM = 128
SETTING_EPS = 1e-8

USE_CUDA = True

# model

In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import logsumexp
from torch.distributions import Normal, kl_divergence

In [None]:
import scanpy as sc
adata = sc.read_h5ad(path_of_data)
adata

In [None]:
#from math import ldexp

class LDVAE(nn.Module):
    """
    :param genes_cnt: Number of input genes
    :param latent_dim: Dimensionality of the latent space 
    """
    def __init__(
        self,
        genes_cnt: int, 
        latent_dim: int = 20
    ):
        super(LDVAE,self).__init__()
        self.local_l_mean = None
        self.local_l_std = None
        self.eps = SETTING_EPS

        self.theta = nn.Parameter(torch.randn(genes_cnt))
        self.encoder_z = nn.Sequential(
            nn.Linear(genes_cnt, SETTING_HIDDEN_DIM),
            nn.BatchNorm1d(SETTING_HIDDEN_DIM,
                           eps=SETTING_BATCHNORM_EPS, 
                           momentum=SETTING_BATCHNORM_MOMENTUM),
            nn.ReLU(),
            nn.Dropout(SETTING_ENCODER_Z_DROPOUT_P)
        )
        self.encoder_z_mean = nn.Lineer(SETTING_HIDDEN_DIM,latent_dim)
        self.encoder_z_std = nn.Lineer(SETTING_HIDDEN_DIM,latent_dim)

        self.encoder_l = nn.Sequential(
            nn.Linear(genes_cnt, SETTING_HIDDEN_DIM),
            nn.BatchNorm1d(SETTING_HIDDEN_DIM,
                           eps=SETTING_BATCHNORM_EPS, 
                           momentum=SETTING_BATCHNORM_MOMENTUM),
            nn.ReLU(),
            nn.Dropout(SETTING_ENCODER_L_DROPOUT_P)
        )
        self.encoder_l_mean = nn.Lineer(SETTING_HIDDEN_DIM, 1)
        self.encoder_l_std = nn.Lineer(SETTING_HIDDEN_DIM, 1)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, genes_cnt),
            nn.BatchNorm1d(genes_cnt,
                           eps=SETTING_BATCHNORM_EPS, 
                           momentum=SETTING_BATCHNORM_MOMENTUM)
        )
    
    def forward(self,x):
        x_z = self.encoder_z(x)
        z_mean = self.encoder_z_mean(x_z)
        z_std = self.encoder_z_std(x_z)
        z = Normal(z_mean, z_std.sqrt()).rsample()

        x_l = self.encoder_l(x)
        l_mean = self.encoder_l_mean(x_l)
        l_std = self.encoder_l_std(x_l)
        library = Normal(l_mean, l_std.sqrt()).rsample()

        y = self.decoder(z)
        y = torch.exp(library)*torch.softmax(y, dim=-1)
        return [z_mean, z_std, z], [l_mean, l_std, library], y

    def set_local_l_mean_and_std(self, data):
        masked_log_sum =np.ma.log(data.sum(axis=1))
        log_counts = masked_log_sum.filled(0)
        self.local_mean = (np.mean(log_counts).reshape(-1, 1)).astype(np.float32)
        self.local_var = (np.var(log_counts).reshape(-1, 1)).astype(np.float32)
        return self.local_mean, self.local_var

    def reconst_error(self,x, mu, theta):
        eps = SETTING_EPS
        log_theta_mu_eps = torch.log(theta + mu + eps)

        res = (
            theta * (torch.log(theta + eps) - log_theta_mu_eps)
            + x * (torch.log(mu + eps) - log_theta_mu_eps)
            + torch.lgamma(x + theta)
            - torch.lgamma(theta)
            - torch.lgamma(x + 1)
        )

    return res

    def loss(self,x):
        zs,ls,y = self.forward(x)
        z_mean, z_std, z = zs
        l_mean, l_std, library = ls

        mean, std = torch.zeros_like(z_mean), torch.ones_like(z_std)
        kl_z = kl_divergence(Normal(z_mean,torch.sqrt(z_std)), Normal(mean, std)).sum(dim=1)

        mean, std = self.local_l_mean, self.local_l_std
        kl_l = kl_divergence(Normal(l_mean,torch.sqrt(l_std)), Normal(mean, torch.sqrt(std))).sum(dim=1)

        reconst = self.reconst_error(x, mu=y, theta=torch.exp(self.theta)).sum(dim=-1)
        
        return reconst, kl_l ,kl_z

In [None]:
model = LDVAE(genes_cnt = adata.n_vars,
              latent_dim = 20)
model

In [None]:
# dataloder

In [None]:
# train

n_epochs  = 10
optimizer = optim.Adam(model.parameters(), lr=0.0002,  betas=(0.5,0.999))

deice = 'cuda'  if USE_CUDA else 'cpu'

for epoch in range(n_epochs):
    losses = []

    model.train()
    for x, _ in dataloader_train:
        x = x.to(device)
        x = x.view(-1,1,28,28)

        model.zero_grad()

        # forawrd and loss
        reconst, kl_l ,kl_z = model.loss(x)
        loss = torch.mean(reconst+kl_l +kl_z)

        loss.backward()
        optimizer.step()

        losses.append(loss.cpu().detach().numpy())

    losses_val1 = []
    losses_val2 = []
    losses_val3 = []
    model.eval()
    for x, t in dataloader_valid:

        x = x.to(device)
        x = x.view(-1,1,28,28)


        reconst, kl_l ,kl_z = model.loss(x)

        losses_val1.append(torch.mean(reconst).cpu().detach().numpy())
        losses_val2.append(torch.mean(kl_l).cpu().detach().numpy())
        losses_val3.append(torch.mean(kl_z).cpu().detach().numpy())

    print('EPOCH: %d    Train Loss: %lf    Valid rec: %lf    Valid kl_l: %lf    Valid kl_z: %lf' %
            (epoch+1, np.average(losses),np.average(losses_val1),np.average(losses_val2),np.average(losses_val3)))
