<a href="https://colab.research.google.com/github/kicysh/final_task_of_world_model_lecture_2021/blob/main/src/LDVAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# pip

In [1]:
!pip install scanpy scikit-misc



# data

In [2]:
!mkdir data
!gsutil cp gs://h5ad/2019-02-Pijuan-Sala-et-al-Nature/pijuan_sala_atlas.h5ad /content/data
path_of_data = '/content/data/pijuan_sala_atlas.h5ad'


mkdir: cannot create directory ‘data’: File exists
Copying gs://h5ad/2019-02-Pijuan-Sala-et-al-Nature/pijuan_sala_atlas.h5ad...
\ [1 files][  1.0 GiB/  1.0 GiB]   46.6 MiB/s                                   
Operation completed over 1 objects/1.0 GiB.                                      


# setting


In [3]:
# setting
SETTING_BATCHNORM_EPS = 0.001
SETTING_BATCHNORM_MOMENTUM = 0.01
SETTING_ENCODER_Z_DROPOUT_P = 0.1
SETTING_ENCODER_L_DROPOUT_P = 0.1
SETTING_HIDDEN_DIM = 128
SETTING_EPS = 1e-8

USE_CUDA = True
SETTING_BATCH_SIZE = 256

# model

In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import logsumexp
from torch.distributions import Normal, kl_divergence


rng = np.random.RandomState(1234)
random_state = 42


In [5]:
import scanpy as sc
adata = sc.read_h5ad(path_of_data)
adata

AnnData object with n_obs × n_vars = 139331 × 29452
    obs: 'barcode', 'sample', 'stage', 'sequencing.batch', 'theiler', 'doub.density', 'doublet', 'cluster', 'cluster.sub', 'cluster.stage', 'cluster.theiler', 'stripped', 'celltype', 'colour', 'umapX', 'umapY', 'haem_gephiX', 'haem_gephiY', 'haem_subclust', 'endo_gephiX', 'endo_gephiY', 'endo_trajectoryName', 'endo_trajectoryDPT', 'endo_gutX', 'endo_gutY', 'endo_gutDPT', 'endo_gutCluster'
    var: 'gene_name'

In [29]:
idx = adata.obs.query('not doublet').index
idx = np.random.choice(idx, 20000, replace=False)
adata = adata[idx]
sc.pp.highly_variable_genes(adata,n_top_genes=1000 ,flavor='seurat_v3')
print(adata)
#sc.pp.filter_genes(adata, min_cells=100)


  self.data[key] = value


AnnData object with n_obs × n_vars = 20000 × 29452
    obs: 'barcode', 'sample', 'stage', 'sequencing.batch', 'theiler', 'doub.density', 'doublet', 'cluster', 'cluster.sub', 'cluster.stage', 'cluster.theiler', 'stripped', 'celltype', 'colour', 'umapX', 'umapY', 'haem_gephiX', 'haem_gephiY', 'haem_subclust', 'endo_gephiX', 'endo_gephiY', 'endo_trajectoryName', 'endo_trajectoryDPT', 'endo_gutX', 'endo_gutY', 'endo_gutDPT', 'endo_gutCluster'
    var: 'gene_name', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
    uns: 'hvg'


In [47]:
gene_index = adata.var[adata.var['highly_variable']].index

In [31]:
adata_df = adata.to_df()[gene_index]

In [32]:
#from math import ldexp

