In [3]:
%%capture
!pip install scprep
!pip install anndata
!pip install scanpy

In [4]:
import numpy as np
import pandas as pd
import anndata
import scprep
import scanpy as sc
import sklearn
from sklearn.model_selection import train_test_split
import tempfile
import os
import sys
import scipy
from scipy import sparse

import torch
import torch.nn as nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split

In [5]:
#set up all hyper-parameters
hyper = {
    "nEpochs":400,
    "dimRNA":60550,
    "dimATAC":146713,
    "n_hidden":1024,
    "nz":128,
    "batchSize":128,
    "lr":1e-3,
    "lamb_kl":0.0000001,
    "lamb_anc":0.0001,
    "weightDirName": './checkpoint/'
}

# **try out with scicar cell lines dataset**

**1. URLs for raw data**

In [6]:
rna_url = ("https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSM3271040"
    "&format=file&file=GSM3271040%5FRNA%5FsciCAR%5FA549%5Fgene%5Fcount.txt.gz")
rna_cells_url = (
    "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSM3271040"
    "&format=file&file=GSM3271040%5FRNA%5FsciCAR%5FA549%5Fcell.txt.gz"
)
rna_genes_url = (
    "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSM3271040"
    "&format=file&file=GSM3271040%5FRNA%5FsciCAR%5FA549%5Fgene.txt.gz"
)
atac_url = (
    "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSM3271041"
    "&format=file&file=GSM3271041%5FATAC%5FsciCAR%5FA549%5Fpeak%5Fcount.txt.gz"
)
atac_cells_url = (
    "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSM3271041"
    "&format=file&file=GSM3271041%5FATAC%5FsciCAR%5FA549%5Fcell.txt.gz"
)
atac_genes_url = (
    "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSM3271041"
    "&format=file&file=GSM3271041%5FATAC%5FsciCAR%5FA549%5Fpeak.txt.gz"
)

In [7]:
rna_genes = pd.read_csv(rna_genes_url, low_memory=False, index_col=0)
atac_genes =  pd.read_csv(atac_genes_url, low_memory=False, index_col=1)
rna_cells = pd.read_csv(rna_cells_url, low_memory=False, index_col=0)
atac_cells = pd.read_csv(atac_cells_url, low_memory=False, index_col=0)

In [8]:
with tempfile.TemporaryDirectory() as tempdir:
  rna_file = os.path.join(tempdir, "rna.mtx.gz")
  scprep.io.download.download_url(rna_url, rna_file)
  rna_data = scprep.io.load_mtx(rna_file, cell_axis="col").tocsr()
  atac_file = os.path.join(tempdir, "atac.mtx.gz")
  scprep.io.download.download_url(atac_url, atac_file)
  atac_data = scprep.io.load_mtx(atac_file, cell_axis="col").tocsr()

 **2. select the joint sub-datasets and store them into csv files**

In [9]:
def create_joint_dataset(X, Y, X_index=None, X_columns=None, Y_index=None, Y_columns=None):
    if X_index is None:
        X_index = X.index
    if X_columns is None:
        X_columns = X.columns
    if Y_index is None:
        Y_index = Y.index
    if Y_columns is None:
        Y_columns = Y.columns
    joint_index = np.sort(np.intersect1d(X_index, Y_index))
    try:
        X = X.loc[joint_index]
        Y = Y.loc[joint_index]
    except AttributeError:
        x_keep_idx = np.isin(X_index, joint_index)
        y_keep_idx = np.isin(Y_index, joint_index)
        X = X[x_keep_idx]
        Y = Y[y_keep_idx]
        X_index_sub = scprep.utils.toarray(X_index[x_keep_idx])
        Y_index_sub = scprep.utils.toarray(Y_index[y_keep_idx])
        X = X[np.argsort(X_index_sub)]
        Y = Y[np.argsort(Y_index_sub)]
        # check order is correct
        assert (X_index_sub[np.argsort(X_index_sub)] == joint_index).all()
        assert (Y_index_sub[np.argsort(Y_index_sub)] == joint_index).all()
    adata = anndata.AnnData(
        scprep.utils.to_array_or_spmatrix(X).tocsr(),
        obs = pd.DataFrame(index = joint_index),
        var = pd.DataFrame(index = X_columns),
    )
    adata.obsm["mode2"] = scprep.utils.to_array_or_spmatrix(Y).tocsr()
    adata.uns["mode2_obs"] = joint_index
    adata.uns["mode2_var"] = scprep.utils.toarray(Y_columns)
    return adata

