In [1]:
import time
import math
import numpy as np
import torch
import torch.nn as nn
from model import WD_LSTM
from data import Corpus
from randomize_bptt import get_bptt_sequence_lengths
from helpers import Config, repackage_hidden, batchify, get_batch

In [2]:
SEED = 42
DATA = '/floyd/home/sein/'
CUDA = True
LOG_INTERVAL = 100
LR_ANNEALING_RATE = 0.5
CONFIG_NAME = 'language_model_base'
device = torch.device("cuda" if CUDA else "cpu")
args = Config(CONFIG_NAME)

In [3]:
torch.manual_seed(SEED)
np.random.seed(SEED)

In [4]:
corpus = Corpus(DATA)

In [5]:
ntokens = len(corpus.dictionary)
train_data = batchify(corpus.train, args.batch_size, device)
valid_data = batchify(corpus.valid, args.eval_batch_size, device)
test_data = batchify(corpus.test, args.test_batch_size, device)

In [6]:
model = WD_LSTM(
    ntoken=ntokens, 
    ninp=args.emsize,
    nhid=args.nhid, 
    nlayers=args.nlayers, 
    dropout=args.dropout,
    dropout_h=args.dropout_h,
    dropout_i=args.dropout_i,
    dropout_e=args.dropout_e,
    weight_drop=args.weight_drop, 
    weight_tying=args.weight_tying
).to(device)
model

WD_LSTM(
  (variational_dropout): VariationalDropout()
  (encoder): Embedding(18818, 400)
  (rnns): ModuleList(
    (0): WeightDrop(
      (module): LSTM(400, 1150)
    )
    (1): WeightDrop(
      (module): LSTM(1150, 1150)
    )
    (2): WeightDrop(
      (module): LSTM(1150, 400)
    )
  )
  (decoder): Linear(in_features=400, out_features=18818, bias=True)
)

In [7]:
lr = args.lr
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=args.weight_decay)
criterion = nn.CrossEntropyLoss()

In [8]:
def evaluate(data_source, batch_size):
    model.eval()  # disable dropout
    total_loss = 0.
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(batch_size)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, args.bptt_seq_len):
            data, targets = get_batch(data_source, i, args.bptt_seq_len)
            output, hidden, _, _ = model(data, hidden)
            hidden = repackage_hidden(hidden)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)

In [9]:
def ar_tar_regularization_loss(outputs, dropped_outputs):
    reg_loss = 0
    for i in range(len(dropped_outputs[-1])):
        reg_loss += args.alpha * dropped_outputs[-1][i].pow(2).mean()
        if i >= 1:
            reg_loss += args.beta * (outputs[-1][i] - outputs[-1][i - 1]).pow(2).mean()
    return reg_loss

In [10]:
def train():
    # Turn on training mode which enables dropout.
    model.train()
    total_loss = 0.
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(args.batch_size)
    for batch, (i, seq_len, lr_scale) in enumerate(get_bptt_sequence_lengths(
        train_data.size(0), 
        args.bptt_seq_len, 
        args.bptt_random_scaling, 
        args.bptt_p, 
        args.bptt_s, 
        args.bptt_min_len
    )):
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr * lr_scale
        data, targets = get_batch(train_data, i, seq_len)
        hidden = repackage_hidden(hidden)
        optimizer.zero_grad()
        output, hidden, outputs, dropped_outputs = model(data, hidden)
        output_flat = output.view(-1, ntokens)
        unregularized_loss = criterion(output_flat, targets)
        regularized_loss = unregularized_loss + ar_tar_regularization_loss(outputs, dropped_outputs)
        regularized_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()
        total_loss += unregularized_loss.item()
        if batch % LOG_INTERVAL == 0 and batch > 0:
            cur_loss = total_loss / LOG_INTERVAL
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:3.2E} | ms/batch {:5.2f} | loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_data) // args.bptt_seq_len, lr,
                elapsed * 1000 / LOG_INTERVAL, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

In [11]:
best_val_loss = 1e20
try:
    for epoch in range(1, args.epochs + 1):
        epoch_start_time = time.time()
        train()
        val_loss = evaluate(valid_data, args.eval_batch_size)
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | valid ppl {:8.2f}'.format(
            epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss)))
        print('-' * 89)
        if val_loss < best_val_loss:
            with open(args.save, 'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss
        else:
            lr *= LR_ANNEALING_RATE
except KeyboardInterrupt:
    print('-' * 89)
    print('Exiting from training early')

  result = self.forward(*input, **kwargs)


| epoch   1 |   100/  523 batches | lr 1.00E-03 | ms/batch 250.88 | loss  6.20 | ppl   494.79
| epoch   1 |   200/  523 batches | lr 1.00E-03 | ms/batch 239.22 | loss  4.98 | ppl   145.60
| epoch   1 |   300/  523 batches | lr 1.00E-03 | ms/batch 235.55 | loss  4.73 | ppl   113.71
| epoch   1 |   400/  523 batches | lr 1.00E-03 | ms/batch 243.45 | loss  4.64 | ppl   103.29
| epoch   1 |   500/  523 batches | lr 1.00E-03 | ms/batch 235.72 | loss  4.52 | ppl    92.26
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 143.37s | valid loss  4.45 | valid ppl    85.98
-----------------------------------------------------------------------------------------
| epoch   2 |   100/  523 batches | lr 1.00E-03 | ms/batch 239.76 | loss  4.39 | ppl    80.37
| epoch   2 |   200/  523 batches | lr 1.00E-03 | ms/batch 234.91 | loss  4.26 | ppl    70.51
| epoch   2 |   300/  523 batches | lr 1.00E-03 | ms/batch 243.75 | loss  4.23 | ppl   

In [12]:
# Load the best saved model.
with open(args.save, 'rb') as f:
    model = torch.load(f, map_location=device)

In [13]:
# Run on test data.
test_loss = evaluate(test_data, args.test_batch_size)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(test_loss, math.exp(test_loss)))
print('=' * 89)

KeyboardInterrupt: 