In [1]:
%%capture
!pip install scprep
!pip install anndata
!pip install scanpy

In [30]:
import numpy as np
import pandas as pd
import anndata
import scprep
import scanpy as sc
import sklearn
from sklearn.model_selection import train_test_split
import tempfile
import os
import sys
import scipy
from scipy import sparse

import torch
import torch.nn as nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split

import matplotlib.pyplot as plt
import load_raw
import normalize_tools as nm
import metrics

# **try out with scicar cell lines dataset**

**1. URLs for raw data**

In [5]:
rna_data, atac_data, rna_cells, atac_cells, rna_genes, atac_genes = load_raw.load_raw_cell_lines()

**2. select the joint sub-datasets** 

In [6]:
scicar_data, joint_index, keep_cells_idx = load_raw.merge_data(rna_data, atac_data, rna_cells, atac_cells, rna_genes, atac_genes)
#rna_df, atac_df = ann2df(scicar_data)

# **logcpm**

In [8]:
#tryout log cpm scicar_data
nm.log_cpm(scicar_data)
nm.log_cpm(scicar_data, obsm = "mode2", obs = "mode2_obs", var = "mode2_var")
nm.hvg_by_sc(scicar_data)

In [11]:
nm.hvg_by_sc(scicar_data,  obsm = "mode2", obs = "mode2_obs", var = "mode2_var")

In [12]:
print(len(scicar_data.uns["mode2_var"]))
print(len(scicar_data.var))

14671
6057


# **define pytorch datasets for RNA and ATAC**

In [15]:
class Merge_Dataset(Dataset):
  def __init__(self, adata):
    self.rna_data, self.atac_data = self._load_merge_data(adata)

  def __len__(self):
    #assert(len(self.rna_data) == len(self.atac_data))
    return len(self.atac_data)
  
  def __getitem__(self, idx):
    rna_sample = self.rna_data.values[idx]
    atac_sample = self.atac_data.values[idx]
    #return a tensor that for a single observation
    return {"rna_tensor": torch.from_numpy(rna_sample).float(), "atac_tensor": torch.from_numpy(atac_sample).float()}
  
  def _load_merge_data(self, adata):
    rna_df = pd.DataFrame(data = adata.X.toarray(), index = np.array(adata.obs.index), columns = np.array(adata.var.index))
    atac_df = pd.DataFrame(data = adata.obsm["mode2"].toarray(), index = np.array(adata.uns["mode2_obs"]), columns = np.array(adata.uns["mode2_var"]))
    return rna_df, atac_df

# **Compute DCCA loss (-corr(H1, H2))**



In [16]:
class cca_loss():
  def __init__(self, out_dim, device, use_all_singvals=False):
    self.out_dim = out_dim #parameter o in original paper
    self.use_all_singvals = use_all_singvals
    self.device = device
  
  def loss(self, H1, H2):
    r1 = 1e-3
    r2 = 1e-3
    eps = 1e-9

    #transpose H1, H2: m x o -> o x m
    H1 = H1.t()
    H2 = H2.t()
    #assert torch.isnan(H1).sum().item() == 0
    #assert torch.isnan(H2).sum().item() == 0

    m = H1.size(1)
    o1, o2 = H1.size(0), H2.size(0)

    #produce the centered data matrices: H1 - 1/m*H1·I (same for H2bar)
    H1bar = H1 - H1.mean(dim=1).unsqueeze(dim=1)
    H2bar = H2 - H2.mean(dim=1).unsqueeze(dim=1)
    assert torch.isnan(H1bar).sum().item() == 0
    assert torch.isnan(H2bar).sum().item() == 0

    SigmaHat12 = (1.0/(m-1))*torch.matmul(H1bar, H2bar.t())
    SigmaHat11 = (1.0/(m-1))*torch.matmul(H1bar, H1bar.t()) + r1*torch.eye(o1, device=self.device)
    SigmaHat22 = (1.0/(m-1))*torch.matmul(H2bar, H2bar.t()) + r2*torch.eye(o2, device=self.device)
    #assert torch.isnan(SigmaHat11).sum().item() == 0
    #assert torch.isnan(SigmaHat12).sum().item() == 0
    #assert torch.isnan(SigmaHat22).sum().item() == 0

    #calculate the root inverse (e.g. SigmaHat11^(-1/2)) using sigular value decomposition
    D1, V1 = torch.symeig(SigmaHat11, eigenvectors=True)
    D2, V2 = torch.symeig(SigmaHat22, eigenvectors=True)

    # ??? probably problemetic in gene count setting
    posIdx1 = torch.gt(D1, eps).nonzero()[:, 0]
    D1 = D1[posIdx1]
    V1 = V1[:, posIdx1]

    posIdx2 = torch.gt(D2, eps).nonzero()[:, 0]
    D2 = D2[posIdx2]
    V2 = V2[:,posIdx2]

    #???take care of torch.sqrt
    SigmaHatRootInv11 = torch.matmul(torch.matmul(V1, torch.diag((D1)**(-0.5))), V1.t())
    SigmaHatRootInv22 = torch.matmul(torch.matmul(V2, torch.diag((D2)**(-0.5))), V2.t())

    #calculate T
    Tval = torch.matmul(torch.matmul(SigmaHatRootInv11, SigmaHat12), SigmaHatRootInv22)

    #calculate corr(H1, H2): matrix trace norm of T or sum of top k singular vals of T
    trace_TT = torch.matmul(Tval.t(), Tval)
    if self.use_all_singvals:
      corr = torch.trace(torch.sqrt(trace_TT))
      #assert torch.isnan(corr).item() == 0
      
    else:
      trace_TT = torch.add(trace_TT, (torch.eye(trace_TT.shape[0])*r1).to(self.device))
      U, V = torch.symeig(trace_TT, eigenvectors=True)
      U = torch.where(U>eps, U, (torch.ones(U.shape).double()*eps).to(self.device))
      U = U.topk(self.out_dim)[0]
      corr = torch.sum(torch.sqrt(U))
    #print("loss: " + str(-corr))
    return -corr

