In [None]:
#TODO: Dataset, data['label'], check optimizers, check structure

In [2]:
import numpy as np
import pandas as pd
from scipy import sparse

import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import diffusion_dist as diff

import matplotlib.pyplot as plt

from sklearn import decomposition, preprocessing
from torch.utils.data import Dataset, DataLoader

import model.autoencoder as ae

from dataset import test_s_curve

In [3]:
expr_RNA = pd.read_csv("./data/Paul/Paul_processed_expr.csv", index_col=0).values
cell_info = pd.read_csv("./data/Paul/Paul_cell_meta.txt", sep="\t")
clusters = cell_info['cell_type2']

In [4]:
DPT_RNA = diff.DPT_similarity(expr_RNA)
# Phate distance
Phate_RNA = diff.phate_similarity(expr_RNA, t = 5, use_potential = True)
# Diffmap distance
Diffmap_RNA = diff.phate_similarity(expr_RNA, t = 5, use_potential = False)

DPT_RNA = torch.FloatTensor(DPT_RNA)
Phate_RNA = torch.FloatTensor(Phate_RNA)
Diffmap_RNA = torch.FloatTensor(Diffmap_RNA)

In [6]:
# TODO: ATAC data
expr_ATAC = expr_RNA
DPT_ATAC = diff.DPT_similarity(expr_ATAC)
Phate_ATAC = diff.phate_similarity(expr_ATAC, t = 5, use_potential = True)
Diffmap_ATAC = diff.phate_similarity(expr_ATAC, t = 5, use_potential = False)

DPT_ATAC = torch.FloatTensor(DPT_ATAC)
Phate_ATAC = torch.FloatTensor(Phate_ATAC)
Diffmap_ATAC = torch.FloatTensor(Diffmap_ATAC)

In [7]:
def pairwise_distance(x):
    x_norm = (x**2).sum(1).view(-1, 1)
    y_norm = x_norm.view(1, -1)
    dist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(x, 0, 1))
    dist = torch.sqrt(dist + 1e-2)
    return dist 

In [9]:
def dist_loss(z, diff_sim, lamb):
    latent_sim = latent_sim / torch.nrom(latent_sim)
    diff_sim = diff_sim / torch.norm(diff_sim)
    return - lamb * latent_sim * diff_sim

In [8]:
def ae_loss(recon_rna, recon_atac, x_rna, x_atac, z_rna, z_atac, diff_sim_rna, diff_sim_atac, lamb_rna = 1, lamb_atac = 1, lamb_dist_rna = 1, lamb_dist_atac = 1):
    
    loss_recon_rna = lamb_rna * F.mse_loss(recon_rna, x_rna)
    loss_recon_atac = lamb_atac * F.mse_loss(recon_atac, x_atac)

    loss_dist_rna = dist_loss(z_rna, diff_sim_rna, lamb_dist_rna)
    loss_dist_atac = dist_loss(z_atac, diff_sim_atac, lamb_dist_atac)

    total_loss = loss_recon_rna + loss_recon_atac + loss_dist_rna + loss_dist_atac

    return total_loss, loss_recon_rna, loss_recon_atac, loss_dist_rna, loss_dist_atac

In [10]:
def train(model, disc, data_loader, diff_sim_rna, diff_sim_atac, n_epochs = 100, learning_rate = 1e-3, lamb_rna = 1, lamb_atac = 1, lamb_dist_rna = 1, lamb_dist_atac = 1):
    
    # not sure about this
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)    
    optimizer_E_rna = torch.optim.Adam(model.rna_encoder.parameters(), lr = learning_rate)
    optimizer_E_atac = torch.optim.Adam(model.atac_encoder.parameters(), lr = learning_rate)
    optimizer_D = torch.optim.Adam(disc.parameters(), lr = learning_rate)

    for epoch in n_epochs:
        for data in data_loader:
            batch_cols = data['index']
            batch_sim_rna = diff_sim_rna[batch_cols,:][:,batch_cols]
            batch_sim_atac = diff_sim_atac[batch_cols,:][:,batch_cols]
            batch_x_rna = data['RNA']
            batch_x_atac = data['ATAC']
            batch_label = data['label']

            # update AE
            recon_rna, recon_atac, z_rna, z_atac = model(batch_x_rna, batch_x_atac)
            total_loss, loss_recon_rna, loss_recon_atac, loss_dist_rna, loss_dist_atac = ae_loss(recon_rna, recon_atac, batch_x_rna, batch_x_atac, z_rna, z_atac, batch_sim_rna, batch_sim_atac)
            total_loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # update Discriminator
            D_loss_avg = 0
            n_iter = 10
            for disc_iter in range(n_iter):
                output = torch.cat((disc(z_rna), disc(z_atac)))
                D_loss = F.nll_loss(output, batch_label)
                D_loss_avg += D_loss.item()
                D_loss.backward()
                optimizer_D.step()
                optimizer_D.zero_grad()
            D_loss_avg /= n_iter

            # update Encoder
            E_loss = -1 * D_loss
            E_loss.backward()
            optimizer_E_rna.step()
            optimizer_E_atac.step()
            optimizer_E_rna.zero_grad()
            optimizer_E_atac.zero_grad()

        if epoch % 10 == 0:
            print("AE loss: ", total_loss.item(), "D loss: ", D_loss_avg)