In [1]:
import skorch
import torch
import torch.nn as nn

import enwik8_data
import imp
import models

In [2]:
import visdom
vis = visdom.Visdom()

In [3]:
raw_data = enwik8_data.hutter_raw_data(data_path='./data/')

In [4]:
TRAIN_DATA, VALID_DATA, TEST_DATA, unique_syms = raw_data

In [5]:
EMBEDDING_SIZE = len(unique_syms)

In [18]:
def collate(g):
    for x, y in g:
        yield torch.from_numpy(x).long(), torch.from_numpy(y).long()

class Enwik8Loader:
    def __init__(self, _dataset, batch_size=128, num_steps=32, max_samples=None, **kwargs):
        self.max_samples = max_samples
        self.batch_size = batch_size
        self.num_steps = num_steps
    def __iter__(self):
        return collate(enwik8_data.data_iterator(
            self.dataset[slice(0, self.max_samples)], 
            self.batch_size, 
            self.num_steps))

class Enwik8TrainLoader(Enwik8Loader):
    dataset = TRAIN_DATA
    
class Enwik8ValidLoader:
    dataset = VALID_DATA

In [7]:
def time_flatten(t):
    return t.view(t.size(0) * t.size(1), -1)

def time_unflatten(t, s):
    return t.view(s[0], s[1], -1)

In [38]:
class ReconModel(nn.Module):
    def __init__(self, num_hidden=64, num_modules=8):
        super().__init__()
        
        self.emb = nn.Embedding(EMBEDDING_SIZE, num_hidden)
        self.rnn = models.SurprisalCWRNN(num_hidden, num_hidden, num_modules)
        self.clf = nn.Linear(num_hidden, EMBEDDING_SIZE)
        
        self.softmax = nn.LogSoftmax()
        
    def forward(self, x):
        x_emb = self.emb(x.long())
        l0, h0, m0 = self.rnn(x_emb)
        
        vis.heatmap(l0[0].data.cpu().numpy(), win="act")
        vis.heatmap(m0[0].data.cpu().numpy(), win="periods")
        vis.heatmap(self.rnn.module_shifts.data.cpu().numpy().reshape(1, -1), win="shifts")

        l1 = self.clf(time_flatten(l0))
        l1_sm = self.softmax(l1)
        
        return time_unflatten(l1_sm, x.size())

In [39]:
class Trainer(skorch.NeuralNet):
    def __init__(self, 
                 criterion=nn.NLLLoss,
                 *args, 
                 **kwargs):
        super().__init__(*args, criterion=criterion, **kwargs)

    def get_loss(self, y_pred, y_true, X=None, train=False):
        pred = time_flatten(y_pred)
        true = time_flatten(y_true).squeeze(-1)
        return super().get_loss(pred, true, X=X, train=train)

In [40]:
import time
import sys

class BatchPrinter(skorch.callbacks.Callback):
    def __init__(self, update_interval=5):
        self.batches_per_epoch = None
        self.batch_counter = 0
        self.update_interval = update_interval
    def on_batch_begin(self, *args, **kwargs):
        self.batch_start_time = time.time()
    def on_batch_end(self, net, *args, train=True, **kwargs):
        self.batch_end_time = time.time()
        self.batch_counter += 1
        if self.batch_counter % self.update_interval != 0:
            return
        
        k = 'train_loss' if train else 'valid_loss'
        loss = '{}: {:.3}'.format(k, net.history[-1, 'batches', -1, k])
        
        sys.stdout.write("Batch {}/{} complete ({:.2}s), {}.\r".format(
            self.batch_counter, 
            self.batches_per_epoch,
            self.batch_end_time - self.batch_start_time,
            loss,
        ))
        sys.stdout.flush()
    def on_epoch_end(self, *args, **kwargs):
        if self.batches_per_epoch is None:
            self.batches_per_epoch = self.batch_counter
        self.batch_counter = 0

In [41]:
torch.manual_seed(1337)