In [10]:
def subset_mode2_genes(adata, keep_genes):
  adata.obsm["mode2"] = adata.obsm["mode2"][:, keep_genes]
  adata.uns["mode2_var"] = adata.uns["mode2_var"][keep_genes]
  if "mode2_varnames" in adata.uns:
    for varname in adata.uns["mode2_varnames"]:
      adata.uns[varname] = adata.uns[varname][keep_genes]
  return adata

In [11]:
def filter_joint_data_empty_cells(adata):
    assert np.all(adata.uns["mode2_obs"] == adata.obs.index)
    #filter out cells
    n_cells_mode1 = scprep.utils.toarray(adata.X.sum(axis = 1)).flatten()
    n_cells_mode2 = scprep.utils.toarray(adata.obsm["mode2"].sum(axis = 1)).flatten()
    keep_cells = np.minimum(n_cells_mode1, n_cells_mode2) > 1
    adata.uns["mode2_obs"] = adata.uns["mode2_obs"][keep_cells]
    adata = adata[keep_cells, :].copy()
    #filter out genes
    sc.pp.filter_genes(adata, min_counts=1)
    n_genes_mode2 = scprep.utils.toarray(adata.obsm["mode2"].sum(axis=0)).flatten()
    keep_genes_mode2 = n_genes_mode2 > 0
    adata = subset_mode2_genes(adata, keep_genes_mode2)
    return adata

In [12]:
def merge_data(rna_data, atac_data, rna_cells, atac_cells, rna_genes, atac_genes):
  scicar_data = create_joint_dataset(
      rna_data, atac_data, 
      X_index=rna_cells.index, 
      X_columns=rna_genes.index, 
      Y_index=atac_cells.index,
      Y_columns=atac_genes.index)

  scicar_data.obs = rna_cells.loc[scicar_data.obs.index]
  scicar_data.var = rna_genes
  for key in atac_cells.columns:
      scicar_data.obs[key] = atac_cells[key]
  scicar_data.uns["mode2_varnames"] = []
  for key in atac_genes.columns:
      varname = "mode2_var_{}".format(key)
      scicar_data.uns[varname] = atac_genes[key].values
      scicar_data.uns["mode2_varnames"].append(varname)
  scicar_data = filter_joint_data_empty_cells(scicar_data)
  return scicar_data

In [13]:
scicar_data = merge_data(rna_data, atac_data, rna_cells, atac_cells, rna_genes, atac_genes)
#rna_df, atac_df = ann2df(scicar_data)

# **define pytorch datasets for RNA and ATAC**

In [14]:
class RNA_Dataset(Dataset):
  def __init__(self, adata):
    self.rna_data = self._load_rna_data(adata)

  def __len__(self):
    return len(self.rna_data)
  
  def __getitem__(self, idx):
    rna_sample = self.rna_data.values[idx]
    #return a tensor that for a single observation
    return torch.from_numpy(rna_sample).float()
  
  def _load_rna_data(self, adata):
     rna_df = pd.DataFrame(data = adata.X.toarray(), index = np.array(adata.obs.index), columns = np.array(adata.var.index))
     return rna_df

