In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import pandas as pd
import sys
import os
import numpy as np
import logging

from argparse import Namespace
import tqdm
import itertools
from collections import Counter
import gzip

In [2]:
### GLOBALS 
SOURCE_GENOME="mm10"
TF="CEBPA"
SOURCE_GENOME_FASTA='../../genomes/mm10_no_alt_analysis_set_ENCODE.fasta'
TARGET_GENOME = "hg38"
TARGET_GENOME_FASTA = "../../genomes/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta"
PILOT_STUDY=False
MODEL_NAME="cogan"
PYTORCH_DEVICE="cuda"
TRAIN=True
MODEL_STORAGE_SUFFIX="_pilot" if PILOT_STUDY else ""

In [3]:
sys.path.append("../")
from utils import datasets,samplers,models,utils

In [4]:
# Logger config
logging.basicConfig(filename=f'./log/{TF}_{MODEL_NAME}{MODEL_STORAGE_SUFFIX}.log', filemode='w', level=logging.DEBUG)

# Define namespace arguments

In [5]:
args = Namespace(
    # Data and Path information
    model_state_file=f'{MODEL_NAME}{MODEL_STORAGE_SUFFIX}.pth',
    source_csv=f'../data/{SOURCE_GENOME}/{TF}/split_data.csv.gz',
    source_genome_fasta=SOURCE_GENOME_FASTA,
    target_csv = f'../data/{TARGET_GENOME}/{TF}/split_data.csv.gz',
    target_genome_fasta = TARGET_GENOME_FASTA,
    model_save_dir=f'../torch_models/{SOURCE_GENOME}/{TF}/{MODEL_NAME}/',
    results_save_dir=f'../results/{SOURCE_GENOME}/{TF}/',
    feat_size=(4, 500),
    
    # Model hyper parameters
    conv_filters=240,
    conv_kernelsize=20,
    maxpool_strides=15,
    maxpool_size=15,
    lstm_outnodes=32,
    linear1_nodes=1024,
    dropout_prob=0.5,
    
    # Training hyper parameters
    batch_size=128,
    early_stopping_criteria=5,
    learning_rate=0.001,
    num_epochs=15,
    tolerance=1e-3,
    seed=1337,
    
    # Runtime options
    catch_keyboard_interrupt=True,
    cuda=True if PYTORCH_DEVICE=="cuda" else False,
    expand_filepaths_to_save_dir=True,
    pilot=PILOT_STUDY, # 2% of original dataset
    train=TRAIN,
    test_batch_size=int(2e3)
)

if args.expand_filepaths_to_save_dir:

    args.model_state_file = os.path.join(args.model_save_dir,
                                         args.model_state_file)
    
    print("Expanded filepaths: ")
    print("\t{}".format(args.model_state_file))
    
# Check CUDA
if not torch.cuda.is_available():
    args.cuda = False

print("Using CUDA: {}".format(args.cuda))

args.device = torch.device("cuda" if args.cuda else "cpu")

# Set seed for reproducibility
utils.set_seed_everywhere(args.seed, args.cuda)

# handle dirs
utils.handle_dirs(args.model_save_dir)
utils.handle_dirs(args.results_save_dir)

Expanded filepaths: 
	../torch_models/mm10/CEBPA/cogan/cogan.pth
Using CUDA: False


# CoGAN CNN-RNN Classifier/Discriminator Model

