# Word-level Language Modeling

## Overview

In word-level language modeling tasks, each element of the sequence is a word, where the model is expected to predict the next incoming in the test. We evaluate the temporal convolutional network as a word-level language model on three datasets: PennTreebank (PTB), Wikitext-103. and LAMBADA.

Because the evaluation of LAMBADA has different requirement (predicting only the very last word based on a broader context), we put it in another script. See `lambada_language.ipynb`.

## Setting

In [1]:
import torch as th
import torch.nn as nn
import math
from tqdm.notebook import tqdm
import torch.nn.functional as F
BATCH_SIZE = 16
DEVICE = "cuda:0"
DROPOUT = 0.45
EMB_DROPOUT = 0.25
CLIP = 0.35
EPOCHS = 10
KSIZE = 3
DATA_ROOT = "/home/densechen/dataset/penn"
EMSIZE = 600
LEVELS = 4
LR = 4
NHID = 600
SEED = 1111
TIED = True
OPTIM = "SGD"
VALID_SEQ_LEN = 40
SEQ_LEN = 80
CORPUS = False

CHANNEL_SIZES = [NHID] * (LEVELS - 1) + [EMSIZE]
EVAL_BATCH_SIZE = 10

th.manual_seed(SEED)

<torch._C.Generator at 0x7ff4c8115990>

## Data Generator

The meaning of batch_size in PTB is different from that in MNIST example. In MNIST, 
batch_size is the # of sample data that is considered in each iteration; in PTB, however,
it is the number of segments to speed up computation. 

The goal of PTB is to train a language model to predict the next word.

You should download the dataset from [here](https://github.com/locuslab/TCN/tree/master/TCN/word_cnn/data/penn), and then place it under `DATA_ROOT`.

In [2]:
import os
import pickle

def data_generator():
    if os.path.exists(os.path.join(DATA_ROOT, "corpus")) and not CORPUS:
        corpus = pickle.load(open(os.path.join(DATA_ROOT, 'corpus'), 'rb'))
    else:
        corpus = Corpus(DATA_ROOT)
        pickle.dump(corpus, open(os.path.join(DATA_ROOT, 'corpus'), 'wb'))
    return corpus


class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)


class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path, 'train.txt'))
        self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
        self.test = self.tokenize(os.path.join(path, 'test.txt'))

    def tokenize(self, path):
        """Tokenizes a text file."""
        assert os.path.exists(path)
        # Add words to the dictionary
        with open(path, 'r') as f:
            tokens = 0
            for line in f:
                words = line.split() + ['<eos>']
                tokens += len(words)
                for word in words:
                    self.dictionary.add_word(word)

        # Tokenize file content
        with open(path, 'r') as f:
            ids = th.LongTensor(tokens)
            token = 0
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    ids[token] = self.dictionary.word2idx[word]
                    token += 1

        return ids


def batchify(data, batch_size):
    """The output should have size [L x batch_size], where L could be a long sequence length"""
    # Work out how cleanly we can divide the dataset into batch_size parts (i.e. continuous seqs).
    nbatch = len(data) // batch_size
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * batch_size)
    # Evenly divide the data across the batch_size batches.
    data = data.view(batch_size, -1)
    data = data.to(DEVICE)
    return data


def get_batch(source, i, evaluation=False):
    seq_len = min(SEQ_LEN, source.size(1) - 1 - i)
    data = source[:, i:i+seq_len]
    target = source[:, i+1:i+1+seq_len]     # CAUTION: This is un-flattened!
    return data, target

print("Producing data...")
corpus = data_generator()
n_words = len(corpus.dictionary)

eval_batch_size = 10
train_data = batchify(corpus.train, BATCH_SIZE)
val_data = batchify(corpus.valid, EVAL_BATCH_SIZE)
test_data = batchify(corpus.test, EVAL_BATCH_SIZE)
print("Finished.")

Producing data...
Finished.


## Define model

In [3]:
from core.tcn import TemporalConvNet

class TCN(nn.Module):

    def __init__(self, input_size, output_size, num_channels,
                 kernel_size=2, dropout=0.3, emb_dropout=0.1, tied_weights=False):
        super(TCN, self).__init__()
        self.encoder = nn.Embedding(output_size, input_size)
        self.tcn = TemporalConvNet(input_size, num_channels, kernel_size, dropout=dropout)

        self.decoder = nn.Linear(num_channels[-1], output_size)
        if tied_weights:
            if num_channels[-1] != input_size:
                raise ValueError('When using the tied flag, nhid must be equal to emsize')
            self.decoder.weight = self.encoder.weight
            print("Weight tied")
        self.drop = nn.Dropout(emb_dropout)
        self.emb_dropout = emb_dropout
        
    def forward(self, input):
        """Input ought to have dimension (N, C_in, L_in), where L_in is the seq_len; here the input is (N, L, C)"""
        emb = self.drop(self.encoder(input))
        y = self.tcn(emb.transpose(1, 2)).transpose(1, 2)
        y = self.decoder(y)
        return y.contiguous()

print("Building model...")

