In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Function

import pandas as pd
import sys
import os
import numpy as np
import logging

from argparse import Namespace
import tqdm
import itertools
from collections import Counter
import gzip

In [2]:
### GLOBALS 
SOURCE_GENOME="mm10"
TF="CEBPA"
SOURCE_GENOME_FASTA='../../genomes/mm10_no_alt_analysis_set_ENCODE.fasta'
TARGET_GENOME = "hg38"
TARGET_GENOME_FASTA = "../../genomes/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta"
PILOT_STUDY=False
MODEL_NAME="dann"
PYTORCH_DEVICE="cuda"
TRAIN=True
MODEL_STORAGE_SUFFIX="_pilot" if PILOT_STUDY else ""

# Genome Dataset, Sampler and Utility functions

In [3]:
sys.path.append("../")
from utils import datasets,samplers,models,utils

In [4]:
# Logger config
logging.basicConfig(filename=f'./log/{TF}_{MODEL_NAME}{MODEL_STORAGE_SUFFIX}.log', filemode='w', level=logging.DEBUG)

# Define namespace arguments

In [5]:
args = Namespace(
    # Data and Path information
    model_state_file=f'{MODEL_NAME}{MODEL_STORAGE_SUFFIX}.pth',
    source_csv=f'../data/{SOURCE_GENOME}/{TF}/split_data.csv.gz',
    source_genome_fasta=SOURCE_GENOME_FASTA,
    target_csv = f'../data/{TARGET_GENOME}/{TF}/split_data.csv.gz',
    target_genome_fasta = TARGET_GENOME_FASTA,
    model_save_dir=f'../torch_models/{SOURCE_GENOME}/{TF}/{MODEL_NAME}/',
    results_save_dir=f'../results/{SOURCE_GENOME}/{TF}/',
    feat_size=(4, 500),
    
    # Model hyper parameters
    conv_filters=240,
    conv_kernelsize=20,
    maxpool_strides=15,
    maxpool_size=15,
    lstm_outnodes=32,
    linear1_nodes=1024,
    dropout_prob=0.5,
    
    # Training hyper parameters
    batch_size=128,
    early_stopping_criteria=5,
    learning_rate=0.001,
    num_epochs=15,
    tolerance=1e-3,
    seed=1337,
    
    # Runtime options
    catch_keyboard_interrupt=True,
    cuda=True if PYTORCH_DEVICE=="cuda" else False,
    expand_filepaths_to_save_dir=True,
    pilot=PILOT_STUDY, # 2% of original dataset
    train=TRAIN,
    test_batch_size=int(2e3)
)

if args.expand_filepaths_to_save_dir:

    args.model_state_file = os.path.join(args.model_save_dir,
                                         args.model_state_file)
    
    print("Expanded filepaths: ")
    print("\t{}".format(args.model_state_file))
    
# Check CUDA
if not torch.cuda.is_available():
    args.cuda = False

print("Using CUDA: {}".format(args.cuda))

args.device = torch.device("cuda" if args.cuda else "cpu")

# Set seed for reproducibility
utils.set_seed_everywhere(args.seed, args.cuda)

# handle dirs
utils.handle_dirs(args.model_save_dir)
utils.handle_dirs(args.results_save_dir)

Expanded filepaths: 
	../torch_models/mm10/CEBPA/dann/dann.pth
Using CUDA: True


## DANN

In [6]:
class GRL(Function):
    
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.view_as(x)
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.lambda_, None
    

In [7]:
class TFDANN(nn.Module):
    
    def __init__(self, args):
        super(TFDANN, self).__init__()
        
        # featurizer
        self.featurizer = models.TFCNN(channels=args.feat_size[0],
                                       conv_filters=args.conv_filters, 
                                       conv_kernelsize=args.conv_kernelsize, 
                                       maxpool_size=args.maxpool_size, 
                                       maxpool_strides=args.maxpool_strides)

        
        # main classifier
        self.classifier = models.TFLSTM(input_features=args.conv_filters, 
                                 lstm_nodes=args.lstm_outnodes, 
                                 fc1_nodes=args.linear1_nodes)
        
        linear_layer_in = int(np.floor((args.feat_size[1] - args.maxpool_size - 2)/args.maxpool_strides + 1)*args.conv_filters)

        # domain classifier
        self.discriminator = models.TFMLP(input_features=linear_layer_in, 
                                   fc1_nodes=args.linear1_nodes, 
                                   dropout_prob=0)
        
        pass
    
    def forward(self, x_in, lambda_=1):
        # Featurize
        feature = self.featurizer(x_in)
        
        # feature transforms for different classifiers
        dc_in = GRL.apply(feature, lambda_)
        
        # main classifier pipeline
        binding_out = self.classifier(feature)
        
        # domain classifier pipeline
        domain_out = self.discriminator(dc_in)
        
        return binding_out, domain_out

