In [None]:
%%bash
jupyter nbconvert models.ipynb --to script

In [None]:
import torch
from torch import nn
from torch.autograd import Variable

import losswise
from prettytable import PrettyTable
from tqdm import tqdm
import numpy as np

import os
import sys
import random
import importlib
import pickle

from datasets import BurstDataset, ShuffledBatchSequentialSampler, FakeBurstDataset
from prep_dataset import BurstDatasetStandardizer
from models import Encoder, Decoder
from train_functions import *
from eval_functions import plot_autoencoding, autoencode

In [None]:
def get_max_saved_epoch(save_dir):
    if not os.path.exists(save_dir):
        # the first epoch which is run is epoch 1
        return 0
    saved_encs = os.listdir(save_dir)
    max_enc_epoch = 0
    max_dec_epoch = 0
    for saved_enc in saved_encs:
        if 'epoch' in saved_enc:
            epoch_num = int(saved_enc.split('_')[0][5:])
            if 'enc' in saved_enc:
                max_enc_epoch = max(epoch_num, max_enc_epoch)
            elif 'dec' in saved_enc:
                max_dec_epoch = max(epoch_num, max_dec_epoch)
    max_epoch = min(max_enc_epoch, max_dec_epoch)
    return max_epoch

In [None]:
def train_wrapper(min_burst_secs, max_burst_secs, min_episode_mins, max_episode_mins, 
                  pad_length, batch_by_len, robust_scale, downsample_factor, data_dir, max_num_patients, 
                  max_num_bursts_per_episode, train_split, dev_split, 
                 hidden_size, input_size, bidirectional, num_layers, extra_input_dim, 
                  batch_size, num_epochs, lr, weight_decay, teacher_forcing_slope, train_reversed,
                  save_dir, use_losswise, run_tag, config_name):
    
    torch.manual_seed(1)
    np.random.seed(1)
    random.seed(1)
    
    # initialize model
    encoder = Encoder(input_size, hidden_size, bidirectional=bidirectional, num_layers=num_layers)
    decoder = Decoder(hidden_size, input_size, extra_input_dim=extra_input_dim, encoder_bidirectional=bidirectional, 
                      num_layers=num_layers)
    if torch.cuda.is_available():
        encoder = encoder.cuda()
        decoder = decoder.cuda()
     
    # load saved model/initialize dataset
    max_saved_epoch = get_max_saved_epoch(save_dir)
    start_epoch_num = 1
    if max_saved_epoch > 10:
        # if we've saved a decent amount; otherwise, just rerun it. 
        print('loading saved models and datasets...')
        encoder.load_state_dict(torch.load(os.path.join(save_dir, 'epoch{}_enc.pkl'.format(max_saved_epoch))))
        decoder.load_state_dict(torch.load(os.path.join(save_dir, 'epoch{}_dec.pkl'.format(max_saved_epoch))))
        (train_dataset, dev_dataset, test_dataset) = pickle.load(open(os.path.join(save_dir, 'datasets.pkl')))
        start_epoch_num = max_saved_epoch + 1
    else:
        # initialize datasets
        print('initializing datasets...')
        dataset = BurstDataset(data_dir, sort_len=False)
        dataset.init_dataset(pad_length, min_burst_secs=min_burst_secs, max_burst_secs=max_burst_secs, 
                           min_episode_mins=min_episode_mins, max_episode_mins=max_episode_mins, 
                            downsample_factor=downsample_factor, max_num_patients=max_num_patients, 
                            max_num_bursts_per_episode=max_num_bursts_per_episode)
        train_dataset, dev_dataset, test_dataset = dataset.split(train_split, dev_split,
                                                                 split_sort_len=batch_by_len)
        standardizer = BurstDatasetStandardizer()
        print('scaling datasets...')
        standardizer.fit(train_dataset)
        standardizer.transform(train_dataset, robust_scale)
        standardizer.transform(dev_dataset, robust_scale)
        standardizer.transform(test_dataset, robust_scale)

    # get list of all the params
    params_dict = {
        # dataset filtering
        'min_burst_secs':min_burst_secs, 'max_burst_secs':max_burst_secs, 
        'min_episode_mins':min_episode_mins, 'max_episode_mins':max_episode_mins, 
        # dataset size
        'max_num_patients':max_num_patients, 'max_num_bursts_per_episode':max_num_bursts_per_episode,
        'len(train_data)':len(train_dataset), 
        'train split':train_split, 'dev split':dev_split, 
        # dataset other
        'pad_length': pad_length, 
        'robust_scale':robust_scale,
        'downsample_factor':downsample_factor,
        # model params
        'hidden_size': hidden_size, 'bidirectional':bidirectional, 'num_layers':num_layers, 
        'extra_input_dim':extra_input_dim, 
        # training params
        'batch_by_len': batch_by_len, 'batch_size':batch_size, 'num_epochs': num_epochs, 
        'lr': lr, 'weight decay': weight_decay, 
        'teacher_forcing_slope': teacher_forcing_slope, 'train reversed':train_reversed, 
        'save dir': save_dir, 'config_name':config_name, 'start_epoch_num':start_epoch_num}
    if save_dir is not None:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        if max_saved_epoch <= 10:
            pickle.dump(params_dict, open(os.path.join(save_dir, "params_dict.pkl"), "w"))
            pickle.dump((train_dataset, dev_dataset, test_dataset), open(os.path.join(save_dir, "datasets.pkl"), "w"))
            pickle.dump(standardizer, open(os.path.join(save_dir, "standardizer.pkl"), "w"))
    # set up losswise
    if use_losswise:
        losswise.set_api_key('W2TAMB3SZ') # api_key for "coma-eeg"
        session = losswise.Session(tag=run_tag, max_iter=num_epochs,
                                   params=params_dict)
        losswise_graph = session.graph('loss', kind='min')
    else:
        losswise_graph = None
    try:
        train_model(train_dataset, dev_dataset, test_dataset, encoder, decoder, save_dir,
                num_epochs=num_epochs, start_epoch_num=start_epoch_num,
                batch_size=batch_size, lr=lr, weight_decay=weight_decay, 
                teacher_forcing_slope=teacher_forcing_slope, train_reversed=train_reversed, batch_by_len=batch_by_len, 
                losswise_graph=losswise_graph, params_dict=params_dict)
    except KeyboardInterrupt:
        pass
    if use_losswise:
        session.done()
    
    return train_dataset, dev_dataset, test_dataset, encoder, decoder

