In [17]:
%%writefile train.py

import argparse

import torch
import torch.nn as nn

from data_loader import DataLoader

from simple_ntc.transformer import *
from simple_ntc.trainer import *


def define_argparse():
    
    p = argparse.ArgumentParser()
    
    p.add_argument('--model_fn', required=True)


    p.add_argument('--gpu_id', type=int, default=-1)
    p.add_argument('--verbose', type=int, default=2)

    p.add_argument('--min_vocab_freq', type=int, default=2)
    p.add_argument('--max_vocab_size', type=int, default=999999)

    p.add_argument('--batch_size', type=int, default=256)
    p.add_argument('--n_epochs', type=int, default=10)

    config = p.parse_args()

    return config


def main(config):
    
    dataset = DataLoader(
                batch_size=config.batch_size,
                min_freq=config.min_vocab_freq,
                max_vocab=config.max_vocab_size,
                device = config.gpu_id
                )
    
    #vocab_size = len(dataset, TEXT_A.vocab)
    #n_classes = len(dataset, label.vocab)
    ntokens = dataset.ntokens
    ninp = 512 # embedding dimension
    nhead = 8 # the number of heads in the multiheadattention models
    nhid = 256 # the dimension of the feedforward network model in nn.TransformerEncoder
    nlayers = 1 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
    dropout = 0.5 # the dropout value
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = TransformerModel(ntokens, ninp, nhead, nhid, nlayers, dropout).to(device)
    
    crit = nn.NLLLoss()
    print(model)
    
    if config.gpu_id >= 0:
        model.cuda(config.gpu_id)
        crit.cuda(config.gpu_id)

    transformer_trainer = Trainer(config)
    transformer_trainer_model = transformer_trainer.train(model, crit, dataset.train_iter, dataset.valid_iter)
    
    torch.save({'config':config,
               'vocab':datset.TEXT_A.vocab,
               'classes':dataset.LABEL.vocab},
              config.model_fn)
    
if __name__ == '__main__':
    config = define_argparse()
    main(config)

Overwriting train.py
