In [38]:
# Text text processing library
import torchtext
from torchtext.vocab import Vectors
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
debug = False

In [66]:
# Our input $x$
TEXT = torchtext.data.Field()

# Data distributed with the assignment
train, val, test = torchtext.datasets.LanguageModelingDataset.splits(
    path=".", 
    train="train.txt", validation="valid.txt", test="valid.txt", text_field=TEXT)

TEXT.build_vocab(train)
if debug:
    TEXT.build_vocab(train, max_size=1000)

train_iter, val_iter, test_iter = torchtext.data.BPTTIterator.splits(
    (train, val, test), batch_size=10, device=-1, bptt_len=32, repeat=False)

In [40]:
it = iter(train_iter)
batch = next(it)
print(batch.text.size())
# print(batch.text[:,3])
print(' '.join([TEXT.vocab.itos[i] for i in batch.text[:,5].data]))
print(' '.join([TEXT.vocab.itos[i] for i in batch.text[:,6].data]))

torch.Size([32, 10])
the government can ensure the same flow of resources and reduce the current deficit <eos> predictably guarantees outstanding have risen by $ N billion since N while direct loans outstanding have fallen
comes across as a <unk> executive mr. phillips has a <unk> <unk> <eos> during time off mr. roman tends to his garden mr. phillips <unk> to a <unk> for among other things


Perplexity goals:
count: 120-200
feedforward: 100-150
recurrent: below 100 (between 80-100)

In [95]:
tgram = Trigram(TEXT)
tgram.train_counts(train_iter)

Iteration 0
Iteration 1000
Iteration 2000


In [94]:
class Trigram(nn.Module):
    def __init__(self, TEXT, args=None):
        super(Trigram, self).__init__()
        self._TEXT = TEXT
        self._text_vocab_len = len(TEXT.vocab)
        
        # Use dictionaries since we don't want to have to 
        # store the vast majority of bi/tri gram counts 
        # which are 0.
        self.cnts = [dict(), dict(), dict()]
        
    def set_alpha(self, *args):
        self.alphas = args
        if len(self.alphas) < 3:
            assert len(self.alphas) == 2
            self.alphas[2] = 1 - sum(self.alphas)
        
    def train_counts(self, train_iter):
        num_iter = len(train_iter)
        train_iter = iter(train_iter)
        for i in range(num_iter):
            batch = next(train_iter)
            if i % 1000 == 0:
                print('Iteration %d' % i)
            self.update_trigram_cnts(torch.t(batch.text).data.numpy())
            
    # Batch is a torch tensor of size [size_batch, sentence_len]; 
    # this returns the probability vectors for each of the words
    # TODO: havven't checked yet!
    def forward(self, batch):
        ret_arr = torch.zeros(batch.size()[0], self._text_vocab_len)
        for i in batch.size()[0]:
            for n in range(0,3):
                key = tuple(batch[i, -n:])
                if key in self.cnts[n]:
                    ret_arr[i,:] += self.alphas[n] * self.cnts[n][key]
        return ret_arr
                
    # Batch is an np array of size [batch_size, bptt_len]
    def update_trigram_cnts(self, batch):
        # We don't glue rows together since they may be shuffled 
        # (this is all kind of silly since ideally we'd just do 
        # this in one big 'sentence', but perhaps we want a 'fair 
        # comparison'...)
        for j in range(batch.shape[0]):
            for n in range(0,3):
                for k in range(batch.shape[1] - n):
                    dict_key = tuple(batch[j, k:k+n])
                    if not dict_key in self.cnts[n]:
                        self.cnts[n][dict_key] = np.zeros(self._text_vocab_len)
                    # Here's where we increment the ocunt
                    self.cnts[n][dict_key][batch[j, k+n]] += 1
                    
    # NOT USED!
    # Here arr is a 1-D numpy array; this returns 
    # groups of n consecutive words (with overlapping)
    def get_ngrams(self, arr, n=3):
        len_ngrams = arr.shape[0] - n + 1
        ngram_inds = np.tile(np.reshape(np.arange(len_ngrams), [len_ngrams, 1]), [1, n]) + \
                    np.tile(np.reshape(np.arange(n), [1, n]), [len_ngrams, 1])
        return np.take(arr, ngram_inds)
    
    
            

In [100]:
print(len(tgram.cnts[0][()]))

10001


3527539