In [1]:
%%capture
!pip install scprep
!pip install anndata
!pip install scanpy

In [2]:
import numpy as np
import pandas as pd
import anndata
import scprep
import scanpy as sc
import sklearn
from sklearn.model_selection import train_test_split
import tempfile
import os
import sys
import scipy
from scipy import sparse

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split

import load_raw
import normalize_tools as nm
import metrics

In [15]:
class StructureHingeLoss(nn.Module):
    def __init__(self, margin, max_val, lamb_match, lamb_nn):
        super(StructureHingeLoss, self).__init__()
        self.margin = margin
        self.max_val = max_val
        self.lamb_match = lamb_match
        self.lamb_nn = lamb_nn
        
    def forward(self, rna_outputs, atac_outputs, nn_indices):
        #rna_outputs: n_batch x n_latent
        #atac_outputs: n_batch x n_latent
        assert rna_outputs.shape[0] == atac_outputs.shape[0]
        assert rna_outputs.shape[1] == atac_outputs.shape[1]
        n_batch = rna_outputs.shape[0]
        
        #calculated pairwise L2 distance
        #dist_rna_atac[i][j]: the L2 distance between RNA embedding i
        #and ATAC embedding j (n_batch x n_batch)
        #constraint for ensuring every rna embedding is close to matched atac embedding
        dist_rna_atac = torch.cdist(rna_outputs, atac_outputs, p=2)
        match_labels = torch.eye(n_batch)
        match_mask = match_labels > 0
        pos_match_dist = torch.masked_select(dist_rna_atac, match_mask).view(n_batch, 1)
        neg_match_dist = torch.masked_select(dist_rna_atac, ~match_mask).view(n_batch, -1)
        
        loss_match_rna = torch.clamp(self.margin + pos_match_dist - neg_match_dist, 0, self.max_val)
        loss_match_rna = loss_match_rna.mean()
        #print(f"loss_match_rna: {loss_match_rna}")
        
        #constraint for ensuring every atac embedding is close to matched rna embedding
        dist_atac_rna = dist_rna_atac.t()
        pos_match_dist = torch.masked_select(dist_atac_rna, match_mask).view(n_batch, 1)
        neg_match_dist = torch.masked_select(dist_atac_rna, ~match_mask).view(n_batch, -1)
        
        loss_match_atac = torch.clamp(self.margin + pos_match_dist - neg_match_dist, 0, self.max_val)
        loss_match_atac = loss_match_rna.mean()
        #print(f"loss_match_atac: {loss_match_atac}")
        
        #constraint for ensuring that every RNA embedding is close to 
        #the neighboring RNA embeddings.
        nn_masked = torch.zeros(n_batch, n_batch)
        nn_masked.scatter_(1, nn_indices, 1.)
        nn_masked = nn_masked > 0
        
        dist_rna_rna = torch.cdist(rna_outputs, rna_outputs, p=2)
        
        #pos_rna_nn_dist: n_batch x n_neighbor
        pos_rna_nn_dist = torch.masked_select(dist_rna_rna, nn_masked).view(n_batch, -1)
        neg_rna_nn_dist = torch.masked_select(dist_rna_rna, ~nn_masked).view(n_batch, -1)
        rna_nn_loss = torch.clamp(self.margin + pos_rna_nn_dist[...,None] - neg_rna_nn_dist[..., None, :], 0, self.max_val)
        rna_nn_loss = rna_nn_loss.mean()
        #print(f"rna_nn_loss: {rna_nn_loss}")
        
        #constraint for ensuring that every ATAC embedding is close to 
        #the neighboring ATAC embeddings.
        dist_atac_atac = torch.cdist(atac_outputs, atac_outputs, p=2)
        #pos_rna_nn_dist: n_batch x n_neighbor
        pos_atac_nn_dist = torch.masked_select(dist_atac_atac, nn_masked).view(n_batch, -1)
        neg_atac_nn_dist = torch.masked_select(dist_atac_atac, ~nn_masked).view(n_batch, -1)
        atac_nn_loss = torch.clamp(self.margin + pos_atac_nn_dist[...,None] - neg_atac_nn_dist[..., None, :], 0, self.max_val)
        atac_nn_loss = atac_nn_loss.mean()
        #print(f"atac_nn_loss: {atac_nn_loss}")
        
        loss = (self.lamb_match * loss_match_rna 
                + self.lamb_match * loss_match_atac
                + self.lamb_nn * rna_nn_loss 
                + self.lamb_nn * atac_nn_loss)
        return loss

In [16]:
def collate_fn(batch):
    n_svd = 100
    proportion_neighbors = 0.1
    
    rna_inputs, atac_inputs = zip(*batch)
    rna_inputs = torch.stack(rna_inputs)
    atac_inputs = torch.stack(atac_inputs)
    
    n_svd = min([n_svd, min(rna_inputs.shape) - 1])
    n_neighbors = int(np.ceil(proportion_neighbors * rna_inputs.shape[0]))
    X_pca = sklearn.decomposition.TruncatedSVD(n_svd).fit_transform(rna_inputs)
    _, indices_true = (
        sklearn.neighbors.NearestNeighbors(n_neighbors=n_neighbors).fit(X_pca).kneighbors(X_pca)
    )
    
    return rna_inputs, atac_inputs, torch.from_numpy(indices_true)