ef = Trainer(module=ReconModel,
             optim=torch.optim.Adam,
             lr=0.005,
             max_epochs=5,
                  
             train_split=None,
             iterator_train=Enwik8TrainLoader,
             iterator_train__batch_size=32,
             iterator_train__num_steps=32,
             iterator_test=Enwik8ValidLoader,
             iterator_test__batch_size=32,
             iterator_test__num_steps=32,
             
             use_cuda=True,
             
             module__num_modules=8,
             module__num_hidden=64,
             
             callbacks=[BatchPrinter()]
            )

In [42]:
import imp; imp.reload(models)

<module 'models' from '/home/nemo/Code/pytorch/work/membank/models.py'>

In [43]:
%pdb on
ef.fit(torch.zeros((10,1)), torch.zeros((10,)))

Automatic pdb calling has been turned ON
  epoch    train_loss         dur, train_loss: 2.18.
-------  ------------  ----------
      1        [36m2.0787[0m  16946.4112
      2        2.1018  16848.5859), train_loss: 2.14.
Exception in user code:ete (0.18s), train_loss: 2.11.
------------------------------------------------------------


Traceback (most recent call last):
  File "/home/nemo/Code/pytorch/env/lib/python3.5/site-packages/visdom/__init__.py", line 240, in _send
    data=json.dumps(msg),
  File "/home/nemo/Code/pytorch/env/lib/python3.5/site-packages/requests/api.py", line 112, in post
    return request('post', url, data=data, json=json, **kwargs)
  File "/home/nemo/Code/pytorch/env/lib/python3.5/site-packages/requests/api.py", line 58, in request
    return session.request(method=method, url=url, **kwargs)
  File "/home/nemo/Code/pytorch/env/lib/python3.5/site-packages/requests/sessions.py", line 508, in request
    resp = self.send(prep, **send_kwargs)
  File "/home/nemo/Code/pytorch/env/lib/python3.5/site-packages/requests/sessions.py", line 618, in send
    r = adapter.send(request, **kwargs)
  File "/home/nemo/Code/pytorch/env/lib/python3.5/site-packages/requests/adapters.py", line 440, in send
    timeout=timeout
  File "/home/nemo/Code/pytorch/env/lib/python3.5/site-packages/urllib3/connectionpool.p

Exception in user code:ete (0.23s), train_loss: 2.27.
------------------------------------------------------------
Batch 35415/87890 complete (0.2s), train_loss: 2.13.

Traceback (most recent call last):
  File "/home/nemo/Code/pytorch/env/lib/python3.5/site-packages/visdom/__init__.py", line 240, in _send
    data=json.dumps(msg),
  File "/home/nemo/Code/pytorch/env/lib/python3.5/site-packages/requests/api.py", line 112, in post
    return request('post', url, data=data, json=json, **kwargs)
  File "/home/nemo/Code/pytorch/env/lib/python3.5/site-packages/requests/api.py", line 58, in request
    return session.request(method=method, url=url, **kwargs)
  File "/home/nemo/Code/pytorch/env/lib/python3.5/site-packages/requests/sessions.py", line 508, in request
    resp = self.send(prep, **send_kwargs)
  File "/home/nemo/Code/pytorch/env/lib/python3.5/site-packages/requests/sessions.py", line 618, in send
    r = adapter.send(request, **kwargs)
  File "/home/nemo/Code/pytorch/env/lib/python3.5/site-packages/requests/adapters.py", line 405, in send
    conn = self.get_connection(request.url, proxies)
  File "/home/nemo/Code/pytorch/env/lib/python3.5/site-

<__main__.Trainer at 0x7fba520ef7f0>

# Clocking CWRNN

In [21]:
%pdb on
ef.fit(torch.zeros((10,1)), torch.zeros((10,)))

Automatic pdb calling has been turned ON
  epoch    train_loss         dur, train_loss: 2.38..
-------  ------------  ----------
      1        [36m2.3205[0m  10506.3287
      2        [36m2.2261[0m  10184.8781loss: 2.11..
      3        [36m2.0944[0m  10190.3140loss: 2.15..
      4        2.1289  10192.3899), train_loss: 2.16..
      5        2.1656  10702.8350), train_loss: 2.22..


<__main__.Trainer at 0x7fba66043ac8>