In [None]:
import numpy as np
import scipy.sparse as scs
from scipy.stats import multinomial
import pathlib as pl
import pandas as pd
import matplotlib.pyplot as plt
from scipy.io import mmread

import umap as um


from os import listdir
from os.path import isfile, join
import os

import random

from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors

import seaborn as sns

from sklearn.preprocessing import StandardScaler

import random
from sklearn import metrics
from sklearn.metrics import precision_recall_curve, accuracy_score, average_precision_score


import tensorflow as tf

import os
import pickle

import seaborn as sns

from sklearn.model_selection import train_test_split

from kneed import KneeLocator

from numpy.random import seed

import time
from scipy.signal import savgol_filter


In [None]:
from plot_results import get_dbl_metrics
from vae import define_clust_vae, define_vae
from PU import PU, epoch_PU2
from PU import noPU
from classifier import define_classifier
from mk_doublets import sim_inflate, sim_avg, sim_sum
from cluster import cluster, fast_cluster

In [None]:
import scanpy as sc
import anndata

In [None]:
def sigmoid(x):
    sig = 1 / (1 + np.exp(-(12*x)+6))
    return sig

In [None]:
#pca_comp=30, clust_weight=20000
data_dir = 'sce_normalized_data_inflate'
save = True
use_old = True

#VAE hyperparams
eps  = 1000
enc_sze = 5
pat = 20
LR=1e-3
clust_weight = 20000

#PU hyperparameters
cls_eps = 250 
stop_metric = 'ValAUC'
puPat = 5
puLR =1e-3
pu_num_layers = 1
k_mult = 2
N = 1

gene_thresh=.01

In [None]:
seeds = [42, 29503, 432809, 42, 132975, 9231996, 12883823, 9231996, 1234, 62938, 57203 ,109573, 23]

In [None]:
data_dir = 'sce_normalized_data_inflate'
path = '../data/' + data_dir + '/'
files = [f for f in listdir(path) if (isfile(join(path, f)) & (f[-18:-4] == 'real_logcounts'))]
files = np.sort(files)
#files = files[3:4]
files

In [None]:
#nohup jupyter nbconvert --to notebook --execute CLUST-VAEDA-ablation.ipynb > nohup_ablation_final.out &

In [None]:
time_names = ['total', 'HVGs', 'scaling1', 'knn', 'downsample', 'scaling2', 'cluster', 'vae', 'epoch_selection', 'PU_loop']
tmp1 = np.zeros((1, len(time_names)))
time_df = pd.DataFrame(tmp1, index=['time'], columns=time_names)

In [None]:
dim_reds = ['clust_vae', 'vae', 'pca']
PUs      = ['PU', 'noPU']
clsses   = ['knn', 'NN']
feats    = ['knnfeat', 'nofeat']
homos    = ['remove', 'keep']

In [None]:
'''dim_reds = ['clust_vae']
PUs      = ['PU']
clsses   = ['NN']
feats    = ['knnfeat']
homos    = ['remove']'''

