In [1]:
import sys
import os

import argparse
import json
import random
import shutil
import copy

import torch
from torch import cuda
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch.nn.functional as F
import numpy as np
import time
import logging
from data import Dataset
from models import RNNLM, RNNG
from utils import *

In [2]:
parser = argparse.ArgumentParser()

# Data path options
parser.add_argument('--train_file', default='data/ptb-train.pkl')
parser.add_argument('--val_file', default='data/ptb-val.pkl')
parser.add_argument('--test_file', default='data/ptb-test.pkl')
parser.add_argument('--train_from', default='')
# Model options
parser.add_argument('--w_dim', default=650, type=int, help='hidden dimension for LM')
parser.add_argument('--h_dim', default=650, type=int, help='hidden dimension for LM')
parser.add_argument('--num_layers', default=2, type=int, help='number of layers in LM and the stack LSTM (for RNNG)')
parser.add_argument('--dropout', default=0.6, type=float, help='dropout rate')
# Optimization options
parser.add_argument('--count_eos_ppl', default=0, type=int, help='whether to count eos in val PPL')
parser.add_argument('--test', default=0, type=int, help='')
parser.add_argument('--save_path', default='urnng.pt', help='where to save the data')
parser.add_argument('--num_epochs', default=80, type=int, help='number of training epochs')
parser.add_argument('--min_epochs', default=30, type=int, help='do not decay learning rate for at least this many epochs')
parser.add_argument('--lr', default=0.1, type=float, help='starting learning rate')
parser.add_argument('--decay', default=0.5, type=float, help='')
parser.add_argument('--param_init', default=0.1, type=float, help='parameter initialization (over uniform)')
parser.add_argument('--max_grad_norm', default=5, type=float, help='gradient clipping parameter')
parser.add_argument('--gpu', default=2, type=int, help='which gpu to use')
parser.add_argument('--seed', default=3435, type=int, help='random seed')
parser.add_argument('--print_every', type=int, default=500, help='print stats after this many batches')

# Distillation stuff
parser.add_argument('--distill', type=int, default=1, help='Whether you want to train this RNNLM with distillation from another model. 0 means no, 1 means yes')
parser.add_argument('--temp', type=float, default=2.0, help='temperature to scale logits by')

_StoreAction(option_strings=['--temp'], dest='temp', nargs=None, const=None, default=2.0, type=<class 'float'>, choices=None, help='temperature to scale logits by', metavar=None)

In [3]:
def init_dist_model():
    loaded_data = torch.load('rnng.pt')
    model_args = loaded_data['args']
    model_state_dict = loaded_data['model'].state_dict()
        
    rnng = RNNG(
        vocab=len(loaded_data['word2idx']),
        w_dim=model_args['w_dim'],           # Dimensionality of word embeddings
        h_dim=model_args['h_dim'],           # Dimensionality of hidden states
        q_dim=model_args['q_dim'],           # Dimensionality of 'q' vector
        num_layers=model_args['num_layers'], # Number of layers
        dropout=model_args['dropout'],       # Dropout rate
        max_len=250
    )
        
    rnng.load_state_dict(model_state_dict)
    rnng.eval()
    rnng.cuda()
    
    return rnng

