In [2]:
import inferno
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import TensorDataset
from torch.autograd import Variable

In [3]:
from sine_data import train_dataset, valid_dataset

In [4]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import insp

In [5]:
import visdom
vis = visdom.Visdom()

In [6]:
X_train, y_train = train_dataset(points=200)
X_train, y_train = torch.Tensor(X_train), torch.LongTensor(y_train)

X_valid, y_valid = valid_dataset(points=400)
X_valid, y_valid = torch.Tensor(X_valid), torch.LongTensor(y_valid)

sine_train_loader = DataLoader(TensorDataset(X_train, y_train),
                               batch_size=64,
                               shuffle=True)
sine_valid_loader = DataLoader(TensorDataset(X_valid, y_valid),
                               batch_size=64,
                               shuffle=False)

In [7]:
class InhCWRNN(nn.Module):
    def __init__(self, input_dim, output_dim, num_modules, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_modules = num_modules
        
        self.input_mod = nn.Linear(input_dim, output_dim)
        self.hidden_mod = nn.Linear(output_dim, output_dim, bias=False)
        
        #self.module_periods = nn.Parameter(torch.zeros(output_dim) + 1)
        #self.module_shifts = nn.Parameter(torch.zeros(output_dim))
        self.module_periods = nn.Parameter(torch.zeros(num_modules) + 1)
        self.module_shifts = nn.Parameter(torch.zeros(num_modules))
        
        self.f_mod = nn.Tanh()        
        
    def step(self, ti, xi, h):
        module_size = self.output_dim // self.num_modules
        
        acts = self.f_mod(self.input_mod(xi) + self.hidden_mod(h))
        module_acts = acts.view(-1, self.num_modules, module_size)
                
        gate = torch.sin(ti * self.module_periods + self.module_shifts)
        gate = gate.view(1, -1, 1).expand_as(module_acts).contiguous()
        gate = gate.view(-1, self.output_dim)
                
        y = (1 - gate) * acts + gate * h
        
        return y, y
        
    def init_hidden(self):
        return Variable(torch.zeros(self.output_dim))
        
    def forward(self, x):
        t = x.size(1)
        ys = []
        h = self.init_hidden()
        for ti in range(t):
            xi = x[:, ti]
            yi, h = self.step(ti, xi, h)            
            ys.append(yi)
        return torch.stack(ys, dim=1), h

In [8]:
def time_flatten(t):
    return t.view(t.size(0) * t.size(1), -1)

def time_unflatten(t, s):
    return t.view(s[0], s[1], -1)

In [9]:
class ReconModel(nn.Module):
    def __init__(self, num_hidden=64, num_modules=8):
        super().__init__()
        
        self.rnn = InhCWRNN(1, num_hidden, num_modules)
        self.clf = nn.Linear(num_hidden, 1)
        
    def forward(self, x):
        l0, h0 = self.rnn(x)
        
        vis.heatmap(l0[0].data.numpy(), win="act")
        vis.heatmap(self.rnn.module_periods.data.numpy().reshape(1, -1), win="periods")
        vis.heatmap(self.rnn.module_shifts.data.numpy().reshape(1, -1), win="shifts")

        l1 = self.clf(time_flatten(l0))
        return time_unflatten(l1, x.size())

In [10]:
class Trainer(inferno.NeuralNet):
    def __init__(self, 
                 criterion=nn.MSELoss,
                 *args, 
                 **kwargs):
        super().__init__(*args, criterion=criterion, **kwargs)

    def get_loss(self, y_pred, y_true, X=None, train=False):
        pred = time_flatten(y_pred)
        true = time_flatten(y_true)
        return super().get_loss(pred, true, X=X, train=train)

### exp inhibition

In [None]:
torch.manual_seed(1337)

def my_train_split(X, y):
    return X, X_valid[:, :-1], y, X_valid[:, 1:]

ef_relu = Trainer(module=ReconModel,
             optim=torch.optim.Adam,
             lr=0.005,
             max_epochs=30,
             train_split=my_train_split,
             
             module__num_modules=32,
             module__num_hidden=64,
            )

In [None]:
%pdb on
ef_relu.fit(X_train[:, :-1], X_train[:, 1:])

Automatic pdb calling has been turned ON
  epoch    train_loss    valid_loss     dur
-------  ------------  ------------  ------
      1        [36m1.8167[0m        [32m0.8152[0m  0.5924
      2        [36m0.8063[0m        [32m0.3911[0m  0.6822
      3        [36m0.3626[0m        [32m0.2530[0m  0.6114
      4        [36m0.2260[0m        [32m0.1849[0m  0.7582
      5        [36m0.1709[0m        [32m0.1122[0m  0.6179
      6        [36m0.1120[0m        [32m0.0763[0m  0.6032
      7        [36m0.0816[0m        0.0988  0.5168
      8        0.1019        0.1276  0.6145
      9        0.1278        0.1145  0.6470
     10        0.1150        [32m0.0695[0m  0.7311
     11        [36m0.0724[0m        [32m0.0284[0m  0.6749
     12        [36m0.0337[0m        [32m0.0147[0m  0.6920
     13        [36m0.0209[0m        0.0253  0.7134
     14        0.0307        0.0415  0.5658
     15        0.0451        0.0483  0.6004
     16        0.0502        0.0440  0.5

In [None]:
data = X_valid
pred = ef_relu.predict_proba(data)

for i in range(pred.shape[0]):
    plt.figure(i)
    plt.plot(data[i].numpy())
    plt.plot(np.arange(len(pred[i])), pred[i])