# **define basic models(now just encoder net) for learning latent space**

In [17]:
class EN_NET(nn.Module):
  def __init__(self, n_input, n_out, layer_sizes):
    super(EN_NET, self).__init__()
    self.n_input = n_input
    self.n_out = n_out
    self.layers = []
    self.layer_sizes = [n_input] + layer_sizes + [n_out]

    for layer_idx in range(len(self.layer_sizes)-1):
      if layer_idx == len(self.layer_sizes) - 2:
        self.layers.append(nn.Linear(self.layer_sizes[layer_idx], self.layer_sizes[layer_idx+1]))
      else:
        self.layers.append(nn.Linear(self.layer_sizes[layer_idx], self.layer_sizes[layer_idx+1]))
        self.layers.append(nn.BatchNorm1d(self.layer_sizes[layer_idx+1]))
        self.layers.append(nn.Sigmoid())

    self.encoder = nn.Sequential(
        *self.layers
    )
    
  def encode(self, x):
    return self.encoder(x)

  def reparametrize(self, mu, logvar):
    #calculate std from log(var)
    std = logvar.mul(0.5).exp_()
    if torch.cuda.is_available():
      eps = torch.cuda.FloatTensor(std.size()).normal_()
    else:
      eps = torch.FloatTensor(std.size()).normal_()
    eps = Variable(eps)
    return eps.mul(std).add_(mu)

  def forward(self, x):
    out = self.encode(x)
    return out

# Assembly Neural Net And Loss Into DCCA Model

In [18]:
class DCCA(nn.Module):
  def __init__(self, n_input1, n_input2, layer_sizes1, layer_sizes2, n_out, use_all_singvals=False, device=torch.device("cpu")):
    super(DCCA, self).__init__()
    self.Net1 = EN_NET(n_input1, n_out, layer_sizes1).double()
    self.Net2 = EN_NET(n_input2, n_out, layer_sizes2).double()
    self.loss = cca_loss(out_dim=n_out, use_all_singvals=use_all_singvals, device=device).loss

  def forward(self, x1, x2):
    output1 = self.Net1(x1)
    output2 = self.Net2(x2)
    return output1, output2

# **Train Basic Model**

In [22]:
#set up all hyper-parameters
hyper = {
    "nEpochs":45,
    "dimRNA":6057,
    "dimATAC":14671,
    "n_hidden":1024,
    "n_out":64,
    "batchSize":128,
    "lr":1e-3,
    "weightDirName": './checkpoint/'
}

In [23]:
#load dataset and split train and test data
merge_dataset = Merge_Dataset(scicar_data)
train_len = int(len(merge_dataset)*0.8)
lengths = [train_len, len(merge_dataset)-train_len]
trainset, testset = random_split(merge_dataset, lengths)

In [24]:
#use GPU if available
my_device = "cuda" if torch.cuda.is_available() else "cpu"

In [25]:
#load data loader
train_loader = DataLoader(trainset, batch_size=hyper["batchSize"], drop_last=False, shuffle=True)
test_loader = DataLoader(testset, batch_size=hyper["batchSize"], drop_last=False, shuffle=False)

#load basic models
toy_model = DCCA(n_input1=hyper["dimRNA"], 
                 n_input2=hyper["dimATAC"], 
                 layer_sizes1=[2048, 1024, 256], 
                 layer_sizes2=[4096, 2048, 512], 
                 n_out=hyper["n_out"],
                 use_all_singvals=False,
                 device=torch.device(my_device))

#set up optimizer
optimizer = optim.Adam(list(toy_model.parameters()), lr=hyper["lr"])

In [26]:
#set up train functions
def train(epoch):  
  #print("Epoch:"+str(epoch))
  train_losses = []
  toy_model.train()
  for idx, samples in enumerate(train_loader):
    rna_inputs = samples["rna_tensor"].double()
    atac_inputs = samples["atac_tensor"].double()
    rna_inputs, atac_inputs = Variable(rna_inputs), Variable(atac_inputs)
    rna_inputs = rna_inputs.to(my_device)
    atac_inputs = atac_inputs.to(my_device)
    optimizer.zero_grad()

    output_rna, output_atac = toy_model(rna_inputs, atac_inputs)
    #print("before loss calculated")
    loss_epoch = toy_model.loss(output_rna, output_atac)
    train_losses.append(loss_epoch.item())
    #print("after loss calculated")
    loss_epoch.backward()
    optimizer.step()
      
  #loss functions for each modalities
  train_loss = np.mean(train_losses)
  if epoch % 15 == 0:
    print("Epoch:"+str(epoch) + ", loss: " + str(train_loss))

