In [21]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import skorch
import numpy as np

from torch.autograd import Variable

In [22]:
%matplotlib inline
from matplotlib import pyplot as plt

In [23]:
import enwik8_data

In [24]:
train, valid, test, word_to_id = enwik8_data.ptb_raw_data('./data/penn/')

In [25]:
id_to_word = {v: k for (k, v) in word_to_id.items()}

In [26]:
[id_to_word[train[n]] for n in range(10)]

['pierre', '<unk>', 'N', 'years', 'old', 'will', 'join', 'the', 'board', 'as']

In [27]:
X_train = np.concatenate([tx for tx, _ in enwik8_data.data_iterator(train, 1, 15)]).astype('int64')
y_train = np.concatenate([ty for _, ty in enwik8_data.data_iterator(train, 1, 15)]).astype('int64')
X_valid = np.concatenate([tx for tx, _ in enwik8_data.data_iterator(valid, 1, 15)]).astype('int64')
y_valid = np.concatenate([ty for _, ty in enwik8_data.data_iterator(valid, 1, 15)]).astype('int64')

In [28]:
limit = 1000
X_train = X_train[:limit]
y_train = y_train[:limit]
X_valid = X_valid[:limit]
y_valid = y_valid[:limit]

In [29]:
class GaussClocking(nn.Module):
    def __init__(self, input_dim, hidden_dim, act='tanh'):
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        self.i2h = nn.Linear(input_dim, hidden_dim)
        self.h2h = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.m = nn.Parameter(torch.ones(hidden_dim) - 0.5)
        self.s = nn.Parameter(torch.ones(hidden_dim))
        self.act = {'relu': nn.ReLU, 'tanh': nn.Tanh}[act]()
        
    def forward(self, x, h=None): # x is (b, t, u) 
        if h is None:
            h = torch.zeros(self.hidden_dim)
            h = skorch.utils.to_var(h, x.is_cuda)
            
        clks = torch.randn(x.size(1), self.hidden_dim)
        clks = skorch.utils.to_var(clks, x.is_cuda)
            
        ys = []
        cs = []
        for ti in range(x.size(1)):
            clock_gate = clks[ti] * self.s + self.m
            clock_gate = F.sigmoid(clock_gate)

            h_new = self.i2h(x[:, ti]) + self.h2h(h)
            h = clock_gate * h_new + (1 - clock_gate) * h
            y = self.act(h)
            ys.append(y)
            cs.append(clock_gate)
        return torch.stack(ys, dim=1), h, torch.stack(cs, dim=0)

In [30]:
import visdom
vis = visdom.Visdom()

In [31]:
def time_flatten(t):
    return t.view(t.size(0) * t.size(1), -1)

def time_unflatten(t, s):
    return t.view(s[0], s[1], -1)

In [32]:
class ReconModel(nn.Module):
    def __init__(self, num_tokens, num_hidden=64, act='tanh', num_layers=1):
        super().__init__()
        
        self.rnn = []
        self.num_tokens = num_tokens
        self.num_layers = num_layers
        
        self.emb = nn.Embedding(num_tokens, num_hidden)
        
        for i in range(num_layers):
            self.rnn.append(GaussClocking(num_hidden, num_hidden, act=act))
            self.add_module('rnn'+str(i), self.rnn[-1])
        
        self.clf = nn.Linear(num_hidden, num_tokens)
        self.softmax = nn.LogSoftmax()
        
    def forward(self, x):
        li = self.emb(x.long())
        for i in range(self.num_layers):
            li, hi, ci = self.rnn[i](li)
            vis.heatmap(skorch.utils.to_numpy(li[0]), opts={'title': 'act rnn'+str(i)}, win="act rnn"+str(i))
            vis.heatmap(skorch.utils.to_numpy(self.rnn[i].m).reshape(1, -1), opts={'title': 'mu rnn'+str(i)}, win="mu rnn"+str(i))
            vis.heatmap(skorch.utils.to_numpy(self.rnn[i].s).reshape(1, -1), opts={'title': 's rnn'+str(i)}, win="s rnn"+str(i))
        l1 = self.clf(time_flatten(li))
        l1 = self.softmax(l1)
        return time_unflatten(l1, x.size())

In [33]:
class Trainer(skorch.NeuralNet):
    def __init__(
        self, 
        criterion=nn.NLLLoss,
        *args, 
        **kwargs
    ):
        super().__init__(*args, criterion=criterion, **kwargs)

    def get_loss(self, y_pred, y_true, X=None, training=False):
        pred = time_flatten(y_pred)
        true = time_flatten(y_true).squeeze(-1)
        return super().get_loss(pred, true, X=X, training=training)

In [44]:
torch.manual_seed(1337)

def my_train_split(X, y):
    return X, X_valid, y, y_valid

ef = Trainer(module=ReconModel,
             optimizer=torch.optim.Adam,
             lr=0.02,
             max_epochs=10,
             train_split=my_train_split,
             
             module__num_tokens=len(word_to_id),
             module__num_hidden=32,
             module__act='relu',
             module__num_layers=2,
             use_cuda=False,
             batch_size=16,
             
             callbacks=[skorch.callbacks.ProgressBar()],
            )

In [45]:
%pdb on
ef.fit(X_train, y_train)

Automatic pdb calling has been turned ON




  epoch    train_loss    valid_loss      dur
-------  ------------  ------------  -------
      1        [36m7.7065[0m        [32m7.0409[0m  23.7901


      2        [36m6.4705[0m        7.1970  24.7283


      3        [36m6.2035[0m        7.4912  24.5867


      4        [36m6.0157[0m        7.9677  24.1973


      5        [36m5.8896[0m        8.2833  25.1713


      6        [36m5.7982[0m        7.5188  26.0449


      7        [36m5.6660[0m        7.6963  26.4944


      8        [36m5.5124[0m        7.7856  26.1957


      9        [36m5.3497[0m        8.2676  25.1216


     10        [36m5.2689[0m        8.7494  25.2007


<class '__main__.Trainer'>[initialized](
  module_=ReconModel(
    (emb): Embedding(15488, 32)
    (rnn0): GaussClocking(
      (i2h): Linear(in_features=32, out_features=32)
      (h2h): Linear(in_features=32, out_features=32)
      (act): ReLU()
    )
    (rnn1): GaussClocking(
      (i2h): Linear(in_features=32, out_features=32)
      (h2h): Linear(in_features=32, out_features=32)
      (act): ReLU()
    )
    (clf): Linear(in_features=32, out_features=15488)
    (softmax): LogSoftmax()
  ),
)

In [46]:
pred = ef.predict_proba(X_train)



In [51]:
[id_to_word[n] for n in pred[4].argmax(-1)]

['<unk>',
 '<unk>',
 '<unk>',
 '<unk>',
 'of',
 'the',
 '<unk>',
 '<unk>',
 'the',
 'few',
 '<unk>',
 'the',
 '<unk>',
 'to',
 'the']