<a href="https://colab.research.google.com/github/kicysh/final_task_of_world_model_lecture_2021/blob/main/src/LDVAE_no_output.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# pip

In [None]:
!pip install scanpy scikit-misc

# data

In [None]:
!mkdir data figures
!gsutil cp gs://h5ad/2019-02-Pijuan-Sala-et-al-Nature/pijuan_sala_atlas.h5ad /content/data
path_of_data = '/content/data/pijuan_sala_atlas.h5ad'


# setting


In [None]:
# setting
SETTING_BATCHNORM_EPS = 0.001
SETTING_BATCHNORM_MOMENTUM = 0.01
SETTING_ENCODER_Z_DROPOUT_P = 0.1
SETTING_ENCODER_L_DROPOUT_P = 0.1
SETTING_HIDDEN_DIM = 128
SETTING_EPS = 1e-8

USE_CUDA = True
SETTING_BATCH_SIZE = 1024

# model

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import logsumexp
from torch.distributions import Normal, kl_divergence

import matplotlib.pyplot as plt

rng = np.random.RandomState(1234)
random_state = 42


In [None]:
import scanpy as sc
adata = sc.read_h5ad(path_of_data)
adata

In [None]:
idx = adata.obs.query('not doublet').index
#idx = np.random.choice(idx, 20000, replace=False)
adata = adata[idx]
sc.pp.highly_variable_genes(adata,n_top_genes=1000 ,flavor='seurat_v3')
print(adata)
#sc.pp.filter_genes(adata, min_cells=100)


In [None]:
gene_index = adata.var[adata.var['highly_variable']].index

In [None]:
adata_df = adata.to_df()[gene_index]

In [None]:
#from math import ldexp

