In [1]:
%%capture
!pip install scprep
!pip install anndata
!pip install scanpy

In [2]:
import numpy as np
import pandas as pd
import anndata as ad
import scprep
import scanpy as sc
import sklearn
from sklearn.model_selection import train_test_split
import tempfile
import os
from os import path
import sys
import scipy
from scipy import sparse

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split

import load_raw
import normalize_tools as nm
import metrics

In [3]:
torch.cuda.is_available()

  return torch._C._cuda_getDeviceCount() > 0


False

# **try out with scicar cell lines dataset**

**1. URLs for raw data**

In [4]:
input_train_mod1 = ad.read_h5ad("./matching/sample_multitome/openproblems_bmmc_multiome_starter.train_mod1.h5ad")
input_train_mod2 = ad.read_h5ad("./matching/sample_multitome/openproblems_bmmc_multiome_starter.train_mod2.h5ad")
input_train_sol = ad.read_h5ad("./matching/sample_multitome/openproblems_bmmc_multiome_starter.train_sol.h5ad")
input_test_mod1 = ad.read_h5ad("./matching/sample_multitome/openproblems_bmmc_multiome_starter.test_mod1.h5ad")
input_test_mod2 = ad.read_h5ad("./matching/sample_multitome/openproblems_bmmc_multiome_starter.test_mod2.h5ad")
input_test_sol = ad.read_h5ad("./matching/sample_multitome/openproblems_bmmc_multiome_starter.test_sol.h5ad")

In [7]:
#set up all hyper-parameters
hyper = {
    "nEpochs": 20,
    "dimRNA": input_train_mod1.X.shape[1],
    "dimATAC": input_train_mod2.shape[1],
    "train_nobs": input_train_mod1.X.shape[0],
    "test_nobs": input_test_mod1.X.shape[0],
    "layer_sizes": [1024, 512, 256],
    "nz": 64,
    "batchSize": 512,
    "lr": 1e-3,
    "lamb_kl": 1e-9,
    "lamb_anc": 1e-9,
    "clip_grad": 0.1,
    "checkpoint_path": './checkpoint/vae_matching.pt',
}

# **define pytorch datasets for RNA and ATAC**

In [8]:
class TrainDataset(Dataset):
    def __init__(self, adata_mod1, adata_mod2, labels):
        #self.rna_data_filtered, self.atac_data_filtered = self._load_merge_data(adata_filtered)
        #self.rna_data_raw = self._load_raw_ref_data(adata_raw)
        self.rna_data = adata_mod1.X.toarray()
        self.atac_data = adata_mod2.X.toarray()
        self.labels = labels.X.toarray()
        self.idx_map = []
        
        assert self.rna_data.shape[0] == self.atac_data.shape[0]
        self.n_obs = self.rna_data.shape[0]
        for i in range(self.n_obs):
            for j in range(self.n_obs):
                self.idx_map.append((i, j))
        
    def __len__(self):
        #assert(len(self.rna_data) == len(self.atac_data))
        return len(self.idx_map)
  
    def __getitem__(self, idx):
        i, j = self.idx_map[idx]
        rna_sample = torch.from_numpy(self.rna_data[i]).float()
        atac_sample = torch.from_numpy(self.atac_data[j]).float()
        label_sample = self.labels[i][j]
        #return a tensor that for a single observation
        return rna_sample, atac_sample, label_sample

In [9]:
class TestDataset(Dataset):
    def __init__(self, adata_mod1, adata_mod2):
        #self.rna_data_filtered, self.atac_data_filtered = self._load_merge_data(adata_filtered)
        #self.rna_data_raw = self._load_raw_ref_data(adata_raw)
        self.rna_data = adata_mod1.X.toarray()
        self.atac_data = adata_mod2.X.toarray()
        self.idx_map = []
        
        assert self.rna_data.shape[0] == self.atac_data.shape[0]
        self.n_obs = self.rna_data.shape[0]
        for i in range(self.n_obs):
            for j in range(self.n_obs):
                self.idx_map.append((i, j))
        
    def __len__(self):
        #assert(len(self.rna_data) == len(self.atac_data))
        return len(self.idx_map)
  
    def __getitem__(self, idx):
        i, j = self.idx_map[idx]
        rna_sample = torch.from_numpy(self.rna_data[i]).float()
        atac_sample = torch.from_numpy(self.atac_data[j]).float()
        return rna_sample, atac_sample

