This is the code for model training.There are 3 important functions here

1) init() : It returns the config dictionary from the task/pose.py,whcih stores values of all the necessary parameters required during the training of the model.

2)train() : resposible for the training of the model.

3) main() : calls the functions init() and train() 


In [0]:
import os
import tqdm
from os.path import dirname
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
cudnn.enabled = True
import torch
import importlib
import argparse
from datetime import datetime
from pytz import timezone

In [0]:
def parse_command_line():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--continue_exp', type=str, help='continue exp')
    parser.add_argument('-e', '--exp', type=str, default='pose', help='experiments name')
    parser.add_argument('-m', '--max_iters', type=int, default=250, help='max number of iterations (thousands)')
    args = parser.parse_args()
    return args

In [0]:
def save_checkpoint(state, is_best, filename='checkpoint.pt'):
    """
    from pytorch/examples
    """
    basename = dirname(filename)
    if not os.path.exists(basename):
        os.makedirs(basename)
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pt')

In [0]:
def save(config):
    resume = os.path.join('exp', config['opt'].exp)
    if config['opt'].exp=='pose' and config['opt'].continue_exp is not None:
        resume = os.path.join('exp', config['opt'].continue_exp)
    resume_file = os.path.join(resume, 'checkpoint.pt')

    save_checkpoint({
            'state_dict': config['inference']['net'].state_dict(),
            'optimizer' : config['train']['optimizer'].state_dict(),
            'epoch': config['train']['epoch'],
        }, False, filename=resume_file)
    print('=> save checkpoint')

In [0]:
def train(train_func, data_func, config, post_epoch=None):
    while True:
        fails = 0
        print('epoch: ', config['train']['epoch'])
        if 'epoch_num' in config['train']:
            if config['train']['epoch'] > config['train']['epoch_num']:
                break

        for phase in ['train', 'valid']:
            num_step = config['train']['{}_iters'.format(phase)]
            generator = data_func(phase)
            print('start', phase, config['opt'].exp)

            show_range = range(num_step)
            show_range = tqdm.tqdm(show_range, total = num_step, ascii=True)
            batch_id = num_step * config['train']['epoch']
            if batch_id > config['opt'].max_iters * 1000:
                return
            for i in show_range:
                datas = next(generator)
                outs = train_func(batch_id + i, config, phase, **datas)
        config['train']['epoch'] += 1
        save(config)

The function given below does the followning 2 things
1)exports the config dictionary from task/pose.py, which  contains all  the necessary variables for training and return this library in the variable "config"
2)It exports the training function "make_network",also saved in tasks/pose.py
and returns it in the variable func. 


In [0]:
def init():
    """
    task.__config__ contains the variables that control the training and testing
    make_network (which is a function in task/pose.py) builds a function which can do forward and backward propagation
    """
    opt = parse_command_line()
    task = importlib.import_module('task.pose')
    exp_path = os.path.join('exp', opt.exp)
    
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')

    config = task.__config__   #here all variables related to training the model,which were defined in task.pose are put in this variable
    try: os.makedirs(exp_path)
    except FileExistsError: pass

    config['opt'] = opt
    config['data_provider'] = importlib.import_module(config['data_provider']) #this imports the MP2 dataset in the 'data_provider'  attribute of config

    func = task.make_network(config)
    reload(config)
    return func, config

In [0]:
def main():
    func, config = init()
    data_func = config['data_provider'].init(config)
    train(func, data_func, config)
    print(datetime.now(timezone('EST')))

In [0]:
if __name__ == '__main__':
    main()