class LDVAE(nn.Module):
    """
    :param genes_cnt: Number of input genes
    :param latent_dim: Dimensionality of the latent space 
    """
    def __init__(
        self,
        genes_cnt: int, 
        latent_dim: int = 20
    ):
        super(LDVAE,self).__init__()
        self.local_l_mean = None
        self.local_l_std = None
        self.eps = SETTING_EPS

        self.theta = nn.Parameter(torch.randn(genes_cnt))
        self.encoder_z = nn.Sequential(
            nn.Linear(genes_cnt, SETTING_HIDDEN_DIM),
            nn.BatchNorm1d(SETTING_HIDDEN_DIM,
                           eps=SETTING_BATCHNORM_EPS, 
                           momentum=SETTING_BATCHNORM_MOMENTUM),
            nn.ReLU(),
            nn.Dropout(SETTING_ENCODER_Z_DROPOUT_P)
        )
        self.encoder_z_mean = nn.Linear(SETTING_HIDDEN_DIM,latent_dim)
        self.encoder_z_std = nn.Linear(SETTING_HIDDEN_DIM,latent_dim)

        self.encoder_l = nn.Sequential(
            nn.Linear(genes_cnt, SETTING_HIDDEN_DIM),
            nn.BatchNorm1d(SETTING_HIDDEN_DIM,
                           eps=SETTING_BATCHNORM_EPS, 
                           momentum=SETTING_BATCHNORM_MOMENTUM),
            nn.ReLU(),
            nn.Dropout(SETTING_ENCODER_L_DROPOUT_P)
        )
        self.encoder_l_mean = nn.Linear(SETTING_HIDDEN_DIM, 1)
        self.encoder_l_std = nn.Linear(SETTING_HIDDEN_DIM, 1)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, genes_cnt,bias=False),#bias=False
            nn.BatchNorm1d(genes_cnt,
                           eps=SETTING_BATCHNORM_EPS, 
                           momentum=SETTING_BATCHNORM_MOMENTUM)
        )


    def forward(self,x):
        x_z = self.encoder_z(x)
        z_mean = self.encoder_z_mean(x_z)
        z_std = torch.exp(torch.clip(self.encoder_z_std(x_z),max=10)) 
        #z = Normal(z_mean, z_std.sqrt()).rsample()
        z_eps = torch.randn(z_mean.shape).to('cuda' if next(self.parameters()).is_cuda else 'cpu')
        z = z_mean + z_std * z_eps


        x_l = self.encoder_l(x)
        l_mean = self.encoder_l_mean(x_l)
        l_std = torch.exp(torch.clip(self.encoder_l_std(x_l),max=10))
        #library = Normal(l_mean, l_std.sqrt()).rsample()
        l_eps = torch.randn(l_mean.shape).to('cuda' if next(self.parameters()).is_cuda else 'cpu')
        library = (l_mean + l_std * l_eps)

        y = self.decoder(z)
        y = torch.exp(torch.clip(library,max=10))*torch.softmax(y, dim=-1)
        return [z_mean, z_std, z], [l_mean, l_std, library], y


    def set_local_l_mean_and_std(self, data):
        masked_log_sum =np.ma.log(data.sum(axis=1))
        log_counts = masked_log_sum.filled(0)
        self.local_l_mean = (np.mean(log_counts).reshape(-1, 1)).astype(np.float32)[0][0]
        self.local_l_std = (np.var(log_counts).reshape(-1, 1)).astype(np.float32)[0][0]
        return self.local_l_mean, self.local_l_std


    def reconst_error(self,x, mu, theta):
        eps = SETTING_EPS
        log_theta_mu_eps = torch.log(theta + mu + eps)

        res = (
            theta * (torch.log(theta + eps) - log_theta_mu_eps)
            + x * (torch.log(mu + eps) - log_theta_mu_eps)
            + torch.lgamma(x + theta)
            - torch.lgamma(theta)
            - torch.lgamma(x + 1)
        )
        return res


    def loss(self,x):
        zs,ls,y = self.forward(x)
        z_mean, z_std, z = zs
        l_mean, l_std, library = ls

        mean, std = torch.zeros_like(z_mean), torch.ones_like(z_std)
        kl_z = kl_divergence(Normal(z_mean,torch.sqrt(z_std)), Normal(mean, std)).sum(dim=1)

        mean, std = self.local_l_mean*torch.ones_like(l_mean), self.local_l_std*torch.ones_like(l_std)
        kl_l = kl_divergence(Normal(l_mean,torch.sqrt(l_std)), Normal(mean, torch.sqrt(std))).sum(dim=1)

        reconst = self.reconst_error(x, mu=y, theta=torch.exp(self.theta)).sum(dim=-1)        
        return reconst, kl_l ,kl_z

In [None]:
model = LDVAE(genes_cnt = len(adata_df.columns),
              latent_dim = 20)
model.set_local_l_mean_and_std(adata_df.values)
model

In [None]:
class GenesDataset(torch.utils.data.Dataset):
    def __init__(self, 
                 adata, 
                 transform=None, 
                 target_transform=None):
        self.data = adata
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.data)

    def __getattr__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data = self.data[idx]
        #label = self.img_labels.iloc[idx, 1]
        if self.transform:
            data = self.transform(data)
        #if self.target_transform:
        #    label = self.target_transform(label)
        return data

In [None]:
# dataloder
dataset = GenesDataset(adata_df.values)


n_samples = len(dataset) 
train_size = int(n_samples* 0.8)
val_size = int(n_samples * 0.2)
test_size = n_samples - train_size - val_size 

dataset_train ,dataset_valid, dataset_test = \
        torch.utils.data.random_split(dataset, [train_size, val_size,test_size])

dataloader_all = torch.utils.data.DataLoader(
    dataset,
    batch_size=SETTING_BATCH_SIZE,
    shuffle=False
)

dataloader_train = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=SETTING_BATCH_SIZE,
    shuffle=True
)

dataloader_valid = torch.utils.data.DataLoader(
    dataset_valid,
    batch_size=SETTING_BATCH_SIZE,
    shuffle=True
)