# **define basic models(autoencoders) for learning latent space**

In [10]:
class FC_VAE(nn.Module):
    def __init__(self, n_input, nz, layer_sizes=hyper["layer_sizes"]):
        super(FC_VAE, self).__init__()
        self.n_input = n_input
        self.nz = nz
        self.layer_sizes = layer_sizes

        self.encoder_layers = []

        self.encoder_layers.append(nn.Linear(n_input, self.layer_sizes[0]))
        self.encoder_layers.append(nn.LeakyReLU(inplace=True))
        self.encoder_layers.append(nn.BatchNorm1d(self.layer_sizes[0]))

        for layer_idx in range(len(layer_sizes)-1):
            if layer_idx == len(layer_sizes) - 2:
                self.encoder_layers.append(nn.Linear(self.layer_sizes[layer_idx], self.layer_sizes[layer_idx+1]))
                self.encoder_layers.append(nn.LeakyReLU(inplace=True))
            else:
                self.encoder_layers.append(nn.Linear(self.layer_sizes[layer_idx], self.layer_sizes[layer_idx+1]))
                self.encoder_layers.append(nn.BatchNorm1d(self.layer_sizes[layer_idx+1]))
                self.encoder_layers.append(nn.LeakyReLU(inplace=True))

        self.encoder = nn.Sequential(
            *self.encoder_layers
        )
        self.fc1 = nn.Linear(self.layer_sizes[-1], nz)
        self.fc2 = nn.Linear(self.layer_sizes[-1], nz)

        self.decoder_layers = []
        self.decoder_layers.append(nn.Linear(nz, self.layer_sizes[-1]))
        self.decoder_layers.append(nn.BatchNorm1d(self.layer_sizes[-1]))
        self.decoder_layers.append(nn.LeakyReLU(inplace=True))

        for layer_idx in range(len(self.layer_sizes)-1, 0, -1):
            self.decoder_layers.append(nn.Linear(self.layer_sizes[layer_idx], self.layer_sizes[layer_idx-1]))
            self.decoder_layers.append(nn.BatchNorm1d(self.layer_sizes[layer_idx-1]))
            self.decoder_layers.append(nn.LeakyReLU(inplace=True))

        self.decoder_layers.append(nn.Linear(self.layer_sizes[0], self.n_input))

        self.decoder = nn.Sequential(
            *self.decoder_layers
        )
    def encode(self, x):
        h = self.encoder(x)
        return self.fc1(h), self.fc2(h)

    def reparametrize(self, mu, logvar):
        #calculate std from log(var)
        std = logvar.mul(0.5).exp_()
        if torch.cuda.is_available():
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)
    
    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        res = self.decode(z)
        return res, z, mu, logvar

    def get_latent_var(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return z
    
    def generate(self, z):
        return self.decode(z)

In [43]:
#implement contrastive loss
class ContrastiveLoss(torch.nn.Module):
    
    def __init__(self, margin=0.1, pos_coef=10):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.pos_coef = pos_coef
    
    def forward(self, output1, output2, label):
        pdist = F.pairwise_distance(output1, output2, keepdim=True)
        loss_contrastive = torch.mean(self.pos_coef * (1 - label) * torch.pow(pdist, 2) + 
                                      label * torch.pow(torch.clamp(self.margin - pdist, min = 0.0), 2))
        return loss_contrastive

# **train VAE model based on reconstruction, KL divergence, and anchor loss**

In [44]:
#load dataset and split train and test data
def get_data_loaders(train_mod1, train_mod2, train_sol, test_mod1, test_mod2):
    train_set = TrainDataset(train_mod1, train_mod2, train_sol)
    test_set = TestDataset(test_mod1, test_mod2)
    
    #load data loader
    train_loader = DataLoader(
        train_set, 
        batch_size=hyper["batchSize"], 
        drop_last=False, 
        shuffle=True,
    )
    test_loader = DataLoader(
        test_set, 
        batch_size=hyper["batchSize"],
        drop_last=False,
        shuffle=False,
    )
    return train_loader, test_loader

In [45]:
#set up loss function
def basic_loss(recon_x, x, mu, logvar, lamb1):
    MSE = nn.MSELoss()
    lloss = MSE(recon_x, x)
    #KL divergence
    KL_loss = -0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    lloss = lloss + lamb1*KL_loss
    return lloss

#anchor loss for minimizing distance between paired observation
def anchor_loss(embed_rna, embed_atac):
    L1 = nn.L2Loss()
    anc_loss = L2(embed_rna, embed_atac)
    return anc_loss

In [46]:
def similar_score(rna_output, atac_output):
    cos_score = F.cosine_similarity(rna_output, atac_output)
    return(cos_score.cpu().detach().numpy())

In [47]:
#set up train functions
def main():
    #load training data and testing data
    train_loader, test_loader = get_data_loaders(
        input_train_mod1,
        input_train_mod2,
        input_train_sol,
        input_test_mod1,
        input_test_mod2,
    )
    
    #load checkpoint
    checkpoint = None
    if path.exists(hyper["checkpoint_path"]):
        checkpoint = torch.load(hyper["checkpoint_path"])
    
    #load basic models
    netRNA = FC_VAE(n_input=hyper["dimRNA"], nz=hyper["nz"], layer_sizes=hyper["layer_sizes"])
    netATAC = FC_VAE(n_input=hyper["dimATAC"], nz=hyper["nz"], layer_sizes=hyper["layer_sizes"])
    if checkpoint != None:
        netRNA.load_state_dict(checkpoint["net_rna_state_dict"])
        netATAC.load_state_dict(checkpoint["net_atac_state_dict"])
        
    if torch.cuda.is_available():
        print("using GPU")
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    netRNA.to(device)
    netATAC.to(device)
    
    #set up additional criterion (constrastive...)
    criterion_contrastive = ContrastiveLoss()
    
    #setup optimizers for two nets
    opt_netRNA = optim.Adam(list(netRNA.parameters()), lr=hyper["lr"])
    opt_netATAC = optim.Adam(list(netATAC.parameters()), lr=hyper["lr"])
    scheduler_netRNA = optim.lr_scheduler.ReduceLROnPlateau(
        opt_netRNA,
        patience=5,
        threshold=0.01,
        mode="max",
        min_lr=1e-5,
    )
    scheduler_netATAC = optim.lr_scheduler.ReduceLROnPlateau(
        opt_netATAC,
        patience=5,
        threshold=0.01,
        mode="max",
        min_lr=1e-5,
    )
    
    best_match_score = 0
    test_sol = torch.from_numpy(input_test_sol.X.toarray()).float()
    train_sol = torch.from_numpy(input_train_sol.X.toarray()).float()
    
    if checkpoint != None:
        best_match_score = checkpoint["dev_match_score"]
        
    #training
    for epoch in range(hyper["nEpochs"]):
        train_losses = []
        train_scores = []
        netRNA.train()
        netATAC.train()
        #train for epochs
        for idx, (rna_inputs, atac_inputs, label) in enumerate(train_loader):
            opt_netATAC.zero_grad()
            opt_netRNA.zero_grad()
            rna_inputs = Variable(rna_inputs).to(device)
            atac_inputs = Variable(atac_inputs).to(device)
            
            recon_rna, z_rna, mu_rna, logvar_rna = netRNA(rna_inputs)
            recon_atac, z_atac, mu_atac, logvar_atac = netATAC(atac_inputs)
            #rna_loss = basic_loss(recon_rna, rna_inputs, mu_rna, logvar_rna, lamb1=hyper["lamb_kl"])
            atac_loss = basic_loss(recon_atac, atac_inputs, mu_atac, logvar_atac, lamb1=hyper["lamb_kl"])
            contrastive_loss = criterion_contrastive(z_rna, z_atac, label)
            
            '''if epoch % 10 == 0:
                print(f"rna_loss: {rna_loss}")
                print(f"atac_loss:{atac_loss}")
                print(f"contrastive loss: {contrastive_loss}")'''

            #loss functions for each modalities
            train_loss = atac_loss + contrastive_loss
            
            #train_loss = rna_loss + atac_loss
            #train_loss = rna_loss + atac_loss + hyper["lamb_anc"] * anc_loss
            #rain_loss = rna_loss + atac_loss + hyper["lamb_anc"] * anc_loss + h_loss
            train_loss.backward()
            #nn.utils.clip_grad_norm_(netRNA.parameters(), max_norm=hyper["clip_grad"])
            #nn.utils.clip_grad_norm_(netATAC.parameters(), max_norm=hyper["clip_grad"])
            opt_netRNA.step()
            opt_netATAC.step()
            train_losses.append(train_loss.item())
            
            scores_batch = similar_score(z_rna, z_atac)
            train_scores.append(scores_batch.reshape((len(scores_batch), 1)))
        
        avg_train_loss = np.mean(train_losses)
        train_score_matrix = torch.tensor(np.vstack(train_scores).reshape((hyper["train_nobs"], hyper["train_nobs"])))
        train_socre_matrix = train_score_matrix / torch.sum(train_score_matrix, axis = 1)
        #train_score_matrix = F.softmax(torch.tensor(train_score_matrix), dim=1)
        train_match_score = torch.sum(train_score_matrix * train_sol)
        
        if epoch % 5 == 0:
            print("Epoch: " + str(epoch) + ", train loss: " + str(avg_train_loss))
            print("Epoch: " + str(epoch) + ", train similarity score: " + str(train_match_score))
        
        #evaluating step
        netRNA.eval()
        netATAC.eval()
        scores = []
        with torch.no_grad():
            for idx, samples in enumerate(test_loader):
                rna_inputs= samples[0].float().to(device)
                atac_inputs = samples[1].float().to(device)

                _, output_rna, _, _ = netRNA(rna_inputs)
                _, output_atac, _, _ = netATAC(atac_inputs)
                scores_batch = similar_score(output_rna, output_atac)
                scores.append(scores_batch.reshape((len(scores_batch), 1)))
                
        score_matrix = torch.tensor(np.vstack(scores).reshape((hyper["test_nobs"], hyper["test_nobs"])))
        score_matrix = score_matrix / torch.sum(score_matrix, axis = 1)
        if epoch == 10:
            print(score_matrix)
        dev_match_score = torch.sum(score_matrix * test_sol)
        
        if dev_match_score > best_match_score:
            torch.save({
                "epoch": epoch,
                "clip_grad": hyper['clip_grad'],
                "layer_sizes": hyper['layer_sizes'],
                "lr": hyper["lr"],
                "net_rna_state_dict": netRNA.state_dict(),
                "net_atac_state_dict": netATAC.state_dict(),
                "train_loss": avg_train_loss,
                "dev_match_score": dev_match_score,
            }, hyper["checkpoint_path"])
                
        if epoch % 5 == 0:
            print("Epoch: " + str(epoch) + ", dev similarity score: " + str(dev_match_score))
        
        scheduler_netRNA.step(dev_match_score)
        scheduler_netATAC.step(dev_match_score)

In [48]:
main()

Epoch: 0, train loss: 1.5615716874771217
Epoch: 0, train similarity score: tensor(187.1886)
Epoch: 0, dev similarity score: tensor(1.0003)
Epoch: 5, train loss: 0.045573747396224835
Epoch: 5, train similarity score: tensor(241.1765)
Epoch: 5, dev similarity score: tensor(1.0262)
Epoch: 10, train loss: 0.1644353379999272
Epoch: 10, train similarity score: tensor(227.8022)
tensor([[ 4.7032e-03,  9.4605e-04,  2.1284e-03,  ...,  5.6483e-03,
          5.3538e-03, -1.4571e-04],
        [ 5.0714e-03,  1.8244e-04, -5.0495e-05,  ...,  6.0322e-03,
          5.7575e-03, -6.2041e-04],
        [ 4.7388e-03, -7.1514e-04,  1.8912e-03,  ...,  5.7369e-03,
          5.3974e-03,  2.1532e-03],
        ...,
        [ 4.8433e-03,  1.0796e-03, -9.8414e-04,  ...,  5.8315e-03,
          5.5583e-03,  7.6700e-04],
        [ 4.8757e-03, -1.1154e-03,  4.1089e-04,  ...,  5.8013e-03,
          5.5328e-03, -7.0081e-04],
        [ 5.1629e-03, -3.0785e-04,  2.0870e-03,  ...,  6.2070e-03,
          5.9521e-03, -3.9464e-

In [17]:
def model_eval(netRNA, netATAC, test_data_filtered, test_data_raw, title):
    netRNA.eval()
    netATAC.eval()
    rna_inputs = Variable(torch.from_numpy(test_data_filtered.X.toarray()).float())
    atac_inputs = Variable(torch.from_numpy(test_data_filtered.obsm["mode2"].toarray()).float())
    if torch.cuda.is_available():
        rna_inputs = rna_inputs.cuda()
        atac_inputs = atac_inputs.cuda()
    _, z_rna, _, _ = netRNA(rna_inputs)
    _, z_atac, _, _ = netATAC(atac_inputs)
    test_data_raw.obsm["aligned"] = sparse.csr_matrix(z_rna.cpu().detach())
    test_data_raw.obsm["mode2_aligned"] = sparse.csr_matrix(z_atac.cpu().detach())
    metrics.plot_multimodal_umap(test_data_raw, title=title, num_points=100, connect_modalities=True)
    knn_score, mse_score = metrics.knn_auc(test_data_raw), metrics.mse(test_data_raw)
    return knn_score, mse_score

In [18]:
#load checkpoint
checkpoint=None
if path.exists(hyper["checkpoint_path"]):
    checkpoint = torch.load(hyper["checkpoint_path"], map_location="cpu")
netRNA = FC_VAE(n_input=hyper["dimRNA"], nz=hyper["nz"], layer_sizes=hyper["layer_sizes"])
netATAC = FC_VAE(n_input=hyper["dimATAC"], nz=hyper["nz"], layer_sizes=hyper["layer_sizes"])
if checkpoint != None:
    netRNA.load_state_dict(checkpoint["net_rna_state_dict"])
    netATAC.load_state_dict(checkpoint["net_atac_state_dict"])
    
#plot UMAP result and show evaluation metrics value
model_eval(netRNA, netATAC, test_data_filtered, test_data_raw, title="VAE with Structure-Preserving Loss")

NameError: name 'test_data_filtered' is not defined

In [None]:
def similairty(netRNA, netATAC, test_data_filtered):
    netRNA.eval()
    netATAC.eval()
    rna_inputs = Variable(torch.from_numpy(test_data_filtered.X.toarray()).float())
    atac_inputs = Variable(torch.from_numpy(test_data_filtered.obsm["mode2"].toarray()).float())
    if torch.cuda.is_available():
        rna_inputs = rna_inputs.cuda()
        atac_inputs = atac_inputs.cuda()
    _, z_rna, _, _ = netRNA(rna_inputs)
    _, z_atac, _, _ = netATAC(atac_inputs)
    cos_score = nn.CosineSimilairty(z_rna, z_atac)
    return(cos_score)