In [27]:
def knn_criteria(rna_inputs, atac_inputs, rna_outputs, atac_outputs, proportion_neighbors=0.1, n_svd=100):
  n_svd = min([n_svd, min(rna_inputs.shape)-1])
  n_neighbors = int(np.ceil(proportion_neighbors*rna_inputs.shape[0]))
  X_pca = sklearn.decomposition.TruncatedSVD(n_svd).fit_transform(rna_inputs)
  _, indices_true = (
      sklearn.neighbors.NearestNeighbors(n_neighbors = n_neighbors).fit(rna_inputs).kneighbors(rna_inputs)
  )
  _, indices_pred = (
      sklearn.neighbors.NearestNeighbors(n_neighbors=n_neighbors).fit(rna_outputs).kneighbors(atac_outputs)
  )
  neighbors_match = np.zeros(n_neighbors, dtype=int)
  for i in range(rna_inputs.shape[0]):
    _, pred_matches, true_matches = np.intersect1d(
        indices_pred[i], indices_true[i], return_indices=True
    )
    neighbors_match_idx = np.maximum(pred_matches, true_matches)
    neighbors_match += np.sum(np.arange(n_neighbors) >= neighbors_match_idx[:, None], axis = 0,)
  neighbors_match_curve = neighbors_match/(np.arange(1, n_neighbors + 1) * rna_inputs.shape[0])
  area_under_curve = np.mean(neighbors_match_curve)
  return area_under_curve

In [28]:
def test_model(epoch, test_loader, device):
  with torch.no_grad():
    toy_model.eval()
    knn_acc = []
    #mse_acc = []
    for idx, samples in enumerate(test_loader):
      rna_inputs = samples["rna_tensor"].double()
      atac_inputs = samples["atac_tensor"].double()
      rna_inputs = rna_inputs.to(device)
      atac_inputs = atac_inputs.to(device)

      output_rna, output_atac = toy_model(rna_inputs, atac_inputs)
      knn_acc.append(knn_criteria(rna_inputs, atac_inputs, output_rna, output_atac))
      avg_knn_acc = np.mean(knn_acc)
    if epoch % 15 == 0:
      print("Epoch:"+str(epoch) + ", average knn_acc: " + str(avg_knn_acc))

In [29]:
max_iter = hyper["nEpochs"]
for epoch in range(max_iter):
  train(epoch)
  test_model(epoch, test_loader, my_device)

The default behavior has changed from using the upper triangular portion of the matrix by default to using the lower triangular portion.
L, _ = torch.symeig(A, upper=upper)
should be replaced with
L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')
and
L, V = torch.symeig(A, eigenvectors=True)
should be replaced with
L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L') (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448234945/work/aten/src/ATen/native/BatchLinearAlgebra.cpp:2500.)
  D1, V1 = torch.symeig(SigmaHat11, eigenvectors=True)


Epoch:0, loss: -39.87804021644458
Epoch:0, average knn_acc: 0.0947296360156937
Epoch:15, loss: -45.429816941867486
Epoch:15, average knn_acc: 0.08020917925890883
Epoch:30, loss: -47.998591428690375
Epoch:30, average knn_acc: 0.03587534406532303


In [31]:
atac_inputs = Variable(torch.from_numpy(merge_dataset.atac_data.values).double())
rna_inputs = Variable(torch.from_numpy(merge_dataset.rna_data.values).double())
toy_model.eval()
out_rna, out_atac = toy_model(rna_inputs, atac_inputs)
scicar_data.obsm["aligned"] = sparse.csr_matrix(out_rna.cpu().detach())
scicar_data.obsm["mode2_aligned"] = sparse.csr_matrix(out_atac.cpu().detach())
print(metrics.knn_auc(scicar_data))
print(metrics.mse(scicar_data))

0.03049226498042545
1.0063349708800973


In [None]:
model_out = DCCA(n_input1=hyper["dimRNA"], 
                 n_input2=hyper["dimATAC"], 
                 layer_sizes1=[2048, 1024, 256], 
                 layer_sizes2=[4096, 2048, 512], 
                 n_out=hyper["n_out"],
                 use_all_singvals=False,
                 device=torch.device(my_device))
checkpoint_out64 = torch.load(path_out64)
model_out64.load_state_dict(checkpoint_out64["model_state_dict"])

NameError: ignored

In [None]:
path_out100 = ""
torch.save({
    "num_iter": hyper["nEpochs"],
    "n_out": hyper["n_out"],
    "model_state_dict": toy_model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "knn_auc": 9.643159972821067e-06,
    "mse": 0.985969183721269
}, path_out100)