In [1]:
import numpy as np
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable

from ConAAE.dataloader import RNA_Dataset
from ConAAE.dataloader import ATAC_Dataset
from ConAAE.model import FC_Autoencoder, FC_Classifier, FC_VAE, Simple_Classifier,TripletLoss
from ConAAE import conAAE

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%load_ext autoreload
%autoreload 1
%aimport ConAAE

In [3]:
import os
import argparse
import time

In [4]:
def setup_args(args=[]):

    options = argparse.ArgumentParser()

    # save and directory options
    options.add_argument('-sd', '--save-dir', action="store", dest="save_dir")
    options.add_argument('-i', '--input-dir', action="store", dest="input_dir")
    options.add_argument('--save-freq', action="store", dest="save_freq", default=10, type=int)
    options.add_argument('--pretrained-file', action="store")

    # training parameters
    options.add_argument('-bs', '--batch-size', action="store", dest="batch_size", default=32, type=int)
    options.add_argument('-nz', '--latent-dimension', action="store", dest="nz", default=50, type=int)
    options.add_argument('-w', '--num-workers', action="store", dest="num_workers", default=10, type=int)
    
    options.add_argument('-lrD', '--learning-rate-D', action="store", dest="learning_rate_D", default=1e-4, type=float)
    options.add_argument('-e', '--max-epochs', action="store", dest="max_epochs", default=101, type=int)
    options.add_argument('-wd', '--weight-decay', action="store", dest="weight_decay", default=0, type=float)
    options.add_argument('--contrastive-loss',action="store_true")
    options.add_argument('--consistency-loss',action="store_true")
    options.add_argument('--anchor-loss',action="store_true")
    options.add_argument('--MMD-loss',action="store_true")
    options.add_argument('--VAE',action="store_true")
    options.add_argument('--discriminator',action="store_true")
    options.add_argument('--augmentation',action="store_true")

    # hyperparameters
    options.add_argument('-lrAE', '--learning-rate-AE', action="store", dest="learning_rate_AE", default=1e-4, type=float)
    options.add_argument('--margin', action="store", default=0.3, type=float)
    options.add_argument('--alpha', action="store", default=10.0, type=float)
    options.add_argument('--beta', action="store", default=1., type=float)
    options.add_argument('--beta1', action="store", default=0.5, type=float)
    options.add_argument('--beta2', action="store", default=0.999, type=float)
    

    # gpu options
    options.add_argument('-gpu', '--use-gpu', action="store_false", dest="use_gpu")

    return options.parse_args(args)

In [5]:
args = setup_args(args=['-i','sci-CAR','--consistency-loss','--contrastive-loss','-sd','demo'])

In [6]:
if not torch.cuda.is_available():
    args.use_gpu = False

os.makedirs(args.save_dir, exist_ok=True)

In [7]:
genomics_dataset = RNA_Dataset(datadir="./data/"+args.input_dir+"/rna.csv",labeldir="./data/"+args.input_dir+"/label.csv",mode='train')
ATAC_dataset = ATAC_Dataset(datadir="./data/"+args.input_dir+"/atac.csv",labeldir="./data/"+args.input_dir+"/label.csv",mode='train')

In [8]:
con=conAAE.conAAE(genomics_dataset,ATAC_dataset,args)

In [9]:
con.train()

epoch 0
clf_loss: 0.494535
contrastive_loss: 2.137235
recon_loss: 26.340048
epoch 1
clf_loss: 0.489661
contrastive_loss: 1.772574
recon_loss: 15.681348
epoch 2
clf_loss: 0.489451
contrastive_loss: 1.324683
recon_loss: 11.201370
epoch 3
clf_loss: 0.487918
contrastive_loss: 1.048810
recon_loss: 8.745085
epoch 4
clf_loss: 0.488386
contrastive_loss: 0.906212
recon_loss: 7.211453
epoch 5
clf_loss: 0.486761
contrastive_loss: 0.826470
recon_loss: 6.196562
epoch 6
clf_loss: 0.488817
contrastive_loss: 0.743252
recon_loss: 5.452908
epoch 7
clf_loss: 0.490403
contrastive_loss: 0.680947
recon_loss: 4.961567
epoch 8
clf_loss: 0.492628
contrastive_loss: 0.639199
recon_loss: 4.634337
epoch 9
clf_loss: 0.492940
contrastive_loss: 0.632160
recon_loss: 4.429671
epoch 10
clf_loss: 0.491291
contrastive_loss: 0.608775
recon_loss: 4.254476
epoch 11
clf_loss: 0.492859
contrastive_loss: 0.592523
recon_loss: 4.072032
epoch 12
clf_loss: 0.487251
contrastive_loss: 0.592082
recon_loss: 4.008751
epoch 13
clf_loss: 

In [10]:
RNA_test = RNA_Dataset(datadir="./data/"+args.input_dir+"/rna.csv",labeldir="./data/"+args.input_dir+"/label.csv",mode='test')
ATAC_test = ATAC_Dataset(datadir="./data/"+args.input_dir+"/atac.csv",labeldir="./data/"+args.input_dir+"/label.csv",mode='test')

In [15]:
con.load_model('demo/netRNA_DE_100.pth','demo/netATAC_DE_100.pth')

In [16]:
con.test(RNA_test,ATAC_test)

0.055710306406685235
0.11142061281337047
0.16991643454038996
0.22284122562674094
0.27019498607242337
0.0
0.6406685236768802
