In [1]:
from train import *
import math

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', default='experiments/base_model',
                    help="Directory containing params.json")
parser.add_argument('--data_dir', default='data/', help="Directory containing the dataset")
parser.add_argument('--label_type', default='cum', help="Label types, cum only or cum with forward label")
parser.add_argument('--restore_dir', default='None',
                    help="Optional, directory containing weights to reload before training")
parser.add_argument('--device', default='0', help="CUDA_DEVICE")

# training settings
parser.add_argument('--max_epoch', default=100, type=int)
parser.add_argument('--patience', default=20, type=int)
parser.add_argument('--batch_size', default=1, type=int)
parser.add_argument('--learning_rate', default=1e-4, type=float, help='1e-3, 1e-4, 1e-5')
parser.add_argument('--weight_decay', default=1e-5, type=float, help='1e-4, 1e-5, 1e-6, l2 regularization')
parser.add_argument('--gamma', default=0.9, type=float, help='exponential learning rate scheduler, lr = lr_0 * gamma ^ epoch')

# Set the random seed for the whole graph for reproductible experiments
tf.set_random_seed(230)

# Load the parameters from the experiment params.json file in model_dir
args = parser.parse_args(['--model_dir', '../experiments/test_fim', '--data_dir', '/home/cwlin/explainable_credit/data/index_01_cum_for/index_fold_01/', '--label_type', 'cum',])


# setting for individual device allocating
os.environ['CUDA_VISIBLE_DEVICES'] = '' if args.device == 'cpu' else args.device

json_path = os.path.join(args.model_dir, 'params.json')
assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path)
params = Params(json_path)

In [3]:
params.num_epochs = args.max_epoch
params.patience = args.patience
params.batch_size = args.batch_size
params.learning_rate = args.learning_rate
params.weight_decay = args.weight_decay
params.decay_rate = args.gamma

# main_valid.py

In [4]:
"""Tensorflow utility functions for training"""

import logging
import os

from tqdm import trange
import tensorflow as tf
import numpy as np

from model.utils import save_dict_to_json
from model.evaluation import evaluate_sess


def train_sess(sess, model_spec, num_steps, writer, params):
    """Train the model on `num_steps` batches

    Args:
        sess: (tf.Session) current session
        model_spec: (dict) contains the graph operations or nodes needed for training
        num_steps: (int) train for this number of batches
        writer: (tf.summary.FileWriter) writer for summaries
        params: (Params) hyperparameters
    """
    # Get relevant graph operations or nodes needed for training
    loss = model_spec['loss']
    train_op = model_spec['train_op']
    update_metrics = model_spec['update_metrics']
    metrics = model_spec['metrics']
    summary_op = model_spec['summary_op']
    global_step = tf.train.get_global_step()

    #pred_op = model_spec['pred']
    #label_op = model_spec['label']
    #para_op = model_spec['num_paras']

    # Load the training dataset into the pipeline and initialize the metrics local variables
    sess.run(model_spec['iterator_init_op'])
    sess.run(model_spec['metrics_init_op'])

    # Use tqdm for progress bar
    t = trange(num_steps)
    for i in t:
        # Evaluate summaries for tensorboard only once in a while
        if i % params.save_summary_steps == 0:
            # Perform a mini-batch update
            _, _, loss_val, summ, global_step_val = sess.run([train_op, update_metrics, loss,
                                                              summary_op, global_step])

            ## Debug
            #_, _, loss_val, summ, global_step_val, n_paras = sess.run([train_op, update_metrics, loss,
            #                                                  summary_op, global_step, para_op])
            #np.set_printoptions(precision=3)
            #print("# of paras: {}".format(n_paras))
            #print(" pred: {}".format(pred))
            #print("label: {}".format(label))

            # Write summaries for tensorboard
            writer.add_summary(summ, global_step_val)
        else:
            _, _, loss_val = sess.run([train_op, update_metrics, loss])
        # Log the loss in the tqdm progress bar
        t.set_postfix(loss='{:05.3f}'.format(loss_val))


    metrics_values = {k: v[0] for k, v in metrics.items()}
    metrics_val = sess.run(metrics_values)
    aucs = [v for k, v in metrics_val.items() if 'auc_' in k]
    metrics_val['auc'] = sum(aucs)/len(aucs)
    metrics_string = "; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_val.items())
    logging.info("- Train metrics: " + metrics_string)