class LDVAE(nn.Module):
    """
    :param genes_cnt: Number of input genes
    :param latent_dim: Dimensionality of the latent space 
    """
    def __init__(
        self,
        genes_cnt: int, 
        latent_dim: int = 20
    ):
        super(LDVAE,self).__init__()
        self.local_l_mean = None
        self.local_l_std = None
        self.eps = SETTING_EPS

        self.theta = nn.Parameter(torch.randn(genes_cnt))
        self.encoder_z = nn.Sequential(
            nn.Linear(genes_cnt, SETTING_HIDDEN_DIM),
            nn.BatchNorm1d(SETTING_HIDDEN_DIM,
                           eps=SETTING_BATCHNORM_EPS, 
                           momentum=SETTING_BATCHNORM_MOMENTUM),
            nn.ReLU(),
            nn.Dropout(SETTING_ENCODER_Z_DROPOUT_P)
        )
        self.encoder_z_mean = nn.Linear(SETTING_HIDDEN_DIM,latent_dim)
        self.encoder_z_std = nn.Linear(SETTING_HIDDEN_DIM,latent_dim)

        self.encoder_l = nn.Sequential(
            nn.Linear(genes_cnt, SETTING_HIDDEN_DIM),
            nn.BatchNorm1d(SETTING_HIDDEN_DIM,
                           eps=SETTING_BATCHNORM_EPS, 
                           momentum=SETTING_BATCHNORM_MOMENTUM),
            nn.ReLU(),
            nn.Dropout(SETTING_ENCODER_L_DROPOUT_P)
        )
        self.encoder_l_mean = nn.Linear(SETTING_HIDDEN_DIM, 1)
        self.encoder_l_std = nn.Linear(SETTING_HIDDEN_DIM, 1)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, genes_cnt),
            nn.BatchNorm1d(genes_cnt,
                           eps=SETTING_BATCHNORM_EPS, 
                           momentum=SETTING_BATCHNORM_MOMENTUM)
        )


    def forward(self,x):
        x_z = self.encoder_z(x)
        z_mean = self.encoder_z_mean(x_z)
        z_std = torch.exp(self.encoder_z_std(x_z)) 
        z = Normal(z_mean, z_std.sqrt()).rsample()

        x_l = self.encoder_l(x)
        l_mean = self.encoder_l_mean(x_l)
        l_std = torch.exp(self.encoder_l_std(x_l))
        library = Normal(l_mean, l_std.sqrt()).rsample()

        y = self.decoder(z)
        y = torch.exp(library)*torch.softmax(y, dim=-1)
        return [z_mean, z_std, z], [l_mean, l_std, library], y


    def set_local_l_mean_and_std(self, data):
        masked_log_sum =np.ma.log(data.sum(axis=1))
        log_counts = masked_log_sum.filled(0)
        self.local_l_mean = (np.mean(log_counts).reshape(-1, 1)).astype(np.float32)[0][0]
        self.local_l_std = (np.var(log_counts).reshape(-1, 1)).astype(np.float32)[0][0]
        return self.local_l_mean, self.local_l_std


    def reconst_error(self,x, mu, theta):
        eps = SETTING_EPS
        log_theta_mu_eps = torch.log(theta + mu + eps)

        res = (
            theta * (torch.log(theta + eps) - log_theta_mu_eps)
            + x * (torch.log(mu + eps) - log_theta_mu_eps)
            + torch.lgamma(x + theta)
            - torch.lgamma(theta)
            - torch.lgamma(x + 1)
        )
        return res


    def loss(self,x):
        zs,ls,y = self.forward(x)
        z_mean, z_std, z = zs
        l_mean, l_std, library = ls

        mean, std = torch.zeros_like(z_mean), torch.ones_like(z_std)
        kl_z = kl_divergence(Normal(z_mean,torch.sqrt(z_std)), Normal(mean, std)).sum(dim=1)

        mean, std = self.local_l_mean*torch.ones_like(l_mean), self.local_l_std*torch.ones_like(l_std)
        kl_l = kl_divergence(Normal(l_mean,torch.sqrt(l_std)), Normal(mean, torch.sqrt(std))).sum(dim=1)

        reconst = self.reconst_error(x, mu=y, theta=torch.exp(self.theta)).sum(dim=-1)        
        return reconst, kl_l ,kl_z

In [39]:
model = LDVAE(genes_cnt = len(adata_df.columns),
              latent_dim = 20)
model.set_local_l_mean_and_std(adata_df.values)
model

