In [28]:
import skorch
import torch
import torch.nn as nn

import enwik8_data
import imp
import models

In [83]:
import visdom
vis = visdom.Visdom()

In [3]:
raw_data = enwik8_data.hutter_raw_data(data_path='./data/')

In [19]:
TRAIN_DATA, VALID_DATA, TEST_DATA, unique_syms = raw_data

In [20]:
EMBEDDING_SIZE = len(unique_syms)

In [69]:
def collate(g):
    for x, y in g:
        yield torch.from_numpy(x).long(), torch.from_numpy(y).long()

class Enwik8TrainLoader:
    def __init__(self, _dataset, batch_size=128, num_steps=32, **kwargs):
        self.batch_size = batch_size
        self.num_steps = num_steps
    def __iter__(self):
        return collate(enwik8_data.data_iterator(TRAIN_DATA, self.batch_size, self.num_steps))

class Enwik8ValidLoader:
    def __init__(self, _dataset, batch_size=128, num_steps=32, **kwargs):
        self.batch_size = batch_size
        self.num_steps = num_steps
    def __iter__(self):
        return collate(enwik8_data.data_iterator(VALID_DATA, self.batch_size, self.num_steps))

In [48]:
def time_flatten(t):
    return t.view(t.size(0) * t.size(1), -1)

def time_unflatten(t, s):
    return t.view(s[0], s[1], -1)

In [88]:
class ReconModel(nn.Module):
    def __init__(self, num_hidden=64, num_modules=8):
        super().__init__()
        
        self.emb = nn.Embedding(EMBEDDING_SIZE, num_hidden)
        self.rnn = models.ClockingCWRNN(num_hidden, num_hidden, num_modules)
        self.clf = nn.Linear(num_hidden, EMBEDDING_SIZE)
        
        self.softmax = nn.LogSoftmax()
        
    def forward(self, x):
        x_emb = self.emb(x.long())
        l0, h0 = self.rnn(x_emb)
        
        vis.heatmap(l0[0].data.numpy(), win="act")
        vis.heatmap(self.rnn.module_periods.data.numpy().reshape(1, -1), win="periods")
        vis.heatmap(self.rnn.module_shifts.data.numpy().reshape(1, -1), win="shifts")

        l1 = self.clf(time_flatten(l0))
        l1_sm = self.softmax(l1)
        
        return time_unflatten(l1_sm, x.size())

In [89]:
class Trainer(skorch.NeuralNet):
    def __init__(self, 
                 criterion=nn.NLLLoss,
                 *args, 
                 **kwargs):
        super().__init__(*args, criterion=criterion, **kwargs)

    def get_loss(self, y_pred, y_true, X=None, train=False):
        pred = time_flatten(y_pred)
        true = time_flatten(y_true).squeeze(-1)
        return super().get_loss(pred, true, X=X, train=train)

In [99]:
import time
import sys

class BatchPrinter(skorch.callbacks.Callback):
    def __init__(self):
        self.batches_per_epoch = None
        self.batch_counter = 0
    def on_batch_begin(self, *args, **kwargs):
        self.batch_start_time = time.time()
    def on_batch_end(self, *args, **kwargs):
        self.batch_end_time = time.time()
        self.batch_counter += 1
        sys.stdout.write("Batch {}/{} complete ({:.2}s).\r".format(
            self.batch_counter, 
            self.batches_per_epoch,
            self.batch_end_time - self.batch_start_time,
        ))
        sys.stdout.flush()
    def on_epoch_end(self, *args, **kwargs):
        if self.batches_per_epoch is None:
            self.batches_per_epoch = self.batch_counter

In [104]:
torch.manual_seed(1337)

ef = Trainer(module=ReconModel,
             optim=torch.optim.Adam,
             lr=0.005,
             max_epochs=60,
                  
             train_split=None,
             iterator_train=Enwik8TrainLoader,
             iterator_train__batch_size=32,
             iterator_train__num_steps=32,
             iterator_test=Enwik8ValidLoader,
             iterator_test__batch_size=32,
             iterator_test__num_steps=32,
             
             module__num_modules=8,
             module__num_hidden=64,
             
             callbacks=[BatchPrinter()]
            )

In [105]:
%pdb on
ef.fit(torch.zeros((10,1)), torch.zeros((10,)))

Automatic pdb calling has been turned ON
  epoch    train_loss        dur)..
-------  ------------  ---------
      1        [36m2.3305[0m  9805.0095
      2        [36m2.2666[0m  9877.3649
      3        [36m2.2161[0m  9878.4157
Batch 274438/87890 complete (0.1s).).

<__main__.Trainer at 0x7f4ac588fe48>