#dataloader_test = torch.utils.data.DataLoader(
#    dataset_test,
#    batch_size=SETTING_BATCH_SIZE,
#    shuffle=True
#)

In [None]:
#from numpy.ma.core import nonzero
# train
#model = LDVAE(genes_cnt = len(adata_df.columns),
#              latent_dim = 20)
#model.set_local_l_mean_and_std(adata_df.values)
#model

n_epochs  = 100
optimizer = optim.Adam(model.parameters(), lr=7e-3)#  betas=(0.5,0.999))

device = 'cuda'  if USE_CUDA else 'cpu'
model.to(device)
#_x = nonzero

for epoch in range(n_epochs):
    losses = []

    model.train()
    for x in dataloader_train:
        x = x.to(device)
        #_x = x

        model.zero_grad()

        # forawrd and loss
        reconst, kl_l ,kl_z = model.loss(x)
        loss = torch.mean(-reconst+kl_l +kl_z)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.cpu().detach().numpy())

    losses_val1 = []
    losses_val2 = []
    losses_val3 = []

    model.eval()
    for x in dataloader_valid:

        x = x.to(device)


        reconst, kl_l ,kl_z = model.loss(x)

        losses_val1.append(torch.mean(-reconst).cpu().detach().numpy())
        losses_val2.append(torch.mean(kl_l).cpu().detach().numpy())
        losses_val3.append(torch.mean(kl_z).cpu().detach().numpy())

    print('EPOCH: %d    Train Loss: %lf    Valid rec: %lf    Valid kl_l: %lf    Valid kl_z: %lf' %
            (epoch+1, np.average(losses),np.average(losses_val1),np.average(losses_val2),np.average(losses_val3)))


In [None]:
# w
import pandas as pd

_w = model.decoder[0].weight
bn = model.decoder[1]
sigma = torch.sqrt(bn.running_var + bn.eps)
gamma = bn.weight
b = gamma / sigma
bI = torch.diag(b)
loadings = torch.matmul(bI, _w)
loadings = loadings.detach().cpu().numpy()

W = pd.DataFrame(loadings, index=gene_index)
W.to_csv('/content/w.csv')


In [None]:
W

In [None]:
latent_array  = None
model.eval()
for x in dataloader_all:
    x = x.to(device)
    _z,_,_ = model(x)
    z_mean,_,_ = _z
    z_mean = z_mean.cpu().detach().numpy()
    if latent_array is None:
        latent_array = z_mean
    else:
        latent_array = np.concatenate([latent_array, z_mean])
pd.DataFrame(latent_array,index=adata_df.index).to_csv('/content/z.csv')

In [None]:
pd.DataFrame(latent_array,index=adata_df.index)

In [None]:
import pandas as pd

w_df = pd.read_csv('/content/w.csv', index_col=0)
z_df = pd.read_csv('/content/z.csv', index_col=0)

In [None]:
w_df = w_df.add_prefix('w')
z_df = z_df.add_prefix('z')

In [None]:
%pylab inline
%config InlineBackend.figure_format ='retina'

import time

import anndata
import matplotlib.colors as mcolors
import pandas as pd

import plotnine as p


In [None]:
def Z_covariance(Z):
    Zcentered = Z - Z.mean(0)
    Zscaled = Z / Z.std(0)
    ZTZ = np.cov(Zscaled.T)
    
    eigen_values, _ = np.linalg.eig(ZTZ)
    singular_values = np.sqrt(eigen_values)
    variance_explained = singular_values / singular_values.sum()

    return ZTZ, variance_explained

In [None]:
_, variance_explained = Z_covariance(z_df)
idx = np.argsort(variance_explained)[::-1]


In [None]:
Z_df_ordered = pd.DataFrame(z_df.values[:, idx]).add_prefix('z')
Z_df_ordered.index = z_df.index
W_df_ordered = pd.DataFrame(w_df.values[:, idx]).add_prefix('w')
W_df_ordered.index = w_df.index