In [None]:
if 'ipykernel' in sys.argv[0]:
    # running in nb
    config_name = 'configtest'
else:
    if len(sys.argv) < 2:
        print('must run with argument indicating config number')
    config_name = sys.argv[1]

In [None]:
config = importlib.import_module('configs.{}'.format(config_name))
globals().update(
    {k: v for (k, v) in config.__dict__.items() if not k.startswith('_')
})

In [None]:
SAVE_DIR = os.path.join(SAVE_DIR, config_name)

In [None]:
train_dataset, dev_dataset, test_dataset, encoder, decoder = train_wrapper(min_burst_secs=MIN_BURST_SECS, max_burst_secs=MAX_BURST_SECS, 
              min_episode_mins=MIN_EPISODE_MINS, max_episode_mins=MAX_EPISODE_MINS, 
              pad_length=PAD_LENGTH, batch_by_len=BATCH_BY_LEN, robust_scale=robust_scale, 
              downsample_factor=downsample_factor, data_dir=DATA_DIR, 
              max_num_patients=MAX_NUM_PATIENTS, max_num_bursts_per_episode=MAX_NUM_BURSTS_PER_EPISODE,
              train_split=train_split, dev_split=dev_split, 
              hidden_size=HIDDEN_SIZE, input_size=INPUT_SIZE, bidirectional=BIDIRECTIONAL, 
              num_layers=NUM_LAYERS, extra_input_dim=EXTRA_INPUT_DIM, batch_size=BATCH_SIZE, 
              num_epochs=NUM_EPOCHS, lr=LR, weight_decay=WEIGHT_DECAY, 
              teacher_forcing_slope=TEACHER_FORCING_SLOPE, train_reversed=TRAIN_REVERSED,
              save_dir=SAVE_DIR, use_losswise=USE_LOSSWISE, run_tag=RUN_TAG, config_name=config_name)

In [None]:
%%time
sample = train_dataset[8]
mse = plot_autoencoding(sample, encoder, decoder, toss_encoder_output=False, reverse=TRAIN_REVERSED)