def train_and_evaluate(train_model_spec, eval_model_spec, model_dir, params, restore_from=None):
    """Train the model and evaluate every epoch.

    Args:
        train_model_spec: (dict) contains the graph operations or nodes needed for training
        eval_model_spec: (dict) contains the graph operations or nodes needed for evaluation
        model_dir: (string) directory containing config, weights and log
        params: (Params) contains hyperparameters of the model.
                Must define: num_epochs, train_size, batch_size, eval_size, save_summary_steps
        restore_from: (string) directory or file containing weights to restore the graph
    """
    # Initialize tf.Saver instances to save weights during training
    last_saver = tf.train.Saver() # will keep last 5 epochs
    best_saver = tf.train.Saver(max_to_keep=1)  # only keep 1 best checkpoint (best on eval)
    begin_at_epoch = 0

    with tf.Session() as sess:
        # Initialize model variables
        sess.run(train_model_spec['variable_init_op'])

        # Reload weights from directory if specified
        if restore_from is not None:
            logging.info("Restoring parameters from {}".format(restore_from))
            if os.path.isdir(restore_from):
                restore_from = tf.train.latest_checkpoint(restore_from)
                begin_at_epoch = int(restore_from.split('-')[-1])
            last_saver.restore(sess, restore_from)

        # For tensorboard (takes care of writing summaries to files)
        train_writer = tf.summary.FileWriter(os.path.join(model_dir, 'train_summaries'), sess.graph)
        eval_writer = tf.summary.FileWriter(os.path.join(model_dir, 'eval_summaries'), sess.graph)

        best_eval_acc = 0.0
        for epoch in range(begin_at_epoch, begin_at_epoch + params.num_epochs):
            # Run one epoch
            logging.info("Epoch {}/{}".format(epoch + 1, begin_at_epoch + params.num_epochs))
            # Compute number of batches in one epoch (one full pass over the training set)
            num_steps = (params.train_size + params.batch_size - 1) // params.batch_size
            train_sess(sess, train_model_spec, num_steps, train_writer, params)

            # Save weights
            last_save_path = os.path.join(model_dir, 'last_weights', 'after-epoch')
            last_saver.save(sess, last_save_path, global_step=epoch + 1)

            # Evaluate for one epoch on validation set
            num_steps = (params.eval_size + params.batch_size - 1) // params.batch_size
            metrics = evaluate_sess(sess, eval_model_spec, num_steps, eval_writer)


            # If best_eval, best_save_path
            eval_acc = metrics['auc']
            if eval_acc >= best_eval_acc:
                # Store new best accuracy
                best_eval_acc = eval_acc
                # Save weights
                best_save_path = os.path.join(model_dir, 'best_weights', 'after-epoch')
                best_save_path = best_saver.save(sess, best_save_path, global_step=epoch + 1)
                logging.info("- Found new best auc, saving in {}".format(best_save_path))
                # Save best eval metrics in a json file in the model directory
                best_json_path = os.path.join(model_dir, "metrics_eval_best_weights.json")
                save_dict_to_json(metrics, best_json_path)

            # Save latest eval metrics in a json file in the model directory
            last_json_path = os.path.join(model_dir, "metrics_eval_last_weights.json")
            save_dict_to_json(metrics, last_json_path)

## Check that we are not overwriting some previous experiment
## Comment these lines if you are developing your model and don't care about overwritting
#model_dir_has_best_weights = os.path.isdir(os.path.join(args.model_dir, "best_weights"))
#overwritting = model_dir_has_best_weights and args.restore_dir is None
#assert not overwritting, "Weights found in model_dir, aborting to avoid overwrite"

# Set the logger
set_logger(os.path.join(args.model_dir, 'train.log'))

# Create the input data pipeline
logging.info("Creating the datasets...")

## CW: train_subset and valid_subset are used for the early stopping epoch and patience
# train_data = "train_{}.gz".format(params.label_type)
train_data = "train_subset_{}.gz".format(params.label_type)
# test_data  =  "test_{}.gz".format(params.label_type)
test_data  =  "valid_subset_{}.gz".format(params.label_type)
# Paths to data
train_data = os.path.join(args.data_dir, train_data)
eval_data  = os.path.join(args.data_dir, test_data)
# skip header
params.train_size = int(os.popen('zcat ' + train_data + '|wc -l').read()) - 1
params.eval_size = int(os.popen('zcat ' + eval_data  + '|wc -l').read()) - 1
logging.info('Train size: {}, Eval size: {}'.format(params.train_size, params.eval_size))

# CW: For setting the decay_steps to the number of batches in an epoch. This way, the learning rate would be decayed once per epoch, just like in PyTorch.
num_batches_per_epoch = math.ceil(params.train_size / params.batch_size)
params.decay_steps = int(num_batches_per_epoch * params.num_epochs)

# Create the two iterators over the two datasets
train_inputs = input_fn(True, train_data, params)
eval_inputs = input_fn(False, eval_data, params)
logging.info("- done.")

# Define the models (2 different set of nodes that share weights for train and eval)
logging.info("Creating the model...")
train_model_spec = model_fn(True, train_inputs, params)
eval_model_spec = model_fn(False, eval_inputs, params, reuse=True)
logging.info("- done.")

if args.restore_dir != "None":
    restore_dir = os.path.join(args.model_dir, args.restore_dir)
else:
    restore_dir = None

# Train the model
logging.info("Starting training for {} epoch(s)".format(params.num_epochs))
train_and_evaluate(train_model_spec, eval_model_spec, args.model_dir, params, restore_dir)