In [None]:
adata = anndata.read('/content/data/pijuan_sala_atlas.h5ad')

In [None]:
W_df_ordered = adata.var.loc[W_df_ordered.index].join(W_df_ordered)

In [None]:
def make_kde(x1, x2):
    dx1 = (x1.max() - x1.min()) / 10
    dx2 = (x2.max() - x2.min()) / 10

    x1min = x1.min() - dx1
    x2min = x2.min() - dx2
    x1max = x1.max() + dx1
    x2max = x2.max() + dx2

    xx1, xx2 = np.mgrid[x1min:x1max:100j, x2min:x2max:100j]

    positions = np.vstack([xx1.ravel(), xx2.ravel()])
    values = np.vstack([x1, x2])

    kernel = stats.gaussian_kde(values)
    f = np.reshape(kernel(positions).T, xx1.shape)
    
    return xx1, xx2, f


In [None]:
metadata = adata.obs.loc[Z_df_ordered.index]
metadata = metadata.join(Z_df_ordered)

for ctype in np.unique(metadata['celltype']):
    z_corrs = Z_df_ordered.corrwith((metadata['celltype'] == ctype))
    top_factors = z_corrs.map(np.abs).sort_values(ascending=False).head(3)
    print([f'{z.rjust(3)}: {z_corrs[z]:+.2f}' for z in top_factors.index], ctype)
    
vardata = W_df_ordered

In [None]:
from matplotlib.patches import Ellipse
from scipy import stats

vardata = W_df_ordered



In [None]:
def create_fig_0(num):
    plt.figure()
    figsize(7, 12 / 5)

    plt.subplot(1, 2, 1)

    plt.hist2d(
        Z_df_ordered[f'z'+str(num[0])], Z_df_ordered[f'z'+str(num[1])],
        bins=256,
        norm=mcolors.PowerNorm(0.25),
        cmap=cm.gray_r,
        rasterized=True
    )

    plt.axis('equal')
    plt.xlabel('$Z_{}$ '.format(num[0]) + f'({variance_explained[idx][num[0]]:.1%} variance)')
    plt.ylabel('$Z_{}$ '.format(num[1]) + f'({variance_explained[idx][num[1]]:.1%} variance)')

    ax = plt.gca()

    color_maps = [cm.Reds_r, cm.Blues_r, cm.Greens_r]
    cell_types = ['Erythroid3', 'ExE endoderm', 'Epiblast']
    for color, ctype in zip(color_maps, cell_types):
        X = metadata.query('celltype == @ctype')[['z'+str(num[0]), 'z'+str(num[1])]]
        xx1, xx2, f = make_kde(X['z'+str(num[0])], X['z'+str(num[1])])
        cset = ax.contour(
            xx1,
            xx2,
            f,
            levels=6,
            cmap=color,
            linewidths=1.
        )

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    plt.subplot(1, 2, 2)

    plt.scatter(
        W_df_ordered['w'+str(num[0])], W_df_ordered['w'+str(num[1])],
        c='lightgrey',
        rasterized=True
    )


    arrows = []
    genes = ['Hbb-bs', 'Hbb-bt', 'Hba-a2']
    for g in genes:
        x_, y_ = vardata.query('gene_name == @g')[['w'+str(num[0]), 'w'+str(num[1])]].values[0]

        arrows.append(plt.arrow(0, 0, x_, y_, length_includes_head=True, color='r'))
        
    plt.text(0, 2, '\n'.join(genes), color='r')

    arrows = []
    genes = ['Pou5f1', 'Tdgf1', 'Snrpn']
    for g in genes:
        x_, y_ = vardata.query('gene_name == @g')[['w'+str(num[0]), 'w'+str(num[1])]].values[0]

        arrows.append(plt.arrow(0, 0, x_, y_, length_includes_head=True, color='g'))

    plt.text(1, -1.5, '\n'.join(genes), color='g')

    arrows = []
    genes = ['Ctsh', 'Amn', 'Apoa4']
    for g in genes:
        x_, y_ = vardata.query('gene_name == @g')[['w'+str(num[0]), 'w'+str(num[1])]].values[0]

        arrows.append(plt.arrow(0, 0, x_, y_, length_includes_head=True, color='b'))

    plt.text(-2.6, -0.5, '\n'.join(genes), color='b')


    plt.xlim(left=-3, right=3)
    plt.ylim(bottom=-3, top=3)

    plt.xlabel('$W_{}$'.format(num[0]))
    plt.ylabel('$W_{}$'.format(num[1]))

    ax = plt.gca()

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    plt.savefig('figures/linear_pij_results_celltypes_{}_{}.pdf'.format(*num), bbox_inches='tight', dpi=400)