In [None]:
for file in files:

    data_name = file[:-19]
    print(data_name) 
    
    save_path = '../results_PU/ablation_analysis/'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    print(save_path)
        
    real_path = '../data/mtx_files/' + data_name + '.mtx'
    ano_path  = '../data/mtx_files/' + data_name + '_anno.csv'
    
    print('loading in real mtx')
    dat_real = mmread(real_path)
    Xr = scs.csr_matrix(dat_real).toarray().T
    
    #scvae-dbl-btch/data/sce_normalized_data_inflate/pdx-MULTI_real_logcounts.mtx
    npz_sim_path  = save_path + '/' + data_name + '_sim_counts.npz'
    sim_ind_path  = save_path + '/' + data_name + '_sim_ind.npy'

    #- READ IN COUNTS
    npz_sim  = pl.Path(npz_sim_path)    
    
    if (npz_sim.exists() & use_old):
        print('loading in sim npz')
        dat_sim = scs.load_npz(npz_sim)
        sim_ind = np.load(sim_ind_path)
        ind1 = sim_ind[0,:]
        ind2 = sim_ind[1,:]
        Xs = scs.csr_matrix(dat_sim).toarray()
    else:
        print('generating new sim npz')
        Xs, ind1, ind2 = sim_inflate(Xr)
    
    dat = np.vstack([Xr,Xs])
    Y0 = np.concatenate([np.zeros(Xr.shape[0]), np.ones(Xs.shape[0])])
    
    #- READ IN BARCODE ANNOTATIONS
    ano = pd.read_csv(ano_path)
    true = pd.factorize(ano.x)[0]
    labels = ano.x
    if (labels[0]=='doublet'):
        tmp = true + 3
        tmp[tmp==3] = 1
        tmp[tmp==4] = 0
        true = tmp
    true0 = np.concatenate([true, np.full(Xs.shape[0],2)])
    labels0 = np.concatenate([labels, np.full(Xs.shape[0],'simulated')])

    #Filter genes
    thresh = np.floor(dat.shape[0]) * gene_thresh
    tmp    = np.sum((dat>0), axis=0)>thresh
    dat = dat[:,tmp]
    
    #- HVGs - NEED TO DO SOMETHING ELSE!    
    var = np.var(dat, axis=0)
    np.random.seed(3900362577)
    hvgs = np.argpartition(var, -2000)[-2000:]      
    

    #######################################################
    ######################### KNN #########################
    #######################################################
    
    TRUE = true
    
    for feat in feats:
        
        save_path = '../results_PU/ablation_analysis/'
        
        X1 = dat[:,hvgs]
        Y1 = Y0
        labels1 = labels0
        true1 = true0
        
        print('dat shape:', dat.shape)
        print('Y:', Y1.shape)
    
        if feat == 'knnfeat':
            
            knn_file = save_path + data_name + '_knn_feature.npy'
            
            if(pl.Path(knn_file).exists() & use_old):
                print('loading old knn')
                knn_feature1 = np.load(knn_file)
            
            else:
                print('generating new knn')
                #HYPERPARAMS
                neighbors = int(np.sqrt(X1.shape[0]))
                comp=30

                #SCALING
                temp_X = np.log2(X1+1)
                np.random.seed(42)
                scaler = StandardScaler().fit(temp_X.T)
                np.random.seed(42)
                temp_X = scaler.transform(temp_X.T).T

                #KNN
                np.random.seed(42)
                pca = PCA(n_components=comp)
                pca_proj = pca.fit_transform(temp_X)
                del(temp_X)

                np.random.seed(42)
                knn = NearestNeighbors(n_neighbors=neighbors)
                knn.fit(pca_proj,Y1)
                graph = knn.kneighbors_graph(pca_proj)
                knn_feature1 = np.squeeze(np.array(np.sum(graph[:,Y1==1], axis=1) / neighbors)) #sum across rows

                if(save):
                    np.save(knn_file, knn_feature1)                    


            #estimate true faction of doublets 
            quantile = np.quantile(knn_feature1[true0==2], .25)
            num = np.sum(knn_feature1[true0<2]>=quantile)
            min_num = int(np.round((sum(Y1==0) *0.05)))
            num = np.max([min_num, num])

            prob = knn_feature1[Y1==1] / np.sum(knn_feature1[Y1==1])
            np.random.seed(seeds[0])
            ind = np.random.choice(np.arange(sum(Y1==1)), size=num, p=prob, replace=False)

            #ind = sum(Y==0) + ind

            #downsample the simulated doublets
            enc_ind = np.concatenate([np.arange(sum(Y1==0)), (sum(Y1==0) + ind)])
            X1 = X1[enc_ind,:]
            Y1 = Y1[enc_ind]
            knn_feature1 = knn_feature1[enc_ind]
            new_labs1 = labels1[enc_ind]
            true1 = true1[enc_ind]
            
        else:            
            #randomly downsample 10% doublets
            num = int(np.round((sum(Y1==0) *0.1)))
            np.random.seed(seeds[0])
            ind = np.random.choice(np.arange(sum(Y1==1)), size=num)

            #downsample the simulated doublets
            enc_ind = np.concatenate([np.arange(sum(Y1==0)), (sum(Y1==0) + ind)])
            X1 = X1[enc_ind,:]
            Y1 = Y1[enc_ind]
            new_labs1 = labels1[enc_ind]
            true1 = true1[enc_ind]
            
        print('X shape:', X1.shape)
        print('Y:', Y1.shape)
        print('end downsample')

        #re-scale
        X1 = np.log2(X1+1)
        np.random.seed(42)
        scaler = StandardScaler().fit(X1.T)
        np.random.seed(42)
        X1 = scaler.transform(X1.T).T

        #######################################################
        ####################### CLUSTER #######################
        #######################################################

        clust_path = save_path + 'FEAT' + feat + '/' 
        if not os.path.exists(clust_path):
            os.makedirs(clust_path)

        clust_file = clust_path + data_name + '_clusters.npy'

        if(pl.Path(clust_file).exists() & use_old):
            print('read in old clusters')
            clust1 = np.load(clust_file)
        else:
            print('generate new clusters')
            if(X1.shape[0]>=1000):
                clust1 = fast_cluster(X1, comp=30)
            else:
                clust1 = cluster(X1, comp=30)

            if(save):
                np.save(clust_file, clust1)
                
        print('X shape:', X1.shape)
        print('clust shape:', clust1.shape)
        print('Y:', Y1.shape)
        print('end clust')
                
        for homo in homos:
            print('HOMO: ', homo)
            
            X2 = X1
            Y2 = Y1
            true2 = true1
            clust2 = clust1
            new_labs2 = new_labs1
            if(feat=='knnfeat'):
                knn_feature2 = knn_feature1
            ind_2 = ind
            
            print('ind2:', ind.shape)
            
            print('start homo')
            print('X2 shape:', X2.shape)
            print('Y2:', Y2.shape)
            
            
            if(homo=='remove'):
                print('removing homos')
                c = clust2[Y2==0]

                hetero_ind = c[ind1] != c[ind2]
                hetero_ind = hetero_ind[ind_2] #becasue down sampled
                hetero_ind = np.concatenate([np.full(sum(Y2==0), True), hetero_ind])

                X2 = X2[hetero_ind,:]
                Y2 = Y2[hetero_ind]
                clust2 = clust2[hetero_ind]
                
                if(feat=='knnfeat'):
                    knn_feature2 = knn_feature2[hetero_ind]
                
                new_labs2 = new_labs2[hetero_ind]
                true2 = true2[hetero_ind]
                

            print('X shape:', X2.shape)
            print('Y:', Y2.shape)
            print('end homo')

            
                
            for dim_red in dim_reds:
                
                X3 = X2
                Y3 = Y2
                true3 = true2
                clust3 = clust2
                new_labs3 = new_labs2
                if(feat=='knnfeat'):
                    knn_feature3 = knn_feature2

                print('start enc')
                print('X:', X3.shape)
                print('Y:', Y3.shape)
                
                print('DIMRED: ', dim_red)
                #######################################################
                ######################### VAE #########################
                #######################################################
                #X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.1, random_state=12345)
                
                if(dim_red=='clust_vae'):
                    print('CLUST VAE')

                    vae_path = save_path + 'FEAT' + feat + '/HOMO' + homo + '/'
                    if not os.path.exists(vae_path):
                        os.makedirs(vae_path)

                    vae_file = vae_path + data_name + '_clust_vae_encoding.npy'

                    if(pl.Path(vae_file).exists() & use_old):
                        print('loading in old VAE encoding')
                        encoding = np.load(vae_file)
                    else:
                        X_train, X_test, clust_train, clust_test = train_test_split(X3, clust3, test_size=0.1, random_state=12345)
                        clust_train = tf.one_hot(clust_train, depth=clust3.max()+1)
                        clust_test = tf.one_hot(clust_test, depth=clust3.max()+1)

                        ngens = X3.shape[1]

                        #VAE
                        print('generating new VAE encoding')
                        tf.random.set_seed(seeds[1])
                        vae = define_clust_vae(enc_sze, ngens, clust3.max()+1, LR=LR, clust_weight=clust_weight)

                        callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                                                    mode = 'min',
                                                                    min_delta=0, 
                                                                    patience=pat, 
                                                                    verbose=True, 
                                                                    restore_best_weights=False)

                        def scheduler(epoch, lr):
                            if epoch < 3:
                                return lr
                            else:
                                return lr * tf.math.exp(-0.75)

                        callback2 = tf.keras.callbacks.LearningRateScheduler(scheduler)

                        #tf.config.optimizer.set_jit(True)
                        hist = vae.fit(x=[X_train],
                                       y=[X_train, clust_train],
                                       validation_data=([X_test], [X_test, clust_test]),
                                       epochs=eps, 
                                       use_multiprocessing=True,
                                       callbacks=[callback, callback2])

                        encoder = vae.get_layer('encoder')
                        tf.random.set_seed(seeds[2])
                        encoding = np.array(tf.convert_to_tensor(encoder(X3)))

                        if save:
                            np.save(vae_file, encoding)

                    del(vae_file)
                    print(seeds[1]==29503)
                    print(seeds[2]==432809)

                if(dim_red=='vae'):
                    print('VAE')

                    vae_path = save_path + 'FEAT' + feat + '/HOMO' + homo + '/'
                    if not os.path.exists(vae_path):
                        os.makedirs(vae_path)

                    vae_file = vae_path + data_name + '_vae_encoding.npy'

                    if(pl.Path(vae_file).exists() & use_old):
                        print('loading in old VAE encoding')
                        encoding = np.load(vae_file)
                    else:
                        X_train, X_test, t_train, y_test = train_test_split(X3, Y3, test_size=0.1, random_state=12345)

                        ngens = X3.shape[1]

                        #VAE
                        print('generating new VAE encoding')
                        tf.random.set_seed(seeds[1])
                        vae = define_vae(enc_sze, ngens)

                        callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                                                    mode = 'min',
                                                                    min_delta=0, 
                                                                    patience=pat, 
                                                                    verbose=True, 
                                                                    restore_best_weights=False)

                        def scheduler(epoch, lr):
                            if epoch < 3:
                                return lr
                            else:
                                return lr * tf.math.exp(-0.75)

                        callback2 = tf.keras.callbacks.LearningRateScheduler(scheduler)

                        #tf.config.optimizer.set_jit(True)
                        hist = vae.fit(x=X_train,
                                       y=X_train,
                                       validation_data=(X_test, X_test),
                                       epochs=eps, 
                                       use_multiprocessing=True,
                                       callbacks=[callback, callback2])

                        encoder = vae.get_layer('encoder')
                        tf.random.set_seed(seeds[2])
                        encoding = np.array(tf.convert_to_tensor(encoder(X3)))
                        
                        if save:
                            np.save(vae_file, encoding)

                    del(vae_file)


                if(dim_red=='pca'):

                    pca_path = save_path + 'FEAT' + feat + '/HOMO' + homo + '/'
                    if not os.path.exists(pca_path):
                        os.makedirs(pca_path)

                    pca_file = pca_path + data_name + '_pca.npy'

                    if(pl.Path(pca_file).exists() & use_old):
                        print('loading in old PCA')
                        encoding = np.load(pca_file)
                    else:
                        print('generating new PCA encoding')
                        pca = PCA(n_components=enc_sze)
                        encoding = pca.fit_transform(X3)
                        if(save):
                            np.save(pca_file, encoding)

                print('X:', X3.shape)
                print('Y:', Y3.shape)
                #print('enc shape: ', encoding.shape)
                #print('knn shape:', knn_feature.shape)
                print('end enc')
                
                if(feat=='knnfeat'):
                    encoding = np.vstack([knn_feature3,encoding.T]).T
                
                #print('enc shape: ', encoding.shape)

                print('Y3:', Y3.shape)
                #print('encoding:', encoding.shape)
                
                
                U3 = encoding[Y3==0,:]
                P3 = encoding[Y3==1,:]

                for pu in PUs:
                    for clss in clsses:
                        
                        U4 = U3
                        P4 = P3
                        true4 = true3
                        X4 = X3
                        Y4 = Y3
                        if (feat=='knnfeat'):
                            knn_feature4 = knn_feature3

                        print('DIMRED: ', dim_red)
                        print('PU: ', pu)
                        print('CLSS: ', clss)
                        print('FEAT: ', feat)
                        print('HOMO: ', homo)

                        save_p = save_path + 'FEAT' + feat + '/HOMO' + homo + '/DIMRED' + dim_red + '/PU' + pu + '/CLSS' + clss + '/'
                        if not os.path.exists(save_p):
                            os.makedirs(save_p)

                        save_file = save_p + data_name + '_scores.csv'
                        if(pl.Path(save_file).exists() & use_old):
                            print('loading in old scores')

                            df = pd.read_csv(save_file)

                            preds = df.score[Y4==0]
                            preds_on_P = df.score[Y4==1]
                        else:
                            print('generating new scores')

                            if(pu=='PU'):
                                print('doing PU')
                                #######################################################
                                ######################### PU ##########################
                                #######################################################

                                num_cells = P4.shape[0]*k_mult#1000
                                k = int(U4.shape[0] / num_cells)
                                if(k<2):
                                    k=2

                                if(clss=='NN'):
                                    hist = epoch_PU2(U4, P4, k, N, 250, seeds=seeds[3:], puLR=1e-3)
            
                                    y=np.log(hist.history['loss'])
                                    x=np.arange(len(y))
                                    yhat = savgol_filter(y, 7, 1) 

                                    y=yhat
                                    x=np.arange(len(y))

                                    kneedle = KneeLocator(x, y, S=10, curve='convex', direction='decreasing')

                                    knee = kneedle.knee

                                    if knee==None:
                                        knee = 250
                                    elif(num < 500):#add epochs if ther aren't enough cells
                                        print('added 100')
                                        knee = knee+100
                                    elif knee<20:
                                        knee = 20
                                    elif knee>250:
                                        knee = 250

                                    print('KNEE:', knee)   

                                else:
                                    knee=250

                                preds, preds_on_P, hists, _, _, _ = PU(U4, P4, k, N, knee, clss=clss, seeds=seeds[3:], puLR=puLR)

                            if(pu=='noPU'):
                                print('NO PU')

                                preds, preds_on_P = noPU(U4, P4, cls_eps, clss=clss, seeds=seeds[3:], puPat=5, puLR=1e-3, num_layers=1)


                        #RESULTS
                        preds_sing = preds[true4[true4<2]==0]
                        preds_doub_test = preds[true4[true4<2]==1]
                        preds_doub_train = preds_on_P
                        labs = ['singlet', 'actual doublet', 'simulated doublet']
                        cols = np.concatenate([preds, preds_on_P])

                        #SAVE SCORES
                        #new_labs = labels[enc_ind]
                        tmp1 = np.zeros((len(new_labs3), 2))
                        df = pd.DataFrame(tmp1, index=new_labs3, columns=['annotation', 'score'])
                        df.annotation = new_labs3
                        df.score = np.concatenate([preds, preds_on_P])
                        if(save):
                            df.to_csv(save_file) 

                        #PR and ROC curves
                        plt.figure(4)
                        res = get_dbl_metrics(true[true<2], preds)
                        plt.show()
                        plt.close()

                        #save AUCs
                        hm_pr = pd.DataFrame(np.array(res).T, index=['AUROC', 'AUPRC', 'AP']).T
                        if(save):
                            hm_pr.to_csv(save_p + data_name + '_scores_ROC_PR_area_ALL.csv') 

                        del(preds)
                        del(preds_on_P)

                del(U4)
                del(P4)

            del(encoding)
                
        if feat=='knnfeat':
            del(knn_feature1) 
            del(knn_feature2) 
            del(knn_feature3) 
            del(knn_feature4) 





In [None]:
'''clusters_real = np.load('../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/J293t-dm_clusters_real.npy')
clusters_sim = np.load('../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/J293t-dm_clusters_sim.npy')
emb_real = np.load('../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/J293t-dm_embedding_real.npy')
emb_sim = np.load('../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/J293t-dm_embedding_sim.npy')
df_scores = pd.read_csv('../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/J293t-dm_scores_ROC_PR_area_ALL.csv')
knn_feat_real = np.load('../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/J293t-dm_knn_feature_real.npy')
knn_feat_sim = np.load('../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/J293t-dm_knn_feature_sim.npy')
sim_scores = np.load('../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/J293t-dm_scores_on_sim.npy')
scores = np.load('../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/J293t-dm_scores.npy')
sim_ind = np.load('../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/J293t-dm_sim_ind.npy')
time = pd.read_csv('../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/J293t-dm_time.csv')
dat_sim = scs.load_npz('../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/J293t-dm_sim_doubs.npz')
sim_which = np.load('../results_PU/final_vaeda_result/hyper_PcaComp30_ClustWeight20000/J293t-dm_which_sim_doubs.npy')'''