class ATAC_Dataset(Dataset):
  def __init__(self, adata):
    self.atac_data = self._load_atac_data(adata)

  def __len__(self):
    return len(self.atac_data)
  
  def __getitem__(self, idx):
    atac_sample = self.atac_data.values[idx]
    #return a tensor that for a single observation
    return torch.from_numpy(atac_sample).float()
  
  def _load_atac_data(self, adata):
    atac_df = pd.DataFrame(data = adata.obsm["mode2"].toarray(), index = np.array(adata.uns["mode2_obs"]), columns = np.array(adata.uns["mode2_var"]))
    return atac_df

In [15]:
class Merge_Dataset(Dataset):
  def __init__(self, adata):
    self.rna_data, self.atac_data = self._load_merge_data(adata)

  def __len__(self):
    #assert(len(self.rna_data) == len(self.atac_data))
    return len(self.atac_data)
  
  def __getitem__(self, idx):
    rna_sample = self.rna_data.values[idx]
    atac_sample = self.atac_data.values[idx]
    #return a tensor that for a single observation
    return {"rna_tensor": torch.from_numpy(rna_sample).float(), "atac_tensor": torch.from_numpy(atac_sample).float()}
  
  def _load_merge_data(self, adata):
    rna_df = pd.DataFrame(data = adata.X.toarray(), index = np.array(adata.obs.index), columns = np.array(adata.var.index))
    atac_df = pd.DataFrame(data = adata.obsm["mode2"].toarray(), index = np.array(adata.uns["mode2_obs"]), columns = np.array(adata.uns["mode2_var"]))
    return rna_df, atac_df

In [16]:
#test datasets
def test_loader(scicar_data):
  merge_dataset = Merge_Dataset(scicar_data)
  
  print(merge_dataset[0])
  print(len(merge_dataset))

# **define basic models(autoencoders) for learning latent space**

In [17]:
class FC_VAE(nn.Module):
  def __init__(self, n_input, nz, n_hidden = hyper["n_hidden"]):
    super(FC_VAE, self).__init__()
    self.n_input = n_input
    self.nz = nz
    self.n_hidden = n_hidden

    self.encoder = nn.Sequential(
        nn.Linear(n_input, n_hidden),
        nn.ReLU(inplace=True),
        nn.BatchNorm1d(n_hidden),
        nn.Linear(n_hidden, n_hidden),
        nn.BatchNorm1d(n_hidden),
        nn.ReLU(inplace=True),
        nn.Linear(n_hidden, n_hidden),
        nn.BatchNorm1d(n_hidden),
        nn.ReLU(inplace=True),
        nn.Linear(n_hidden, n_hidden),
        nn.BatchNorm1d(n_hidden),
        nn.ReLU(inplace=True),
        nn.Linear(n_hidden, n_hidden),
    )
    self.fc1 = nn.Linear(n_hidden, nz)
    self.fc2 = nn.Linear(n_hidden, nz)

    self.decoder = nn.Sequential(
        nn.Linear(nz, n_hidden),
        nn.ReLU(inplace=True),
        nn.BatchNorm1d(n_hidden),
        nn.Linear(n_hidden, n_hidden),
        nn.BatchNorm1d(n_hidden),
        nn.ReLU(inplace=True),
        nn.Linear(n_hidden, n_hidden),
        nn.BatchNorm1d(n_hidden),
        nn.ReLU(inplace=True),
        nn.Linear(n_hidden, n_hidden),
        nn.BatchNorm1d(n_hidden),
        nn.ReLU(inplace=True),
        nn.Linear(n_hidden, n_input),
    )
  def encode(self, x):
    h = self.encoder(x)
    return self.fc1(h), self.fc2(h)

  def reparametrize(self, mu, logvar):
    #calculate std from log(var)
    std = logvar.mul(0.5).exp_()
    if torch.cuda.is_available():
      eps = torch.cuda.FloatTensor(std.size()).normal_()
    else:
      eps = torch.FloatTensor(std.size()).normal_()
    eps = Variable(eps)
    return eps.mul(std).add_(mu)

  def decode(self, z):
    return self.decoder(z)

  def forward(self, x):
    mu, logvar = self.encode(x)
    z = self.reparametrize(mu, logvar)
    res = self.decode(z)
    return res, z, mu, logvar

  def get_latent_var(self, x):
    mu, logvar = self.encode(x)
    z = self.reparametrize(mu, logvar)
    return z
    
  def generate(self, z):
    return self.decode(z)
      