In [None]:
for i in range(20):
    for j in range(i+1,20):
        create_fig_0((i,j))

In [None]:
# ダウンロードしたいフォルダを zip 圧縮する
!zip -r /content/download.zip /content/figures

# 圧縮した zip ファイルをダウンロードする
from google.colab import files
files.download("/content/download.zip")

In [None]:
num = (0,1)

figsize(7, 12 / 5)

plt.subplot(1, 2, 1)

plt.hist2d(
    Z_df_ordered[f'z'+str(num[0])], Z_df_ordered[f'z'+str(num[1])],
    bins=256,
    norm=mcolors.PowerNorm(0.25),
    cmap=cm.gray_r,
    rasterized=True
)

plt.axis('equal')
plt.xlabel('$Z_{}$ '.format(num[0]) + f'({variance_explained[idx][num[0]]:.1%} variance)')
plt.ylabel('$Z_{}$ '.format(num[1]) + f'({variance_explained[idx][num[1]]:.1%} variance)')

ax = plt.gca()

color_maps = [cm.Reds_r, cm.Blues_r, cm.Greens_r]
cell_types = ['Erythroid3', 'ExE endoderm', 'Epiblast']
for color, ctype in zip(color_maps, cell_types):
    X = metadata.query('celltype == @ctype')[['z'+str(num[0]), 'z'+str(num[1])]]
    xx1, xx2, f = make_kde(X['z'+str(num[0])], X['z'+str(num[1])])
    cset = ax.contour(
        xx1,
        xx2,
        f,
        levels=6,
        cmap=color,
        linewidths=1.
    )

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.subplot(1, 2, 2)

plt.scatter(
    W_df_ordered['w'+str(num[0])], W_df_ordered['w'+str(num[1])],
    c='lightgrey',
    rasterized=True
)


arrows = []
genes = ['Hbb-bs', 'Hbb-bt', 'Hba-a2']
for g in genes:
    x_, y_ = vardata.query('gene_name == @g')[['w'+str(num[0]), 'w'+str(num[1])]].values[0]

    arrows.append(plt.arrow(0, 0, x_, y_, length_includes_head=True, color='r'))
    
plt.text(0, 2, '\n'.join(genes), color='r')

arrows = []
genes = ['Pou5f1', 'Tdgf1', 'Snrpn']
for g in genes:
    x_, y_ = vardata.query('gene_name == @g')[['w'+str(num[0]), 'w'+str(num[1])]].values[0]

    arrows.append(plt.arrow(0, 0, x_, y_, length_includes_head=True, color='g'))

plt.text(1, -1.5, '\n'.join(genes), color='g')

arrows = []
genes = ['Ctsh', 'Amn', 'Apoa4']
for g in genes:
    x_, y_ = vardata.query('gene_name == @g')[['w'+str(num[0]), 'w'+str(num[1])]].values[0]

    arrows.append(plt.arrow(0, 0, x_, y_, length_includes_head=True, color='b'))

plt.text(-2.6, -0.5, '\n'.join(genes), color='b')