In [None]:
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_data = Dataset(args.train_file)
    val_data = Dataset(args.val_file)  
    vocab_size = int(train_data.vocab_size)
    
    print('Train: %d sents / %d batches, Val: %d sents / %d batches' % 
          (train_data.sents.size(0), len(train_data), val_data.sents.size(0), 
           len(val_data)))
    
    print('Vocab size: %d' % vocab_size)
    
    cuda.set_device(args.gpu)
    if args.train_from == '':
        model = RNNLM(vocab = vocab_size,
                      w_dim = args.w_dim, 
                      h_dim = args.h_dim,
                      dropout = args.dropout,
                      num_layers = args.num_layers)
        if args.param_init > 0:
            for param in model.parameters():    
                param.data.uniform_(-args.param_init, args.param_init)      
    else:
        print('loading model from ' + args.train_from)
        checkpoint = torch.load(args.train_from)
        model = checkpoint['model']
        
    if args.distill == 1:
        rnng = init_dist_model()
        ce_loss = nn.CrossEntropyLoss()
        kl_div_loss = nn.KLDivLoss(reduction='batchmean')
        temp = args.temp
        
    print("model architecture")
    print(model)
    
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
#     scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)

    model.train()
    model.cuda()
    epoch = 0
    decay= 0
    if args.test == 1:
        test_data = Dataset(args.test_file)  
        test_ppl = eval(test_data, model, count_eos_ppl = args.count_eos_ppl)
        sys.exit(0)
    best_val_ppl = eval(val_data, model, count_eos_ppl = args.count_eos_ppl)
    while epoch < args.num_epochs:
        print("Learning rate: ", args.lr)
        print()
        start_time = time.time()
        epoch += 1  
        print('Starting epoch %d' % epoch)
        train_nll = 0.
        num_sents = 0.
        num_words = 0.
        b = 0
        for i in np.random.permutation(len(train_data)):
            sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = train_data[i]
            if length == 1:
                continue
                
            # Target is second to last element, since last is the end of sentence token </s>
            targets = sents[:, -2]
            targets = targets.cuda()
            sents = sents.cuda()
            b += 1
            optimizer.zero_grad()
            optimizer.zero_grad()
            
            # Get outputs from student and teacher models
            # Student output is in log_probs, teacher in regular probs for KLDivergence
            student_logits = model(sents)
            teacher_logits = rnng(sents)
            
            _, pred_idx = torch.max(student_logits, 1)
            correct_pred = (pred_idx == targets)
            accuracy = correct_pred.sum().item() / targets.size(0)
            
            student_log_probs = F.log_softmax(student_logits / temp, dim=1)
            teacher_probs = F.softmax(teacher_logits / temp, dim=1)
            
            loss = ce_loss(student_logits, targets)
            kl_loss = kl_div_loss(student_log_probs, teacher_probs)
            
            total_loss = loss + kl_loss
            print("Total loss: ", total_loss)
            print("Accuracy: ", accuracy*100)
            print()
            
            total_loss.backward()
            
            if args.max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)  
                
            optimizer.step()
            num_sents += batch_size
            num_words += batch_size * length
            
            if b % args.print_every == 0:
                param_norm = sum([p.norm()**2 for p in model.parameters()]).item()**0.5
                print('Epoch: %d, Batch: %d/%d, LR: %.4f, TrainPPL: %.2f, |Param|: %.4f, BestValPerf: %.2f, Throughput: %.2f examples/sec' % 
                      (epoch, b, len(train_data), args.lr, np.exp(train_nll / num_words), 
                       param_norm, best_val_ppl, num_sents / (time.time() - start_time)))
        print('--------------------------------')
        print('Checking validation perf...')    
        val_ppl = eval(val_data, model,  count_eos_ppl = args.count_eos_ppl)
        print('--------------------------------')
        if val_ppl < best_val_ppl:
            best_val_ppl = val_ppl
            checkpoint = {
                'args': args.__dict__,
                'model': model.cpu(),
                'word2idx': train_data.word2idx,
                'idx2word': train_data.idx2word
            }
            print('Saving checkpoint to %s' % args.save_path)
            torch.save(checkpoint, args.save_path)
            model.cuda()
        else:
            if epoch > args.min_epochs:
                decay = 1
        if decay == 1:
            args.lr = args.decay*args.lr
            for param_group in optimizer.param_groups:
                param_group['lr'] = args.lr
        if args.lr < 0.03:
            break
        print("Finished training")
#         scheduler.step(best_val_ppl)

In [8]:
def eval(data, model, count_eos_ppl = 0):
    model.eval()
    num_words = 0
    total_nll = 0.
    with torch.no_grad():
        for i in list(reversed(range(len(data)))):
            sents, length, batch_size, gold_actions, gold_spans, gold_binary_trees, other_data = data[i] 
            if length == 1: #we ignore length 1 sents in URNNG eval so do this for LM too
                continue
            if args.count_eos_ppl == 1:
                length += 1 
            else:
                sents = sents[:, :-1] 
            sents = sents.cuda()
            num_words += length * batch_size
            nll = -model(sents).mean()
            total_nll += nll.item()*batch_size
    ppl = np.exp(total_nll / num_words)
    print('PPL: %.2f' % (ppl))
    model.train()
    return ppl

In [10]:
if __name__ == '__main__':
    args = parser.parse_args()
    main(args)

usage: ipykernel_launcher.py [-h] [--train_file TRAIN_FILE]
                             [--val_file VAL_FILE] [--test_file TEST_FILE]
                             [--train_from TRAIN_FROM] [--w_dim W_DIM]
                             [--h_dim H_DIM] [--num_layers NUM_LAYERS]
                             [--dropout DROPOUT]
                             [--count_eos_ppl COUNT_EOS_PPL] [--test TEST]
                             [--save_path SAVE_PATH] [--num_epochs NUM_EPOCHS]
                             [--min_epochs MIN_EPOCHS] [--lr LR]
                             [--decay DECAY] [--param_init PARAM_INIT]
                             [--max_grad_norm MAX_GRAD_NORM] [--gpu GPU]
                             [--seed SEED] [--print_every PRINT_EVERY]
                             [--distill DISTILL] [--temp TEMP]
ipykernel_launcher.py: error: unrecognized arguments: -f /home/jts75596/.local/share/jupyter/runtime/kernel-4386bad8-dad5-443c-a1b9-fb2090cf8c79.json


SystemExit: 2