LDVAE(
  (encoder_z): Sequential(
    (0): Linear(in_features=1000, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
  )
  (encoder_z_mean): Linear(in_features=128, out_features=20, bias=True)
  (encoder_z_std): Linear(in_features=128, out_features=20, bias=True)
  (encoder_l): Sequential(
    (0): Linear(in_features=1000, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
  )
  (encoder_l_mean): Linear(in_features=128, out_features=1, bias=True)
  (encoder_l_std): Linear(in_features=128, out_features=1, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=20, out_features=1000, bias=True)
    (1): BatchNorm1d(1000, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
  )
)

In [36]:
class GenesDataset(torch.utils.data.Dataset):
    def __init__(self, 
                 adata, 
                 transform=None, 
                 target_transform=None):
        self.data = adata
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data)

    def __getattr__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data = self.data[idx]
        #label = self.img_labels.iloc[idx, 1]
        if self.transform:
            data = self.transform(data)
        #if self.target_transform:
        #    label = self.target_transform(label)
        return data

In [41]:
# dataloder
dataset = GenesDataset(adata_df.values)


n_samples = len(dataset) 
train_size = int(n_samples* 0.65)
val_size = int(n_samples * 0.15)
test_size = n_samples - train_size - val_size 

dataset_train ,dataset_valid, dataset_test = \
        torch.utils.data.random_split(dataset, [train_size, val_size,test_size])

dataloader_train = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=SETTING_BATCH_SIZE,
    shuffle=True
)

dataloader_valid = torch.utils.data.DataLoader(
    dataset_valid,
    batch_size=SETTING_BATCH_SIZE,
    shuffle=True
)

#dataloader_test = torch.utils.data.DataLoader(
#    dataset_test,
#    batch_size=SETTING_BATCH_SIZE,
#    shuffle=True
#)

In [44]:
from numpy.ma.core import nonzero
# train
model = LDVAE(genes_cnt = len(adata_df.columns),
              latent_dim = 20)
model.set_local_l_mean_and_std(adata_df.values)
model

n_epochs  = 100
optimizer = optim.Adam(model.parameters(), lr=0.002,  betas=(0.5,0.999))

device = 'cuda'  if USE_CUDA else 'cpu'
model.to(device)
_x = nonzero

for epoch in range(n_epochs):
    losses = []

    model.train()
    for x in dataloader_train:
        x = x.to(device)
        _x = x

        model.zero_grad()

        # forawrd and loss
        reconst, kl_l ,kl_z = model.loss(x)
        loss = torch.mean(-reconst+kl_l*0.05 +kl_z*0.25)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.cpu().detach().numpy())

    losses_val1 = []
    losses_val2 = []
    losses_val3 = []
    model.eval()
    for x in dataloader_valid:

        x = x.to(device)


        reconst, kl_l ,kl_z = model.loss(x)

        losses_val1.append(torch.mean(-reconst).cpu().detach().numpy())
        losses_val2.append(torch.mean(kl_l).cpu().detach().numpy())
        losses_val3.append(torch.mean(kl_z).cpu().detach().numpy())

    print('EPOCH: %d    Train Loss: %lf    Valid rec: %lf    Valid kl_l: %lf    Valid kl_z: %lf' %
            (epoch+1, np.average(losses),np.average(losses_val1),np.average(losses_val2),np.average(losses_val3)))