plt.xlim(left=-3, right=3)
plt.ylim(bottom=-3, top=3)

plt.xlabel('$W_{}$'.format(num[0]))
plt.ylabel('$W_{}$'.format(num[1]))

ax = plt.gca()

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.savefig('figures/linear_pij_results_celltypes.pdf', bbox_inches='tight', dpi=400)



In [None]:
figsize(80, 80)
z_length = 20
for i in range(z_length):
    for j in range(z_length):
        # -- Z plot --
    
        plt.subplot(z_length, z_length, z_length*i +j + 1)
        
        plt.hist2d(
            Z_df_ordered[f'z{i}'], Z_df_ordered[f'z{ j }'],
            bins=256,
            norm=mcolors.PowerNorm(0.25),
            cmap=cm.gray_r,
            rasterized=True
        )
        plt.axis('equal');
        plt.xlabel(f'$Z_{i}$ ({variance_explained[idx][i]:.1%} variance)')
        plt.ylabel(f'$Z_{j}$ ({variance_explained[idx][ j]:.1%} variance)')
        
        ax = plt.gca()
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    
plt.tight_layout()
plt.savefig('figures/linear_pij_results.pdf', bbox_inches='tight', dpi=20)

In [None]:


text_shift = {
    (0, 'Fst'): (0.2, -0.2),
    (0, 'Pmp22'): (-0.1, -0.3),
    (0, 'Hoxaas3'): (0.0, 0.3),
    (0, 'Acta2'): (0.0, 0.2),
    (0, 'Nnat'): (0.0, 0.5),
    (0, 'Ifitm1'): (0.0, 0.2),
    
    (1, 'Srgn'): (0., -0.2),
    (1, 'Amn'): (-0.7, 0.),
    
    (2, 'Crabp2'): (0.0, 0.4),
    (2, 'Tdgf1'): (0.0, 0.4),
    (2, 'Cer1'): (-0.6, 0.0),
    (2, 'T'): (-1.3, 0.0),
    
    (3, 'Cdx1'): (0.0, 0.4),
    (3, 'Cdx2'): (-1.0, 0.7),
    (3, 'Cited1'): (-1.6, -0.0),
    (3, 'Phlda2'): (-0.7, 0.0),
    (3, 'T'): (0.0, 0.2),
    (3, 'Ifitm1'): (-0.2, 0.0),
#     (3, 'Rspo2'): (-0.3, 0.0),
#     (3, 'Htr2c'): (0.6, 0.1),
#     (3, 'Col1a1'): (0.0, 0.2),
    
#     (4, 'Ttn'): (0.0, -0.4),
#     (4, 'Sntb1'): (0.0, -0.3),
#     (4, 'Colec12'): (-1.1, 0.1),
#     (4, 'Adam12'): (0.1, 0.4),
#     (4, 'Spon1'): (0.0, 0.2),
#     (4, 'Gm3764'): (-0.2, 0.3),
#     (4, 'C130071C03Rik'): (1.5, -0.35),
    
}



In [None]:

figsize(7, 12)
for i in range(5):
    
    # -- Z plot --
    
    plt.subplot(5, 2, 2 * i + 1)
    
    plt.hist2d(
        Z_df_ordered[f'z{2 * i}'], Z_df_ordered[f'z{2 * i + 1}'],
        bins=256,
        norm=mcolors.PowerNorm(0.25),
        cmap=cm.gray_r,
        rasterized=True
    )
    plt.axis('equal');
    plt.xlabel(f'$Z_{2 * i}$ ({variance_explained[idx][2 * i]:.1%} variance)')
    plt.ylabel(f'$Z_{2 * i + 1}$ ({variance_explained[idx][2 * i + 1]:.1%} variance)')
    
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    

    #plt.xlim(left=-3, right=3)
    #plt.ylim(bottom=-3, top=3)
    
plt.tight_layout()
plt.savefig('figures/linear_pij_results.pdf', bbox_inches='tight', dpi=400)