# **train VAE model based on reconstruction, KL divergence, and anchor loss**

In [18]:
#load dataset and split train and test data
merge_dataset = Merge_Dataset(scicar_data)
train_len = int(len(merge_dataset)*0.8)
lengths = [train_len, len(merge_dataset)-train_len]
trainset, testset = random_split(merge_dataset, lengths)
#netRNA = FC_VAE()

In [19]:
#load data loader
train_loader = DataLoader(trainset, batch_size=hyper["batchSize"], drop_last=False, shuffle=True)
test_loader = DataLoader(testset, batch_size=hyper["batchSize"], drop_last=False, shuffle=False)

#load basic models
netRNA = FC_VAE(n_input=hyper["dimRNA"], nz=hyper["nz"])
netATAC = FC_VAE(n_input=hyper["dimATAC"], nz=hyper["nz"])

In [20]:
#use GPU
if torch.cuda.is_available():
  print("using GPU")
  netRNA.cuda()
  netATAC.cuda()

#setup optimizers for two nets
opt_netRNA = optim.Adam(list(netRNA.parameters()), lr=hyper["lr"])
opt_netATAC = optim.Adam(list(netATAC.parameters()), lr=hyper["lr"])

using GPU


In [21]:
#set up loss function
def basic_loss(recon_x, x, mu, logvar, lamb1):
  MSE = nn.MSELoss()
  lloss = MSE(recon_x, x)
  #KL divergence
  KL_loss = -0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
  lloss = lloss + lamb1*KL_loss
  return lloss

#anchor loss for minimizing distance between paired observation
def anchor_loss(embed_rna, embed_atac):
  L1 = nn.L1Loss(reduction = 'sum')
  anc_loss = L1(embed_rna, embed_atac)
  return anc_loss

In [22]:
#set up train functions
def train(epoch):
  netRNA.train()
  netATAC.train()
  train_loss = 0
  for idx, samples in enumerate(train_loader):
    rna_inputs, atac_inputs = Variable(samples["rna_tensor"]), Variable(samples["atac_tensor"])
    if torch.cuda.is_available():
      rna_inputs = rna_inputs.cuda()
      atac_inputs = atac_inputs.cuda()
      opt_netATAC.zero_grad()
      opt_netRNA.zero_grad()

      recon_rna, z_rna, mu_rna, logvar_rna = netRNA(rna_inputs)
      recon_atac, z_atac, mu_atac, logvar_atac = netATAC(atac_inputs)
      rna_loss = basic_loss(recon_rna, rna_inputs, mu_rna, logvar_rna, lamb1=hyper["lamb_kl"])
      atac_loss = basic_loss(recon_atac, atac_inputs, mu_atac, logvar_atac, lamb1=hyper["lamb_kl"])
      anc_loss = anchor_loss(z_rna, z_atac)
      
      #loss functions for each modalities
      train_loss = rna_loss + atac_loss + hyper["lamb_anc"]*anc_loss
      train_loss.backward()
      opt_netRNA.step()
      opt_netATAC.step()
      

In [23]:
#train a toy model and see the scores
max_iter = hyper["nEpochs"]
for epoch in range(max_iter):
  train(epoch)
  #set up log
  #if epoch % 50 == 0:
    #print("***saving checkpoints***")
    #path = "{}Max_iter_{}lamb_anc_{}Epoch_{}params.pth".format(hyper["weightDirName"], str(hyper["nEpochs"]), str(hyper["lamb_anc"]), str(epoch))
    
    #torch.save({
    #    "epoch": epoch,
    #    'netRNA_state_dict': netRNA.state_dict(),
    #    'netATAC_state_dict': netATAC.state_dict(),
    # }, path)

del trainset
del testset
#start to evaluate the data
netRNA.eval()
rna_inputs = Variable(torch.from_numpy(merge_dataset.rna_data.values).float())