model = TCN(EMSIZE, n_words, CHANNEL_SIZES, 
            dropout=DROPOUT, emb_dropout=EMB_DROPOUT, kernel_size=KSIZE, tied_weights=TIED)
model = model.to(DEVICE)

# May use adaptive softmax to speed up training
optimizer = getattr(th.optim, OPTIM)(model.parameters(), lr=LR)
print("Finished.")

Building model...
Weight tied
Finished.


## Run

In [5]:
def evaluate(data_source):
    model.eval()
    total_loss = 0
    processed_data_size = 0
    with th.no_grad():
        for i in range(0, data_source.size(1) - 1, VALID_SEQ_LEN):
            if i + SEQ_LEN - VALID_SEQ_LEN >= data_source.size(1) - 1:
                continue
            data, targets = get_batch(data_source, i, evaluation=True)
            output = model(data)

            # Discard the effective history, just like in training
            eff_history = SEQ_LEN - VALID_SEQ_LEN
            final_output = output[:, eff_history:].contiguous().view(-1, n_words)
            final_target = targets[:, eff_history:].contiguous().view(-1)

            loss = F.cross_entropy(final_output, final_target)

            # Note that we don't add TAR loss here
            total_loss += (data.size(1) - eff_history) * loss.item()
            processed_data_size += data.size(1) - eff_history
    return total_loss / processed_data_size


def train(ep):
    # Turn on training mode which enables dropout.
    model.train()
    process = tqdm(range(0, train_data.size(1) - 1, VALID_SEQ_LEN))
    for i in process:
        if i + SEQ_LEN - VALID_SEQ_LEN >= train_data.size(1) - 1:
            continue
        data, targets = get_batch(train_data, i)
        optimizer.zero_grad()
        output = model(data)

        # Discard the effective history part
        eff_history = SEQ_LEN - VALID_SEQ_LEN
        if eff_history < 0:
            raise ValueError("Valid sequence length must be smaller than sequence length!")
        final_target = targets[:, eff_history:].contiguous().view(-1)
        final_output = output[:, eff_history:].contiguous().view(-1, n_words)
        loss = F.cross_entropy(final_output, final_target)

        loss.backward()
        if CLIP > 0:
            th.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
        optimizer.step()
        
        process.set_description(f"Train Epoch: {ep:2d}, loss: {loss.item():.6f}")

for epoch in range(1, EPOCHS+1):
    train(epoch)
    val_loss = evaluate(val_data)
    test_loss = evaluate(test_data)

    print('-' * 89)
    print(f'| end of epoch {epoch:3d} | valid loss {val_loss:5.2f}')
    print(f'| end of epoch {epoch:3d} | test loss {test_loss:5.2f}')
    print('-' * 89)


  0%|          | 0/1453 [00:00<?, ?it/s]

-----------------------------------------------------------------------------------------
| end of epoch   1 | valid loss 53.68
| end of epoch   1 | test loss 52.49
-----------------------------------------------------------------------------------------


  0%|          | 0/1453 [00:00<?, ?it/s]

-----------------------------------------------------------------------------------------
| end of epoch   2 | valid loss 44.76
| end of epoch   2 | test loss 43.53
-----------------------------------------------------------------------------------------


  0%|          | 0/1453 [00:00<?, ?it/s]

-----------------------------------------------------------------------------------------
| end of epoch   3 | valid loss 38.78
| end of epoch   3 | test loss 37.64
-----------------------------------------------------------------------------------------


  0%|          | 0/1453 [00:00<?, ?it/s]

-----------------------------------------------------------------------------------------
| end of epoch   4 | valid loss 34.69
| end of epoch   4 | test loss 33.50
-----------------------------------------------------------------------------------------


  0%|          | 0/1453 [00:00<?, ?it/s]

-----------------------------------------------------------------------------------------
| end of epoch   5 | valid loss 31.75
| end of epoch   5 | test loss 30.63
-----------------------------------------------------------------------------------------


  0%|          | 0/1453 [00:00<?, ?it/s]

-----------------------------------------------------------------------------------------
| end of epoch   6 | valid loss 29.38
| end of epoch   6 | test loss 28.31
-----------------------------------------------------------------------------------------


  0%|          | 0/1453 [00:00<?, ?it/s]

-----------------------------------------------------------------------------------------
| end of epoch   7 | valid loss 27.87
| end of epoch   7 | test loss 26.88
-----------------------------------------------------------------------------------------


  0%|          | 0/1453 [00:00<?, ?it/s]

-----------------------------------------------------------------------------------------
| end of epoch   8 | valid loss 26.76
| end of epoch   8 | test loss 25.79
-----------------------------------------------------------------------------------------


  0%|          | 0/1453 [00:00<?, ?it/s]

-----------------------------------------------------------------------------------------
| end of epoch   9 | valid loss 25.38
| end of epoch   9 | test loss 24.48
-----------------------------------------------------------------------------------------


  0%|          | 0/1453 [00:00<?, ?it/s]

-----------------------------------------------------------------------------------------
| end of epoch  10 | valid loss 23.91
| end of epoch  10 | test loss 23.03
-----------------------------------------------------------------------------------------
