In [1]:
import scanpy as sc
import scvi
import scarches as sca
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.distributions import Normal, Poisson
import torch.nn.functional as F
# from dalib.modules import domain_discriminator
# from scvi.nn._base_components import DecoderSCVI

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

Global seed set to 0
 captum (see https://github.com/pytorch/captum).


device(type='cuda')

In [2]:
ref_latent = sc.read_h5ad("./reference_latent.h5ad")
query_adata = sc.read_h5ad("/home/wyh/liver_atlas/data/Aizarani2019/Aizarani2019_plot_V2.h5ad")
# ref_path = "./scvi_ref_model/"
ref_path = "./scvi_ref_model_sample/"

In [4]:
from scipy.sparse import csr_matrix
def feature_alignment(test_set, gene_list):
    test_set.X = test_set.layers['counts']
    
    # set test_set features as gene_list, zero-filling for missing features
    selected_set = set(gene_list)
    test_set_genes = set(test_set.var_names)
    common_set = selected_set & test_set_genes
    gene_extra = selected_set - common_set
    n_extra = len(gene_extra)
    if n_extra / len(gene_list) > 0.05:
        print("Warning: %d features not exist in testset." % len(gene_extra))

    if n_extra > 0:  # fill zeros for missing features
        new_mtx = csr_matrix(test_set.X, shape=(test_set.n_obs, test_set.n_vars + n_extra))
        test_adata = sc.AnnData(new_mtx)
        test_adata.obs = test_set.obs
        test_adata.layers['counts'] = test_adata.X
        test_adata.obs_names = test_set.obs_names
        test_adata.var_names = list(test_set.var_names) + list(gene_extra)
        # test_adata.obs = test_set.obs
        return test_adata[:, gene_list]
    else:
        return test_set[:, gene_list]
    
gene_selected = pd.read_csv(ref_path + "var_names.csv", header=None)[0].tolist()
query_adata = feature_alignment(query_adata, gene_selected)
query_adata.obs['sample'] = query_adata.obs['batch']

scvi_model = sca.models.SCVI.load_query_data(
    query_adata,
    ref_path,
    freeze_dropout = True,
)

scvi_model.train(max_epochs=0, plan_kwargs=dict(weight_decay=0.0))
# scvi.data.view_anndata_setup(scvi_model.adata)

[34mINFO    [0m Using data from adata.layers[1m[[0m[32m"counts"[0m[1m][0m                                              


  query_adata.obs['sample'] = query_adata.obs['batch']
INFO:scvi.data._anndata:Using data from adata.layers["counts"]


[34mINFO    [0m Registered keys:[1m[[0m[32m'X'[0m, [32m'batch_indices'[0m, [32m'labels'[0m[1m][0m                                    


INFO:scvi.data._anndata:Registered keys:['X', 'batch_indices', 'labels']


[34mINFO    [0m Successfully registered anndata object containing [1;36m10372[0m cells, [1;36m2000[0m vars, [1;36m39[0m        
         batches, [1;36m9[0m labels, and [1;36m0[0m proteins. Also registered [1;36m0[0m extra categorical covariates   
         and [1;36m0[0m extra continuous covariates.                                                  


INFO:scvi.data._anndata:Successfully registered anndata object containing 10372 cells, 2000 vars, 39 batches, 9 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra continuous covariates.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Training: 0it [00:00, ?it/s]

In [35]:
scvi.data.view_anndata_setup(scvi_model.adata)



In [45]:
batch_id = scvi_model.adata.obs['_scvi_batch'][0]

In [30]:
class TransferMappingModule(nn.Module):
    def __init__(
        self, 
        n_latent=10, 
        n_gene=2000, 
        n_hidden=[64, 32],
        dropout_rate=0.1,
        extra_decoder_kwargs=None
    ):
        
        
        super(TransferMappingModule, self).__init__()
        
#         # decoder (froze param)
#         self.decoder = Decoder(
#             n_latent=n_latent, 
#             n_out=n_gene, 
#         )

        # transfer
        self.fc11 = nn.Linear(n_latent, n_hidden[0])
        self.fc12 = nn.Linear(n_hidden[0], n_hidden[1])      
        self.fc2 = nn.Linear(n_hidden[1], n_latent)
        self.dropout = nn.Dropout(p=dropout_rate)
        
    def query_transfer(self, z):
        h = F.relu(self.fc11(z))
        h = self.dropout(h)
        h = F.relu(self.fc12(h))
        h = self.dropout(h)
        return self.fc2(h)
    
#     def generative(self, z):
#         categorical_input = ()
#         size_factor = None
# #         px_scale, px_r, px_rate, px_dropout = self.decoder(
# #                 self.dispersion,
# #                 self.query_transfer(z),
# #                 size_factor,
# #                 *categorical_input,
# #                 y,
# #             )
# #         px = NegativeBinomial(mu=px_rate, theta=px_r, scale=px_scale)
#         px_scale, px_dropout = self.decoder(z)
#         poisson = Poisson(torch.exp(library) * px_scale + self.eps)
#         decoder_poisson = Poisson(px_scale + self.eps)
#         return px
    
    def forward(self, z):
        return self.query_transfer(z)

# 定义discriminator类
class DomainDiscriminator(nn.Module):
    def __init__(self, input_dim=30, hidden_dim=64, dropout_rate=0.1):
        super(DomainDiscriminator, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, dropout_rate=0.1):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return self.sigmoid(x)

In [31]:
from torch.utils.data import Dataset, DataLoader

# 自定义dataset类
class MyDataset(Dataset):
    def __init__(self, adata, X_key):
        self.data = adata
        if X_key == "X":
            self.X = adata.X
        else:
            self.X = adata.layers[X_key]
        
    def __len__(self):
        return self.data.n_obs
    
    def __getitem__(self, idx):
        return torch.tensor(self.X[idx,:].toarray())
    

In [32]:
def discriminator_loss(discriminator, ref_latent, query_latent, mode):
    if mode == "trainD":
        ref_labels = torch.zeros(ref_latent.shape[0], 1)
        query_labels = torch.ones(query_latent.shape[0], 1)
        labels = torch.cat((ref_labels, query_labels), dim=0).to(device)
        
    elif mode == "trainM":
        labels = torch.zeros(ref_latent.shape[0] + query_latent.shape[0], 1).to(device)
    else:
        print("No such mode for discriminator loss.")

    # 将源领域和目标领域数据合并并输入到域鉴别器中
    data = torch.cat((ref_latent, query_latent), dim=0).to(device)
    predictions = discriminator(data)

    # 计算域鉴别器的损失
    criterion = nn.BCELoss()
    loss = criterion(predictions, labels)
#     print(mode)
#     print(predictions[-10:])
#     print(labels[-10:])
#     print(loss)
    return loss #* 1e3

def reconstruction_loss(scvi_model, transfer_mapping_module, X, batch_id):
    n_sample = X.shape[0]
    batch_index = torch.full((n_sample, 1), batch_id)
    inference_output = scvi_model.module.inference(
        X, 
        batch_index=batch_index
    )
    z = inference_output['z']
    transferred_z = transfer_mapping_module(z)
    generative_output = scvi_model.module.generative(
        z=transferred_z,
        library=inference_output['library'], 
        batch_index=batch_index
    )    
    reconst_loss = scvi_model.module.get_reconstruction_loss(
        X, 
        generative_output['px_rate'], 
        generative_output['px_r'], 
        generative_output['px_dropout']
    )
#     loss = -generative_x.log_prob(x).sum(-1)
    return torch.mean(reconst_loss)


In [33]:
import scanpy as sc
dataset = MyDataset(query_adata, "counts")

# hyper-parameters
learning_rate = 0.01
batch_size = 256
num_workers = 4
shuffle = True
num_epoch = 250

# Instantiate trainloader
train_loader = DataLoader(dataset, 
                        batch_size=batch_size, 
                        num_workers=num_workers, 
                        shuffle=shuffle)
# Instantiate mapping module
mapping_module = TransferMappingModule(n_latent=30, n_gene=2000,  
                                       dropout_rate=0.1).to(device)
# Instantiate discriminator
discriminator = DomainDiscriminator(input_dim=30, 
                                    dropout_rate=0.1).to(device)

# criterion = nn.BCELoss()
optimizerD = optim.Adam(discriminator.parameters(), lr=learning_rate)
optimizerM = optim.Adam(mapping_module.parameters(), lr=learning_rate)

from torch.optim.lr_scheduler import StepLR
schedulerD = StepLR(optimizerD, step_size=50, gamma=0.5)
schedulerM = StepLR(optimizerM, step_size=50, gamma=0.5)

In [34]:
reconst_loss_rate = 0.001
num_batches = len(train_loader)

mapping_module.train()
discriminator.train()

z_record = {}
for epoch in range(num_epoch):
    # ref_latent to GPU
    ref_latent_subset = sc.pp.subsample(ref_latent, n_obs = 10000, 
                                        random_state = epoch, copy=True)
    ref_latent_tensor = torch.tensor(ref_latent_subset.X).to(device)
    
    for i, x in enumerate(train_loader):
        query_x = x.to(device)
        # encode and transfer
        n_sample = query_x.shape[0]
        batch_index = torch.full((n_sample, 1), batch_id)
        inference_output = scvi_model.module.inference(
            query_x, 
            batch_index=batch_index
        )
        query_latent = torch.squeeze(inference_output['z'])
        query_latent_transferred = mapping_module(query_latent)
        
        ##############################
        #   Training discriminator   #
        ##############################
        # forward
        lossD = discriminator_loss(
            discriminator, 
            ref_latent_tensor, 
            query_latent_transferred,
            "trainD"
        )
        # backward
        optimizerD.zero_grad()
        lossD.backward()
#         lossD.backward(retain_graph=True)
        optimizerD.step()         
        
#         ##########################
#         #   Training generator   #
#         ##########################
        inference_output = scvi_model.module.inference(
            query_x, 
            batch_index=batch_index
        )
        query_latent = torch.squeeze(inference_output['z'])
        query_latent_transferred = mapping_module(query_latent)
        
        # forward
        discrim_loss = discriminator_loss(
            discriminator, 
            ref_latent_tensor, 
            query_latent_transferred,
            "trainM"
        )
        reconst_loss = reconstruction_loss(scvi_model, mapping_module, 
                                           torch.squeeze(query_x), batch_id)
        lossM = reconst_loss_rate * reconst_loss + discrim_loss
        
        # backward
        optimizerM.zero_grad()
        lossM.backward()
        optimizerM.step()
        
        if i == 0:
            z_mean = torch.tensor(scvi_model.get_latent_representation()).to(device)
            z_mean_transferred = mapping_module(z_mean)
            z_record[epoch] = z_mean_transferred
#             break     

        if i%40 == 0:
#             print(lossD.item())
            print('Epoch [{}/{}], step [{}/{}], D_loss: {:.4f}, M_loss: {:.4f}, M_d_loss: {:.4f}, M_r_loss: {:.4f}'.format(
                epoch+1, num_epoch, i+1, num_batches, 
                lossD.item(), lossM.item(), 
                discrim_loss.item(), reconst_loss.item()))
    schedulerD.step()
    schedulerM.step()





Epoch [1/250], step [1/41], D_loss: 0.6527, M_loss: 5.8408, M_d_loss: 0.5479, M_r_loss: 5292.8696
Epoch [1/250], step [41/41], D_loss: 0.1227, M_loss: 1.1967, M_d_loss: 0.0341, M_r_loss: 1162.6326
Epoch [2/250], step [1/41], D_loss: 0.2269, M_loss: 1.3288, M_d_loss: 0.0559, M_r_loss: 1272.8298
Epoch [2/250], step [41/41], D_loss: 0.0979, M_loss: 1.9963, M_d_loss: 0.7223, M_r_loss: 1274.0557
Epoch [3/250], step [1/41], D_loss: 0.1201, M_loss: 2.6296, M_d_loss: 1.4264, M_r_loss: 1203.2478
Epoch [3/250], step [41/41], D_loss: 0.0956, M_loss: 1.4968, M_d_loss: 0.4142, M_r_loss: 1082.5873
Epoch [4/250], step [1/41], D_loss: 0.1470, M_loss: 2.1062, M_d_loss: 0.9335, M_r_loss: 1172.6942
Epoch [4/250], step [41/41], D_loss: 0.0592, M_loss: 1.5415, M_d_loss: 0.4551, M_r_loss: 1086.3967
Epoch [5/250], step [1/41], D_loss: 0.0834, M_loss: 2.0566, M_d_loss: 0.8800, M_r_loss: 1176.6519
Epoch [5/250], step [41/41], D_loss: 0.0857, M_loss: 1.6589, M_d_loss: 0.5440, M_r_loss: 1114.9553
Epoch [6/250], 

Epoch [42/250], step [41/41], D_loss: 0.0240, M_loss: 1.7389, M_d_loss: 0.5582, M_r_loss: 1180.7108
Epoch [43/250], step [1/41], D_loss: 0.0404, M_loss: 2.0861, M_d_loss: 1.0113, M_r_loss: 1074.7848
Epoch [43/250], step [41/41], D_loss: 0.0198, M_loss: 1.7189, M_d_loss: 0.5442, M_r_loss: 1174.7056
Epoch [44/250], step [1/41], D_loss: 0.0317, M_loss: 1.9461, M_d_loss: 0.8684, M_r_loss: 1077.6292
Epoch [44/250], step [41/41], D_loss: 0.0193, M_loss: 1.5864, M_d_loss: 0.4726, M_r_loss: 1113.7263
Epoch [45/250], step [1/41], D_loss: 0.0356, M_loss: 1.9645, M_d_loss: 0.8673, M_r_loss: 1097.2344
Epoch [45/250], step [41/41], D_loss: 0.0174, M_loss: 1.6033, M_d_loss: 0.4854, M_r_loss: 1117.8346
Epoch [46/250], step [1/41], D_loss: 0.0254, M_loss: 1.9909, M_d_loss: 0.8971, M_r_loss: 1093.8157
Epoch [46/250], step [41/41], D_loss: 0.0202, M_loss: 1.4811, M_d_loss: 0.4549, M_r_loss: 1026.1608
Epoch [47/250], step [1/41], D_loss: 0.0298, M_loss: 2.0057, M_d_loss: 0.9328, M_r_loss: 1072.8801
Epoch

Epoch [84/250], step [1/41], D_loss: 0.0197, M_loss: 2.0816, M_d_loss: 1.0320, M_r_loss: 1049.5583
Epoch [84/250], step [41/41], D_loss: 0.0240, M_loss: 1.4544, M_d_loss: 0.3913, M_r_loss: 1063.1399
Epoch [85/250], step [1/41], D_loss: 0.0343, M_loss: 2.1115, M_d_loss: 1.0594, M_r_loss: 1052.1801
Epoch [85/250], step [41/41], D_loss: 0.0172, M_loss: 1.4257, M_d_loss: 0.4142, M_r_loss: 1011.4954
Epoch [86/250], step [1/41], D_loss: 0.0250, M_loss: 2.2495, M_d_loss: 1.1669, M_r_loss: 1082.6277
Epoch [86/250], step [41/41], D_loss: 0.0134, M_loss: 1.5407, M_d_loss: 0.4732, M_r_loss: 1067.4553
Epoch [87/250], step [1/41], D_loss: 0.0195, M_loss: 2.1288, M_d_loss: 1.0354, M_r_loss: 1093.4231
Epoch [87/250], step [41/41], D_loss: 0.0199, M_loss: 1.5407, M_d_loss: 0.4490, M_r_loss: 1091.6115
Epoch [88/250], step [1/41], D_loss: 0.0293, M_loss: 2.1019, M_d_loss: 1.0144, M_r_loss: 1087.4735
Epoch [88/250], step [41/41], D_loss: 0.0119, M_loss: 1.6164, M_d_loss: 0.5334, M_r_loss: 1083.0756
Epoch


KeyboardInterrupt



In [28]:
# import os
model_path = './model'
# os.makedirs(model_path, exist_ok=True)
torch.save(discriminator.state_dict(), model_path + "/sample_aizarani.pth")
torch.save(mapping_module.state_dict(), model_path + "/sample_aizarani.pth")

In [29]:
save_path = './saves'
ad = sc.AnnData(z_record[0].detach().cpu().numpy())
ad.obs["cell_type"] = query_adata.obs["level1"].tolist()
ad.obs["batch"] = query_adata.obs["batch"].tolist()
ad.raw = ad
ad.write(save_path + "/ad_sample_aizarani_epo00.h5ad")

ad = sc.AnnData(z_record[99].detach().cpu().numpy())
ad.obs["cell_type"] = query_adata.obs["level1"].tolist()
ad.obs["batch"] = query_adata.obs["batch"].tolist()
ad.raw = ad
ad.write(save_path + "/ad_sample_aizarani_epo99.h5ad")

save_path = './saves'
ad = sc.AnnData(z_record[199].detach().cpu().numpy())
ad.obs["cell_type"] = query_adata.obs["level1"].tolist()
ad.obs["batch"] = query_adata.obs["batch"].tolist()
ad.raw = ad
ad.write(save_path + "/ad_sample_aizarani_epo199.h5ad")

ad = sc.AnnData(z_record[249].detach().cpu().numpy())
ad.obs["cell_type"] = query_adata.obs["level1"].tolist()
ad.obs["batch"] = query_adata.obs["batch"].tolist()
ad.raw = ad
ad.write(save_path + "/ad_sample_aizarani.h5ad")

In [17]:
save_path = './saves'
# os.makedirs(save_path, exist_ok=True)

reference_latent1 = sc.AnnData(scvi_model.get_latent_representation())
reference_latent1.obs["cell_type"] = query_adata.obs["level1"].tolist()
reference_latent1.obs["batch"] = query_adata.obs["batch"].tolist()
reference_latent1.write(save_path + "/latent_origin2.h5ad")

z_mean = torch.tensor(scvi_model.get_latent_representation()).to(device)
z_mean_transferred = mapping_module(z_mean)
reference_latent = sc.AnnData(z_mean_transferred.detach().cpu().numpy())
reference_latent.obs["cell_type"] = query_adata.obs["level1"].tolist()
reference_latent.obs["batch"] = query_adata.obs["batch"].tolist()
reference_latent.write(save_path + "/latent_transferred2.h5ad")



In [None]:
import scanpy as sc
save_path = './saves'
query_latent_origin = sc.read_h5ad(save_path + "/latent_origin2.h5ad")
query_latent_transferred = sc.read_h5ad(save_path + "/latent_transferred2.h5ad")

In [None]:
ref_latent = sc.read_h5ad("./reference_latent.h5ad")
sc.pp.subsample(ref_latent, n_obs = 10000)
merged_adata = sc.AnnData.concatenate(ref_latent, query_latent_transferred,
                                      batch_key="ref_query")
merged_adata

In [None]:
sc.pp.neighbors(merged_adata, n_neighbors=4)
sc.tl.umap(merged_adata)
sc.pl.umap(merged_adata,
           color=['cell_type', "ref_query"],
           frameon=False,
           wspace=0.6,
           )

In [None]:
merged_adata0 = sc.AnnData.concatenate(ref_latent, query_latent_origin,
                                      batch_key = "ref_query")

sc.pp.neighbors(merged_adata0, n_neighbors=4)
sc.tl.umap(merged_adata0)
sc.pl.umap(merged_adata0,
           color=['cell_type', "ref_query"],
           frameon=False,
           wspace=0.6,
           )