if torch.cuda.is_available():
  rna_inputs = rna_inputs.cuda()

_, z_rna, _, _ = netRNA(rna_inputs)
scicar_data.obsm["aligned"] = sparse.csr_matrix(z_rna.cpu().detach())
del z_rna
del netRNA
del rna_inputs
del train_loader
del test_loader

In [24]:
netATAC.eval()
atac_inputs = Variable(torch.from_numpy(merge_dataset.atac_data.values).float())

del merge_dataset

if torch.cuda.is_available():
  atac_inputs = atac_inputs.cuda()

_, z_atac, _, _ = netATAC(atac_inputs)
scicar_data.obsm["mode2_aligned"] = sparse.csr_matrix(z_atac.cpu().detach())
del atac_inputs
del netATAC
del z_atac

In [25]:
#KNN-AUC
def knn_auc(adata, proportion_neighbors=0.1, n_svd=100):
  n_svd = min([n_svd, min(adata.X.shape)-1])
  n_neighbors = int(np.ceil(proportion_neighbors*adata.X.shape[0]))
  X_pca = sklearn.decomposition.TruncatedSVD(n_svd).fit_transform(adata.X)
  _, indices_true = (
      sklearn.neighbors.NearestNeighbors(n_neighbors = n_neighbors).fit(X_pca).kneighbors(X_pca)
  )
  _, indices_pred = (
      sklearn.neighbors.NearestNeighbors(n_neighbors=n_neighbors).fit(adata.obsm["aligned"]).kneighbors(adata.obsm["mode2_aligned"])
  )
  neighbors_match = np.zeros(n_neighbors, dtype=int)
  for i in range(adata.shape[0]):
    _, pred_matches, true_matches = np.intersect1d(
        indices_pred[i], indices_true[i], return_indices=True
    )
    neighbors_match_idx = np.maximum(pred_matches, true_matches)
    neighbors_match += np.sum(np.arange(n_neighbors) >= neighbors_match_idx[:, None], axis = 0,)
    neighbors_match_curve = neighbors_match/(np.arange(1, n_neighbors + 1) * adata.shape[0])
    area_under_curve = np.mean(neighbors_match_curve)
    return area_under_curve

In [26]:
#MSE
def _square(X):
  if sparse.issparse(X):
        X.data = X.data ** 2
        return X
  else:
        return scprep.utils.toarray(X) ** 2

def mse(adata):
  X=scprep.utils.toarray(adata.obsm["aligned"])
  Y=scprep.utils.toarray(adata.obsm["mode2_aligned"])

  X_shuffled = X[np.random.permutation(np.arange(X.shape[0])), :]
  error_random = np.mean(np.sum(_square(X_shuffled - Y)))
  error_abs = np.mean(np.sum(_square(X - Y)))
  return error_abs/error_random

In [27]:
knn_score = knn_auc(scicar_data)
mse_score = mse(scicar_data)
print(knn_score)
print(mse_score)

#log the metrics
path = "{}Max_iter_{}lamb_anc_{}metrics.txt".format(hyper["weightDirName"], str(hyper["nEpochs"]), str(hyper["lamb_anc"]))
'''torch.save({
    "num_iter": hyper["nEpochs"],
    "lamb_anc": hyper["lamb_anc"],
    'knn_auc': knn_score,
    'mse': mse_score,
}, path)'''

6.121189098180982e-06
0.97133


'torch.save({\n    "num_iter": hyper["nEpochs"],\n    "lamb_anc": hyper["lamb_anc"],\n    \'knn_auc\': knn_score,\n    \'mse\': mse_score,\n}, path)'

In [28]:
with open(path, 'a') as f:
        print('nEpoch: ', hyper["nEpochs"], 'lamb_anc:%.8f'%float(hyper["lamb_anc"]) , ',knn_auc: %.8f' % float(knn_score), ', mse_score: %.8f' % float(mse_score), file=f)