EPOCH: 1    Train Loss: 3943.533203    Valid rec: 1935.536743    Valid kl_l: 12.102735    Valid kl_z: 52.642838
EPOCH: 2    Train Loss: 1370.984985    Valid rec: 1326.448364    Valid kl_l: 6.463783    Valid kl_z: 29.813789
EPOCH: 3    Train Loss: 1133.139648    Valid rec: 1163.127808    Valid kl_l: 3.669250    Valid kl_z: 28.958839
EPOCH: 4    Train Loss: 1046.825806    Valid rec: 980.517151    Valid kl_l: 2.604321    Valid kl_z: 33.954975
EPOCH: 5    Train Loss: 998.076294    Valid rec: 967.799561    Valid kl_l: 3.338364    Valid kl_z: 37.366375
EPOCH: 6    Train Loss: 954.183105    Valid rec: 972.898865    Valid kl_l: 3.699333    Valid kl_z: 36.882259
EPOCH: 7    Train Loss: 914.192017    Valid rec: 858.116699    Valid kl_l: 2.701347    Valid kl_z: 37.534718
EPOCH: 8    Train Loss: 880.983215    Valid rec: 814.440735    Valid kl_l: 2.706646    Valid kl_z: 37.524574
EPOCH: 9    Train Loss: 855.980530    Valid rec: 800.635986    Valid kl_l: 2.544543    Valid kl_z: 38.962627
EPOCH: 10  

In [51]:
import pandas as pd

_w = model.decoder[0].weight
bn = model.decoder[1]
sigma = torch.sqrt(bn.running_var + bn.eps)
gamma = bn.weight
b = gamma / sigma
bI = torch.diag(b)
loadings = torch.matmul(bI, _w)
loadings = loadings.detach().cpu().numpy()

W = pd.DataFrame(loadings, index=gene_index)

In [52]:
W

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
ENSMUSG00000025902,-0.114847,0.014870,-0.748335,-0.063984,-0.144368,-0.545247,0.159475,-0.532325,0.221236,-0.063225,-0.833941,0.424912,0.127992,-0.549809,0.059238,0.109194,0.531668,-0.102879,-0.657995,-0.244427
ENSMUSG00000025927,-0.049106,-0.314564,-0.106663,0.515449,-0.428855,-0.189759,-0.042427,-0.180187,-0.017321,0.513675,0.607979,-0.400233,-0.130607,0.230544,0.092221,-0.049951,0.339507,-0.163870,0.025882,-0.129630
ENSMUSG00000026124,0.962166,-0.021462,-0.330671,-0.619310,-0.027634,-0.103679,-0.301403,-0.658163,-0.098231,-0.417534,-0.348092,0.102325,0.652855,-0.099561,-0.117962,0.027653,-0.273133,-0.023976,0.239798,-0.334595
ENSMUSG00000026043,0.456029,0.133721,0.117041,-0.044423,0.051522,-0.259104,0.188649,0.072159,0.050188,0.447276,0.428279,-0.116995,-0.008734,0.256060,0.100138,-0.421274,-0.102574,0.067202,0.233048,0.070419
ENSMUSG00000045954,0.677873,0.224888,-0.042514,0.125330,-0.433727,-0.706035,0.597891,-0.150792,-0.602343,0.005353,0.928849,-0.251801,0.049986,-0.643417,-0.406187,-0.512795,-0.015923,-1.032671,-0.338712,-0.323017
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ENSMUSG00000025219,0.151430,0.231536,-0.541314,-0.674195,-0.303684,-0.306293,-0.231648,-0.617818,0.060072,0.422939,-0.297237,-0.318274,-0.613641,0.154591,-0.229179,0.153763,-0.485376,0.330079,-0.250823,-0.429374
ENSMUSG00000025068,-0.373089,0.059569,-0.029015,0.127127,0.058646,0.150118,-0.060132,0.268836,-0.053790,-0.036490,0.182135,-0.213085,0.139468,0.225037,-0.179689,0.123369,0.083911,0.152779,0.127375,-0.110744
ENSMUSG00000064341,0.062246,0.080869,-0.046041,-0.146250,0.248021,0.042828,-0.493345,0.127642,-0.248297,0.033964,0.187736,-0.086645,-0.229711,-0.039657,0.092087,-0.060499,0.178087,0.184678,-0.028148,-0.145498
ENSMUSG00000064351,0.043916,0.047881,-0.083317,-0.212095,0.296512,0.086948,-0.551643,0.109832,-0.254271,-0.036927,0.116559,-0.030761,-0.230797,-0.029200,0.075384,-0.033997,0.228394,0.205639,-0.012916,-0.057969
