In [2]:
import errno
import glob
import json
import os
import re
import torch
from config import ConstantConfig, DynamicConfig
from torch.utils.tensorboard import SummaryWriter
from common_utils import read_csv, create_holdout_loader, stratify_split, make_holdout_df

from logger import init_logger
from train_utilities.trainer import Trainer

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
def main(experiment_name, config, resume=False, train_fold=None):
    experiment_dir = config.save_dir + f'/{experiment_name}'

    try:
        # -------- SETUP --------
        # if resuming, get checkpoint parameters
        checkpoint_params = get_checkpoint_params(experiment_dir, resume, train_fold)
        logger = init_logger()
        tb_writer = SummaryWriter(f'./runs/{experiment_name}')
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        if not checkpoint_params:
            make_experiment_directory(experiment_dir)

            # -------- LOAD DATA FROM TRAIN FILE --------
            data_df = read_csv(config.data_dir + '/train.csv', config.debug)
            holdout_df = make_holdout_df(data_df)
            folds_df = stratify_split(data_df, config.fold_num, config.seed, config.target_col)

            # -------- SAVE FILES (for experiment state) --------
            folds_df.to_csv(experiment_dir + '/folds.csv', index=False)
            # save holdout to a csv file for final inference (so we don't run inference on training examples)
            holdout_df.to_csv(experiment_dir + '/holdout.csv', index=False)
            # save the settings for this experiment to its directory
            with open(experiment_dir + '/experiment_config.json', 'w') as f:
                json.dump(config.__dict__, f)
        else:
            # LOAD DATA FROM SAVED FILES
            with open(experiment_dir + '/experiment_config.json', 'r') as f:
                config = json.load(f)
            folds_df = read_csv(experiment_dir + '/folds.csv', config.debug)
            holdout_df = read_csv(experiment_dir + '/holdout.csv', config.debug)

        holdout_loader, holdout_targets = create_holdout_loader(holdout_df, config.train_img_dir)

        trainer = Trainer(experiment_dir=experiment_dir,
                          folds_df=folds_df, holdout_loader=holdout_loader,
                          logger=logger, tensorboard_writer=tb_writer,
                          device=device, checkpoint_params=checkpoint_params,
                          config=config,)
        trainer.fit()
    finally:
        torch.cuda.empty_cache()

def make_experiment_directory(basename):
    try:
        os.makedirs(basename)
    except OSError as e:
        print('Experiment already exists. Be sure to resume training appropriately or start a new experiment.')
        if e.errno != errno.EEXIST: raise


def get_checkpoint_params(basename, resume, train_fold):
    """
    We can restart from the middle of a fold or start from the beginning of a fold.

    checkpoint_params: {"restart_from": fold, "start_beginning_of": fold, "checkpoint_file_path": file}
        restart_from (int): start from middle of a fold - typically used when a training session was cancelled mid fold
            checkpoint_file_path (str) is required in this case
        start_beginning_of (int): train a particular fold
    """

    checkpoint_params = None
    if resume:
        model_filenames = glob.glob(basename + '/*.pth')
        trained_folds = [re.findall(r'fold\d+', f)[0][len('fold'):]
                         for f in model_filenames]
        if train_fold is not None:
            assert train_fold not in trained_folds
            checkpoint_params['start_beginning_of'] = train_fold
        else:
            most_recent_fold = max(trained_folds)
            checkpoint_params['restart_from'] = most_recent_fold
            checkpoint_params['checkpoint_file_path'] = basename + f'/{config.settings.model_arch}_fold{most_recent_fold}.pth'
    return checkpoint_params

In [None]:
if __name__ == '__main__':
    try:
        debug = True
        print('Running in debug mode:', debug)
        config = Configuration()
        config.debug = debug
        main(experiment_name='test1', resume=False, config=config)
    except KeyboardInterrupt:
        pass