In [150]:
class TFCoCD(nn.Module):
    
    def __init__(self, args):
        super(TFCoCD, self).__init__()
        self.featurizer=models.TFCNN(channels=args.feat_size[0], 
                                     conv_filters=args.conv_filters, 
                                     conv_kernelsize=args.conv_kernelsize, 
                                     maxpool_size=args.maxpool_size, 
                                     maxpool_strides=args.maxpool_strides)
        self.cd_lstm=models.TFLSTM_(input_features=args.conv_filters, 
                                    lstm_nodes=args.lstm_outnodes)
        self.cd_mlp=models.TFMLP_(input_features=args.lstm_outnodes, 
                                  fc1_nodes=args.linear1_nodes)
        
        self.c_slp=models.TFSLP(input_features=args.linear1_nodes//4)
        self.d_slp=models.TFSLP(input_features=args.linear1_nodes//4)

        pass
    
    def forward(self, x_a, x_b, apply_sigmoid=False):
        x_a = self.featurizer(x_a)
        x_a = self.cd_lstm(x_a)
        x_b = self.featurizer(x_b)
        x_b = self.cd_lstm(x_b)
        x_in = torch.cat((x_a, x_b))
        x_in = self.cd_mlp(x_in)
        
        out_dscm = self.d_slp(x_in)
        if apply_sigmoid:
            out_dscm = torch.sigmoid(out_dscm)

        return out_dscm, x_a, x_b
    
    def classify_a(self, x_a):
        x_a = self.featurizer(x_a)
        x_a = self.cd_lstm(x_a)
        x_in = self.cd_mlp(x_a)
        
        out_class = self.c_slp(x_in)
        if apply_sigmoid:
            out_class = torch.sigmoid(out_class)

        return out_class

    def classify_b(self, x_b):
        x_b = self.featurizer(x_b)
        x_b = self.cd_lstm(x_b)
        x_in = self.cd_mlp(x_b)
        
        out_class = self.c_slp(x_in)
        if apply_sigmoid:
            out_class = torch.sigmoid(out_class)

        return out_class

In [151]:
class TFCoGen(nn.Module):
    
    def __init__(self, latent_dims):
        super(TFCoGen, self).__init__()
        self.dconv0 = nn.ConvTranspose1d(latent_dims, 1024, kernel_size=15, stride=1)
        self.bn0 = nn.BatchNorm1d(1024, affine=False)
        self.prelu0 = nn.PReLU()
        self.dconv1 = nn.ConvTranspose1d(1024, 512, kernel_size=6, stride=2, padding=1)
        self.bn1 = nn.BatchNorm1d(512, affine=False)
        self.prelu1 = nn.PReLU()
        self.dconv2 = nn.ConvTranspose1d(512, 256, kernel_size=3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm1d(256, affine=False)
        self.prelu2 = nn.PReLU()
        self.dconv3 = nn.ConvTranspose1d(256, 128, kernel_size=3, stride=2, padding=1)
        self.bn3 = nn.BatchNorm1d(128, affine=False)
        self.prelu3 = nn.PReLU()
        self.dconv4 = nn.ConvTranspose1d(128, 64, kernel_size=3, stride=2, padding=1)
        self.bn4 = nn.BatchNorm1d(64, affine=False)
        self.prelu4 = nn.PReLU()
        self.dconv5 = nn.ConvTranspose1d(64, 32, kernel_size=3, stride=2, padding=1)
        self.bn5 = nn.BatchNorm1d(32, affine=False)
        self.prelu5 = nn.PReLU()
        self.dconv6_a = nn.ConvTranspose1d(32, 4, kernel_size=6, stride=1, padding=1)
        self.dconv6_b = nn.ConvTranspose1d(32, 4, kernel_size=6, stride=1, padding=1)
        self.sig6_a = nn.Sigmoid()
        self.sig6_b = nn.Sigmoid()
        
    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1)
        h0 = self.prelu0(self.bn0(self.dconv0(z)))
        h1 = self.prelu1(self.bn1(self.dconv1(h0)))
        h2 = self.prelu2(self.bn2(self.dconv2(h1)))
        h3 = self.prelu3(self.bn3(self.dconv3(h2)))
        h4 = self.prelu4(self.bn4(self.dconv4(h3)))
        h5 = self.prelu5(self.bn5(self.dconv5(h4)))
        out_a = self.sig6_a(self.dconv6_a(h5))
        out_b = self.sig6_b(self.dconv6_b(h5))
        return out_a, out_b

In [152]:
def set_requires_grad(model, requires_grad=True):
    for param in model.parameters():
        param.requires_grad=requires_grad
    return

In [7]:
def train_cogan(args):
    
    # Load the dataset
    logging.debug(f'Loading source and target data...')
    src_dataset, tgt_dataset = datasets.load_data(args)
    
    # Initializing models
    logging.debug(f'Initializing model...')
    discriminator = TFCoCD(args)
    generator = TFCoGen(100)
    discriminator.to(args.device)
    generator.to(args.device)
    model_params = utils.get_n_params(discriminator) + utils.get_n_params(generator)
    logging.debug(f"The model has {model_params} parameters.")
        
    # Defining loss functions, optimizers
    bce_loss_func = nn.BCEWithLogitsLoss()
    mse_loss_func = nn.MSELoss()
    opt_dscm = optim.Adam(discriminator.parameters(), lr=0.0002, eps=1e-7, weight_decay=0.0005)
    opt_gen = optim.Adam(generator.parameters(), lr=0.0002, eps=1e-7, weight_decay=0.0005)

    
    # Making samplers
    # weighted train and unweighted valid samplers for classifier part of the model
    train_sampler, valid_sampler = samplers.make_train_samplers(src_dataset, args)
    nsamples = train_sampler.num_samples
    
    # unweighted samplers of source and target data for generator and discriminator 
    src_dataset.set_split("train")
    src_sampler = samplers.get_sampler(src_dataset, weighted=False, mini=False)
    tgt_dataset.set_split("train")
    tgt_sampler = samplers.get_sampler(tgt_dataset, weighted=False, mini=False)
    
    # Defining initial train state
    train_state = utils.make_train_state(args)
    
    # tqdm progress bar initialize
    epoch_bar = tqdm.notebook.tqdm(desc='training routine', 
                          total=args.num_epochs,
                          position=0)
    
    train_bar = tqdm.notebook.tqdm(desc=f'split=train',
                              total=nsamples//args.batch_size, 
                              position=1, 
                              leave=True)
    
    src_dataset.set_split('valid')
    val_bar = tqdm.notebook.tqdm(desc='split=valid',
                        total=len(src_dataset)//int(args.batch_size*1e1), 
                        position=2, 
                        leave=True)
    
    ##### Training Routine #####
    
    try:
        for epoch_index in range(args.num_epochs):
            train_state['epoch_index'] = epoch_index

            # Iterate over training dataset

            # setup: batch generator (uw), src_batch_generator(w), tgt_batch_generator (w)
            # set loss and acc to 0, 
            # set train mode on
            src_dataset.set_split('train')
            batch_generator = utils.generate_batches(src_dataset, sampler=train_sampler,
                                               batch_size=int(2*args.batch_size), 
                                               device=args.device)
            src_dataset.set_split('train')
            src_batch_generator = utils.generate_batches(src_dataset, sampler=src_sampler,
                                               batch_size=args.batch_size, 
                                               device=args.device)
            tgt_dataset.set_split('train')
            tgt_batch_generator = utils.generate_batches(src_dataset, sampler=tgt_sampler,
                                               batch_size=args.batch_size, 
                                               device=args.device)
            
            dscm_running_loss = 0.0
            gen_running_loss = 0.0
            running_dscmacc = 0.0
            discriminator.train()
            generator.train()

            for batch_index, (batch_dict, src_batch_dict, tgt_batch_dict) in enumerate(zip(batch_generator, src_batch_generator, tgt_batch_generator)):

                # the discriminator training routine:
                
                # only discriminator gets trained in this step
                set_requires_grad(discriminator, requires_grad=True)
                set_requires_grad(generator, requires_grad=False)

                # --------------------------------------
                # zero the gradients
                opt_dscm.zero_grad()

                # step 1. compute the discriminator output of the discriminator for real data
                real_pred, real_feat_a, real_feat_b = discriminator(src_batch_dict["x_data"].float(), tgt_batch_dict["x_data"].float())
                real_labels = torch.ones(2*args.batch_size, dtype=torch.float, device=args.device)
                dscm_real_loss = bce_loss_func(real_pred, real_labels)

                # step 2. compute the discriminator output of the discriminator for fake data
                noise = torch.randn((args.batch_size, 100), device=args.device)
                fake_data_a, fake_data_b = generator(noise)
                fake_pred, fake_feat_a, fake_feat_b = discriminator(fake_data_a, fake_data_b)
                fake_labels = torch.zeros(2*args.batch_size, dtype=torch.float, device=args.device)
                dscm_fake_loss = bce_loss_func(fake_pred, fake_labels)
                dummy_tensor = torch.zeros((fake_feat_a.size(0), fake_feat_a.size(1), fake_feat_a.size(2)), dtype=torch.float, device=args.device)
                mse_loss = mse_loss_func(fake_feat_a - fake_feat_b, dummy_tensor)*fake_feat_a.size(1)*fake_feat_a.size(2)
                
                # step 3. compute the classifier output of the discriminator for real weighted data
                class_pred = discriminator.classify_a(batch_dict["x_data"].float())
                loss_class = bce_loss_func(class_pred, batch_dict['y_target'].float())
                
                # Step 4: compute overall loss
                loss = dscm_real_loss + dscm_fake_loss + 0.01*mse_loss + 10*loss_class

                # step 5. use optimizer to take gradient step
                loss.backward()
                opt_dscm.step()
                
                # -----------------------------------------
                # compute the loss for update
                loss_t = loss.item()
                dscm_running_loss += (loss_t - dscm_running_loss) / (batch_index + 1)
                
                real_hat = real_pred>0.5
                real_hat = real_hat.long()
                fake_hat = fake_pred<0.5
                fake_hat = fake_hat.long()
                acc = torch.sum(torch.cat((real_hat, fake_hat)))/(len(real_hat) + len(fake_hat))
                acc = acc.item()
                running_dscmacc += (acc - running_dscmacc) / (batch_index + 1)
                
                
                # the generator training routine:
                
                # only generator gets trained in this step
                set_requires_grad(discriminator, requires_grad=False)
                set_requires_grad(generator, requires_grad=True)
                
                
                # --------------------------------------
                # zero the gradients
                opt_gen.zero_grad()

                # step 1. compute the discriminator output of the discriminator for fake data
                noise = torch.randn((args.batch_size, 100), device=args.device)
                fake_data_a, fake_data_b = generator(noise)
                fake_pred, fake_feat_a, fake_feat_b = discriminator(fake_data_a, fake_data_b)
                fake_labels = torch.ones(2*args.batch_size, dtype=torch.float, device=args.device)
                gen_fake_loss = bce_loss_func(fake_pred, fake_labels)

                # step 2. use optimizer to take gradient step
                gen_fake_loss.backward()
                opt_gen.step()
                
                # -----------------------------------------
                # compute the loss for update
                loss_t = gen_fake_loss.item()
                gen_running_loss += (loss_t - gen_running_loss) / (batch_index + 1)
                
                
                # update bar
                train_bar.set_postfix(dscm_loss=dscm_running_loss,
                                      gen_loss=gen_running_loss,
                                      acc=running_dscmacc,
                                      epoch=epoch_index)
                                 
                train_bar.update()


            # Iterate over val dataset

            # setup: batch generator, set loss and acc to 0; set eval mode on
            src_dataset.set_split('valid')
            batch_generator = utils.generate_batches(src_dataset, sampler=valid_sampler,
                                               batch_size=int(args.batch_size*1e1), 
                                               device=args.device)
            running_loss = 0.
            ## TODO::Calculate actual aps
            tmp_filename = f"./{TF}_cogan_tmp.tmp"
            tmp_file = open(tmp_filename, "wb")
            discriminator.eval()
            generator.eval()

            for batch_index, batch_dict in enumerate(batch_generator):

                # compute the output
                y_pred = discriminator.classify_a(batch_dict["x_data"].float())
                y_target = batch_dict['y_target'].float()
                loss = bce_loss_func(class_pred, y_target)

                # step 3. compute the loss
                loss_t = loss.item()
                running_loss += (loss_t - running_loss) / (batch_index + 1)

                # save data for computing aps
                for yp, yt in zip(torch.sigmoid(y_pred).cpu().detach().numpy(), y_target.cpu().detach().numpy()):
                    tmp_file.write(bytes(f"{yp},{yt}\n", "utf-8"))

                val_bar.set_postfix(loss=running_loss,
                                    epoch=epoch_index,
                                    early_stop=train_state['early_stopping_step'])
                val_bar.update()

            train_state['val_loss'].append(running_loss)
            
            # compute aps from saved file
            tmp_file.close()
            val_aps = utils.compute_aps_from_file(tmp_filename)
            os.remove(tmp_filename)
        
            train_state['val_aps'].append(val_aps)

            train_state = utils.update_train_state(args=args, model=classifier,
                                             train_state=train_state)

            
            logging.debug(f"Epoch: {epoch_index}, Validation Loss: {running_loss}, Validation APS: {val_aps}")

            train_bar.n = 0
            val_bar.n = 0
            epoch_bar.update()

            if train_state['stop_early']:
                logging.debug("Early stopping criterion fulfilled!")
                break

    except KeyboardInterrupt:
        logging.warning("Exiting loop")
    
    return train_state

In [8]:
if __name__ == "__main__":
    train_state = train_cogan(args)

training routine:   0%|          | 0/15 [00:00<?, ?it/s]

split=train:   0%|          | 0/9644 [00:00<?, ?it/s]

split=valid:   0%|          | 0/2645 [00:00<?, ?it/s]

  self.padding, self.dilation, self.groups)
  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)


# Test Model

## Source dataset

In [9]:
classifier = TFHybrid(args)

In [10]:
source_dataset = datasets.TFDataset.load_dataset_and_vectorizer_from_path(args.source_csv, 
                                                                          args.source_genome_fasta, 
                                                                          ohe=True)

In [11]:
utils.eval_model(classifier, source_dataset, args, dataset_type="src")

split=test:   0%|          | 0/1582 [00:00<?, ?it/s]

  recall = tps / tps[-1]


'../results/mm10/CEBPA/hybrid_src.csv.gz'

## Target dataset

In [12]:
target_dataset = datasets.TFDataset.load_dataset_and_vectorizer_from_path(args.target_csv, 
                                                                 args.target_genome_fasta, 
                                                                 ohe=True)

In [13]:
utils.eval_model(classifier, target_dataset, args, dataset_type="tgt")

split=test:   0%|          | 0/2169 [00:00<?, ?it/s]

'../results/mm10/CEBPA/hybrid_tgt.csv.gz'