In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import collections 
from collections import defaultdict
import os
import copy
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import warnings
warnings.simplefilter('ignore')

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data
from torchvision.transforms import Compose, ToTensor
import pytorch_lightning as pl

from types import MethodType
from pytorch_lightning.callbacks import LearningRateLogger, EarlyStopping, ModelCheckpoint

%matplotlib inline
plt.rcParams['figure.figsize'] = (17, 8);
%config InlineBackend.figure_format = 'retina'

In [3]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [5]:
from models import *
from utils import *
from config import *
from datasets import DoubleMotifs

In [6]:
def make_hparams(epochs=100, second_hidden=32, learning_rate=3e-4,
                 grouped_lr=False, weight_decay=0, batch_size=48, warmup=0.25, lr_multiplier=1, scheduler=False):
    ### model dimensions
    hparams = {'input_dim1':len(AmEnc), 'input_dim2':len(NucEnc), 
               'main_hidden1':112, 'main_hidden2':96, 'second_hidden':second_hidden, 'groups':4}
    
    ### trainig
    hparams.update({'epochs':epochs, 'Positive_rate':0.5, 'batch_size':batch_size})
    
    ### transforms 
    hparams.update({'FIXED_LEN':100, 'PREFIX_prob':0.3, 'MAX_ROLL':25, 'ROLL_prob':0.2})
    
    ### paths
    hparams.update({'train_data':'../Data/small_Train_particle_0.csv',
               'val_data':'../Data/small_Val_particel_0.csv',
               'test_data':'../Data/small_Test_particle.csv'})
    
    ### optimizers
    hparams['optimizer'] = {}
    hparams['optimizer']['learning_rate'] = learning_rate
    hparams['optimizer']['weight_decay'] = weight_decay

    hparams['optimizer']['grouped_lr'] = False
    param_groups = defaultdict(lambda: {})
    param_groups['ProteinEncoder']['lr'] = 'Freeze'
    param_groups['DNAEncoder']['lr'] = 'Freeze'
    hparams['optimizer']['param_groups'] = dict(param_groups)
    if not scheduler:
        hparams['scheduler'] = False
    else:
        hparams['scheduler'] = {'pct_start':warmup, 'div_factor':20,
                               'final_div_factor':800, 'anneal_strategy':'cos',
                                'steps_per_epoch':60*32//batch_size,
                                'epochs':hparams['epochs']}

    return hparams

In [7]:
hparams = make_hparams(epochs=100, batch_size=64)

In [8]:
M = COUPLER(hparams)

In [9]:
M.prepare_data()

In [10]:
x1,x2,y,s,idx = next(iter(M.train_dataloader()))

In [11]:
yy = M(x1,x2)

In [12]:
yy[0]

tensor([0.6369], grad_fn=<SelectBackward>)

In [13]:
def build_trainer(i, epochs, name='SatgeI_small_baseline', version='v1'):
    early_stopping = EarlyStopping('val_loss', min_delta=0.003, patience=3)
    tmpdir = '/home/moonstrider/Bioinfo/DNA_to_Protein_Motifs/BindFind/chkpt/'
    tb_logger = pl.loggers.TensorBoardLogger('lightning_logs', name=name, version=version)
    model_chekpoint = ModelCheckpoint(os.path.join(tmpdir, f'small_baseline_cross_val_{i}'+'-{epoch}-{val_loss:.2f}'), 
                                      verbose=True, period=5, save_top_k=5)
    
    trainer = pl.Trainer(max_epochs=100, gradient_clip=0.5, logger=tb_logger, gpus=1, 
                         callbacks=[early_stopping], checkpoint_callback=model_chekpoint,
                     show_progress_bar=False)
    return trainer

In [14]:
HPasramsList = [copy.deepcopy(hparams) for _ in range(5)]
for i, hp in enumerate(HPasramsList):
    hp['train_data'] = hp['train_data'].replace('0', f'{i}')
    hp['val_data'] = hp['val_data'].replace('0', f'{i}')

In [1]:
### Run training in cross validation
for i, hp in enumerate(HPasramsList):
    M = COUPLER(hparams=hp)
    trainer = build_trainer(i, epochs=M.hparams['epochs'], name='Velaciraptor', version='crossVal'+str(i))
    trainer.fit(M)
    del M