# Train Model

In [8]:
def train_dann(args):
    
    logging.debug("loading dataset...")
    dataset, target_dataset = datasets.load_data(args)
        
    # Initializing model
    logging.debug(f'Initializing model...')
    classifier = TFDANN(args)    
    classifier = classifier.to(args.device)
    model_params = utils.get_n_params(classifier)
    logging.debug(f"The model has {model_params} parameters.")
    
    # Defining loss function, optimizer and scheduler
    loss_func = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(classifier.parameters(), lr=args.learning_rate, eps=1e-7)
    # adjusting the learning rate for better performance
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                     mode='min', factor=0.5,
                                                     patience=1)    
    
    # Making samplers
    train_sampler, valid_sampler = samplers.make_train_samplers(dataset, args)
    tr_samples = train_sampler.num_samples
    dataset.set_split("valid")
    va_samples = len(dataset)
    
    # samplers for source and target data
    dataset.set_split("train")
    source_sampler = samplers.get_sampler(dataset, weighted=False, mini=args.pilot)
    target_dataset.set_split("train")
    target_sampler = samplers.get_sampler(target_dataset, weighted=False, mini=args.pilot)
    
    logging.debug(f"Training {model_params} parameters with {tr_samples} instances at a rate of {round(tr_samples/model_params, 6)} instances per parameter.")
    
    # Defining initial train state
    train_state = utils.make_train_state(args)
    
    # tqdm progress bars initialize
    epoch_bar = tqdm.notebook.tqdm(desc='training routine', 
                          total=args.num_epochs,
                          position=0)
    
    train_bar = tqdm.notebook.tqdm(desc=f'split=train',
                              total=tr_samples//args.batch_size, 
                              position=1, 
                              leave=True)
    
    val_bar = tqdm.notebook.tqdm(desc='split=valid',
                        total=va_samples//int(args.batch_size*1e1), 
                        position=1, 
                        leave=True)

    ##### Training Routine #####
    
    try:
        for epoch_index in range(args.num_epochs):
            train_state['epoch_index'] = epoch_index

            # Iterate over training dataset

            # setup: batch generator, set loss and acc to 0, set train mode on
            dataset.set_split('train')
            target_dataset.set_split('train')
            batch_generator = utils.generate_batches(dataset, sampler=train_sampler,
                                               batch_size=args.batch_size, 
                                               device=args.device)
            
            source_batch_generator = utils.generate_batches(dataset, sampler=source_sampler,
                                   batch_size=args.batch_size, 
                                   device=args.device)
            
            target_batch_generator = utils.generate_batches(target_dataset, sampler=target_sampler,
                                               batch_size=args.batch_size, 
                                               device=args.device)
            
            running_loss = 0.0
            running_domainacc = 0.0
            classifier.train()
            
            for batch_index, (batch_dict, source_batch_dict, target_batch_dict) in enumerate(zip(batch_generator, source_batch_generator, target_batch_generator)):
                
                # the training routine is these 8 steps:

                # --------------------------------------
                # step 1. zero the gradients
                optimizer.zero_grad()

                # step 2. compute the output with balanced data
                y_pred, _ = classifier(x_in=batch_dict['x_data'].float())

                # step 3. compute the source classifier loss with balanced data
                loss_class = loss_func(y_pred, batch_dict['y_target'].float())
                
                # step 4. compute domain loss with random data from source and target species
                domain_in = torch.cat((source_batch_dict['x_data'].float(),
                                       target_batch_dict['x_data'].float()))
                
                
                domain_label = torch.cat((torch.zeros(args.batch_size, dtype=torch.float, device=args.device),
                                          torch.ones(args.batch_size, dtype=torch.float, device=args.device)))
                
                _, domain_pred = classifier(x_in=domain_in)
                loss_domain = loss_func(domain_pred, domain_label)
                                
                # step 5. use losses to produce gradients
                loss = loss_class + loss_domain
                loss.backward()

                # step 6. use optimizer to take gradient step
                optimizer.step()
                # -----------------------------------------
                
                # compute the average precision score
                loss_t = loss_class.item()
                running_loss += (loss_t - running_loss) / (batch_index + 1)
                
                domain_hat = domain_pred>0.5
                domain_hat = domain_hat.long()
                acc_domain = torch.sum(domain_hat==domain_label)/len(domain_label)
                acc_domain = acc_domain.item()
                running_domainacc += (acc_domain - running_domainacc) / (batch_index + 1)

                # update bar
                train_bar.set_postfix(loss=running_loss,
                                      dacc=running_domainacc,
                                      epoch=epoch_index)
                train_bar.update()

            train_state['train_loss'].append(running_loss)

            # Iterate over val dataset

            # setup: batch generator, set loss and acc to 0; set eval mode on
            dataset.set_split('valid')
            batch_generator = utils.generate_batches(dataset, sampler=valid_sampler,
                                               batch_size=int(args.batch_size*1e1), 
                                               device=args.device)
            running_loss = 0.
            tmp_filename = f"./{TF}_dann_tmp.tmp"
            tmp_file = open(tmp_filename, "wb")
            classifier.eval()

            for batch_index, batch_dict in enumerate(batch_generator):

                # compute the output
                y_pred, _ = classifier(x_in=batch_dict['x_data'].float())
                y_target = batch_dict['y_target'].float()

                # step 3. compute the loss
                loss = loss_func(y_pred, y_target)
                loss_t = loss.item()
                running_loss += (loss_t - running_loss) / (batch_index + 1)

                # compute aps
                for yp, yt in zip(torch.sigmoid(y_pred).cpu().detach().numpy(), y_target.cpu().detach().numpy()):
                    tmp_file.write(bytes(f"{yp},{yt}\n", "utf-8"))

                val_bar.set_postfix(loss=running_loss, 
                                    epoch=epoch_index,
                                    early_stop=train_state['early_stopping_step'])
                val_bar.update()

            train_state['val_loss'].append(running_loss)
            # compute aps from saved file
            tmp_file.close()
            val_aps = utils.compute_aps_from_file(tmp_filename)
            os.remove(tmp_filename)
            
            train_state['val_aps'].append(val_aps)

            train_state = utils.update_train_state(args=args, model=classifier,
                                             train_state=train_state)

            scheduler.step(train_state['val_loss'][-1])
            
            logging.debug(f"Epoch: {epoch_index}, Validation Loss: {running_loss}, Validation APS: {val_aps}")

            train_bar.n = 0
            val_bar.n = 0
            epoch_bar.update()

            if train_state['stop_early']:
                logging.debug("Early stopping criterion fulfilled!")
                break

    except KeyboardInterrupt:
        logging.warning("Exiting loop")
        
    return train_state

In [None]:
if __name__ == "__main__":
    info = train_dann(args)


training routine:   0%|          | 0/15 [00:00<?, ?it/s]

split=train:   0%|          | 0/9644 [00:00<?, ?it/s]

split=valid:   0%|          | 0/2645 [00:00<?, ?it/s]

  self.padding, self.dilation, self.groups)
  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)


# Test Model

## Source dataset

In [None]:
classifier = TFDANN(args)

In [None]:
source_dataset = datasets.TFDataset.load_dataset_and_vectorizer_from_path(args.source_csv, 
                                                                 args.source_genome_fasta, 
                                                                 ohe=True)

In [None]:
utils.eval_model(classifier, source_dataset, args, dataset_type="src", model="dann")

## Target dataset

In [None]:
target_dataset = datasets.TFDataset.load_dataset_and_vectorizer_from_path(args.target_csv, 
                                                                 args.target_genome_fasta, 
                                                                 ohe=True)

In [None]:
utils.eval_model(classifier, target_dataset, args, dataset_type="tgt", model="dann")