In [13]:
import numpy as np
import torch
import torch.utils.data
from torch import nn, optim
from torch.autograd import Variable

from ConAAE.dataloader import RNA_Dataset
from ConAAE.dataloader import ATAC_Dataset
from ConAAE.model import FC_Autoencoder, FC_Classifier, FC_VAE, Simple_Classifier,TripletLoss
from ConAAE import conAAE

In [14]:
%load_ext autoreload
%autoreload 1
%aimport ConAAE

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
import os
import argparse
import time

In [16]:
def setup_args(args=[]):

    options = argparse.ArgumentParser()

    # save and directory options
    options.add_argument('-sd', '--save-dir', action="store", dest="save_dir")
    options.add_argument('-i', '--input-dir', action="store", dest="input_dir")
    options.add_argument('--save-freq', action="store", dest="save_freq", default=10, type=int)
    options.add_argument('--pretrained-file', action="store")

    # training parameters
    options.add_argument('-bs', '--batch-size', action="store", dest="batch_size", default=100, type=int)
    options.add_argument('-nz', '--latent-dimension', action="store", dest="nz", default=50, type=int)
    options.add_argument('-w', '--num-workers', action="store", dest="num_workers", default=10, type=int)
    
    options.add_argument('-lrD', '--learning-rate-D', action="store", dest="learning_rate_D", default=1e-4, type=float)
    options.add_argument('-e', '--max-epochs', action="store", dest="max_epochs", default=101, type=int)
    options.add_argument('-wd', '--weight-decay', action="store", dest="weight_decay", default=0, type=float)
    options.add_argument('--train-imagenet', action="store_true")
    options.add_argument('--conditional', action="store_true")
    #options.add_argument('--conditional-adv', action="store_true")
    #options.add_argument('--triplet-loss',action="store_true")
    options.add_argument('--contrastive-loss',action="store_true")
    options.add_argument('--consistency-loss',action="store_true")
    options.add_argument('--anchor-loss',action="store_true")
    options.add_argument('--MMD-loss',action="store_true")
    options.add_argument('--VAE',action="store_true")
    options.add_argument('--discriminator',action="store_true")
    options.add_argument('--augmentation',action="store_true")

    # hyperparameters
    options.add_argument('-lrAE', '--learning-rate-AE', action="store", dest="learning_rate_AE", default=1e-3, type=float)
    options.add_argument('--margin', action="store", default=0.3, type=float)
    options.add_argument('--alpha', action="store", default=10.0, type=float)
    options.add_argument('--beta', action="store", default=1., type=float)
    options.add_argument('--beta1', action="store", default=0.5, type=float)
    options.add_argument('--beta2', action="store", default=0.999, type=float)
    options.add_argument('--lamb', action="store", default=0.00000001, type=float)
    options.add_argument('--latent-dims', action="store", default=50, type=int)
    
    #options.add_argument('-rna', '--input-rna', action="store", dest="input_rna")
    #options.add_argument('-atac', '--input-atac', action="store", dest="input_atac")
    #options.add_argument('-label', '--input-label', action="store", dest="input_label")
    

    # gpu options
    options.add_argument('-gpu', '--use-gpu', action="store_false", dest="use_gpu")

    return options.parse_args(args)

In [17]:
args = setup_args(args=['-i','PBMC','--consistency-loss','--contrastive-loss','-sd','demo'])

In [18]:
if not torch.cuda.is_available():
    args.use_gpu = False

os.makedirs(args.save_dir, exist_ok=True)

In [19]:
genomics_dataset = RNA_Dataset(datadir="./data/"+args.input_dir+"/rna.csv",labeldir="./data/"+args.input_dir+"/label.csv",mode='train')
ATAC_dataset = ATAC_Dataset(datadir="./data/"+args.input_dir+"/atac.csv",labeldir="./data/"+args.input_dir+"/label.csv",mode='train')

In [20]:
con=conAAE.conAAE(genomics_dataset,ATAC_dataset,args)

In [21]:
con.train()

epoch 0
clf_loss: 0.499890
contrastive_loss: 0.929297
recon_loss: 27.559725
epoch 1
clf_loss: 0.499831
contrastive_loss: 0.509586
recon_loss: 23.453665
epoch 2
clf_loss: 0.500394
contrastive_loss: 0.479272
recon_loss: 23.089788
epoch 3
clf_loss: 0.499978
contrastive_loss: 0.418671
recon_loss: 22.650355
epoch 4
clf_loss: 0.499774
contrastive_loss: 0.430626
recon_loss: 22.254313
epoch 5
clf_loss: 0.499617
contrastive_loss: 0.370126
recon_loss: 21.572776
epoch 6
clf_loss: 0.499933
contrastive_loss: 0.409612
recon_loss: 23.025798
epoch 7
clf_loss: 0.499321
contrastive_loss: 0.366949
recon_loss: 22.644789
epoch 8
clf_loss: 0.499865
contrastive_loss: 0.274342
recon_loss: 21.676018
epoch 9
clf_loss: 0.499251
contrastive_loss: 0.256479
recon_loss: 21.322411
epoch 10
clf_loss: 0.499819
contrastive_loss: 0.257185
recon_loss: 21.639521
epoch 11
clf_loss: 0.499323
contrastive_loss: 0.214050
recon_loss: 21.056996
epoch 12
clf_loss: 0.499664
contrastive_loss: 0.160882
recon_loss: 21.196803
epoch 13


In [22]:
RNA_test = RNA_Dataset(datadir="./data/"+args.input_dir+"/rna.csv",labeldir="./data/"+args.input_dir+"/label.csv",mode='test')
ATAC_test = ATAC_Dataset(datadir="./data/"+args.input_dir+"/atac.csv",labeldir="./data/"+args.input_dir+"/label.csv",mode='test')

In [27]:
con.load_model('demo/netRNA_DE_80.pth','demo/netATAC_DE_80.pth')

In [28]:
con.test(RNA_test,ATAC_test)

0.11987545407368967
0.20083030617540218
0.27088738972496107
0.33471717695900366
0.39231966787752987
26.0
0.7976128697457188


In [None]:
#PBMC data can be download at https://drive.google.com/drive/folders/1gOrPV-npNhrubNWjpFvD4LTz_WoMXPSN?usp=share_link