In [1]:
import torch
import torch.nn as nn
import skorch
import numpy as np

from torch.autograd import Variable

In [2]:
%matplotlib inline
from matplotlib import pyplot as plt

In [3]:
import enwik8_data

In [4]:
train, valid, test, word_to_id = enwik8_data.ptb_raw_data('./data/penn/')

In [5]:
id_to_word = {v: k for (k, v) in word_to_id.items()}

In [6]:
[id_to_word[train[n]] for n in range(10)]

['pierre', '<unk>', 'N', 'years', 'old', 'will', 'join', 'the', 'board', 'as']

In [7]:
X_train = np.concatenate([tx for tx, _ in enwik8_data.data_iterator(train, 1, 15)]).astype('int64')
y_train = np.concatenate([ty for _, ty in enwik8_data.data_iterator(train, 1, 15)]).astype('int64')
X_valid = np.concatenate([tx for tx, _ in enwik8_data.data_iterator(valid, 1, 15)]).astype('int64')
y_valid = np.concatenate([ty for _, ty in enwik8_data.data_iterator(valid, 1, 15)]).astype('int64')

In [8]:
limit = 1000
X_train = X_train[:limit]
y_train = y_train[:limit]
X_valid = X_valid[:limit]
y_valid = y_valid[:limit]

In [9]:
class Clocking(nn.Module):
    def __init__(self, input_dim, hidden_dim, act='tanh', update_state_with_output=False):
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.update_state_with_output = update_state_with_output
        
        self.i2h = nn.Linear(input_dim, hidden_dim)
        self.h2h = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.period = nn.Parameter(torch.randn(hidden_dim) + 1)
        #self.period = nn.Parameter(torch.ones(hidden_dim) - 0.5)
        self.act = {'relu': nn.ReLU, 'tanh': nn.Tanh}[act]()
        
    def forward(self, x, clock=None, h=None): # x is (b, t, u) 
        if h is None:
            h = torch.zeros(self.hidden_dim)
            h = skorch.utils.to_var(h, x.is_cuda)
            
        ys = []
        cs = []
        for ti in range(x.size(1)):
            clock_gate = (torch.sin(ti/self.period * np.pi + np.pi/2) + 1) / 2
            # enable stacking of clocking functions
            if clock is not None:
                clock_gate = (clock_gate + clock[ti]) / 2

            h_new = self.i2h(x[:, ti]) + self.h2h(h)
            h = clock_gate * h_new + (1 - clock_gate) * h
            y = self.act(h)
            if self.update_state_with_output:
                h = y
            ys.append(y)
            cs.append(clock_gate)
        return torch.stack(ys, dim=1), h, torch.stack(cs, dim=0)

In [10]:
import visdom
vis = visdom.Visdom()

In [11]:
def time_flatten(t):
    return t.view(t.size(0) * t.size(1), -1)

def time_unflatten(t, s):
    return t.view(s[0], s[1], -1)

In [12]:
class ReconModel(nn.Module):
    def __init__(self, num_tokens, num_hidden=64, act='tanh', num_layers=1, inherit_clocks=True):
        super().__init__()
        
        self.rnn = []
        self.num_tokens = num_tokens
        self.num_layers = num_layers
        self.inherit_clocks = inherit_clocks
        
        self.emb = nn.Embedding(num_tokens, num_hidden)
        
        for i in range(num_layers):
            self.rnn.append(Clocking(num_hidden, num_hidden, act=act))
            self.add_module('rnn'+str(i), self.rnn[-1])
        
        self.clf = nn.Linear(num_hidden, num_tokens)
        self.softmax = nn.LogSoftmax()
        
    def forward(self, x):
        li = self.emb(x.long())
        ci = None
        for i in range(self.num_layers):
            if not self.inherit_clocks:
                ci = None
            li, hi, ci = self.rnn[i](li, clock=ci)
            vis.heatmap(skorch.utils.to_numpy(li[0]), opts={'title': 'act rnn'+str(i)}, win="act rnn"+str(i))
            vis.heatmap(skorch.utils.to_numpy(self.rnn[i].period).reshape(1, -1), opts={'title': 'periods rnn'+str(i)}, win="periods rnn"+str(i))
            vis.heatmap(skorch.utils.to_numpy(ci), opts={'title': 'clock rnn'+str(i)}, win="clock rnn"+str(i))
        l1 = self.clf(time_flatten(li))
        l1 = self.softmax(l1)
        return time_unflatten(l1, x.size())

In [13]:
class Trainer(skorch.NeuralNet):
    def __init__(
        self, 
        criterion=nn.NLLLoss,
        *args, 
        **kwargs
    ):
        super().__init__(*args, criterion=criterion, **kwargs)

    def get_loss(self, y_pred, y_true, X=None, training=False):
        pred = time_flatten(y_pred)
        true = time_flatten(y_true).squeeze(-1)
        return super().get_loss(pred, true, X=X, training=training)

In [14]:
torch.manual_seed(1337)

def my_train_split(X, y):
    return X, X_valid, y, y_valid

ef = Trainer(module=ReconModel,
             optimizer=torch.optim.Adam,
             lr=0.01,
             max_epochs=10,
             train_split=my_train_split,
             
             module__num_tokens=len(word_to_id),
             module__num_hidden=32,
             module__act='relu',
             module__num_layers=2,
             module__inherit_clocks=True,
             use_cuda=False,
             batch_size=16,
             
             callbacks=[skorch.callbacks.ProgressBar()],
            )

In [15]:
%pdb on
ef.fit(X_train, y_train)

Automatic pdb calling has been turned ON





  epoch    train_loss    valid_loss      dur
-------  ------------  ------------  -------
      1        [36m7.5952[0m        [32m7.0461[0m  24.8386



      2        [36m6.4820[0m        7.1880  25.7292



      3        [36m6.3590[0m        7.2500  25.0049



      4        [36m6.2793[0m        7.3777  25.4389



      5        6.3259        7.9880  24.7627



      6        [36m6.2623[0m        8.3791  25.1098



      7        [36m6.2414[0m        8.4663  25.4330



      8        6.2600        7.5666  24.5356



      9        6.2681        7.6816  25.5026



     10        [36m6.2155[0m        7.8414  24.9745


<class '__main__.Trainer'>[initialized](
  module_=ReconModel(
    (emb): Embedding(15488, 32)
    (rnn0): Clocking(
      (i2h): Linear(in_features=32, out_features=32)
      (h2h): Linear(in_features=32, out_features=32)
      (act): ReLU()
    )
    (rnn1): Clocking(
      (i2h): Linear(in_features=32, out_features=32)
      (h2h): Linear(in_features=32, out_features=32)
      (act): ReLU()
    )
    (clf): Linear(in_features=32, out_features=15488)
    (softmax): LogSoftmax()
  ),
)

In [16]:
pred = ef.predict_proba(X_train)



In [17]:
[id_to_word[n] for n in pred[4].argmax(-1)]

['the',
 'the',
 'the',
 'the',
 'the',
 'the',
 'the',
 'the',
 'of',
 'the',
 'the',
 'the',
 'the',
 'of',
 'the']