In [63]:
import torch
import torch.nn as nn
import torch.jit as jit

from models import LFADS_GRUCell, LFADS_GenGRUCell
    
class LFADS_GenGRULoop(nn.Module):
    def __init__(self, cell):
        super(LFADS_GenGRULoop, self).__init__()
        self.cell = cell
        
    def forward(self, input, initial_state):
        prev_state = initial_state.clone()
        state = initial_state.unsqueeze(0)
        for t in range(len(input)):
            new_state = self.cell(input[t], prev_state)
            state = torch.cat((state, new_state.clone().unsqueeze(0)), dim=0)
            prev_state = new_state
        return state[1:], prev_state
                
class LFADS_GeneratorCell(nn.Module):
    def __init__(self, generator_size, factor_size, encoder_size, controller_size, perturbation_size, dropout=0.0):
        super(LFADS_Generator, self).__init__()
        
        self.encoder_size      = encoder_size
        self.perturbation_size = perturbation_size
        self.factor_size       = factor_size
        self.generator_size    = generator_size
        self.controller_size   = controller_size
        
        if self.perturbation_size>0:
            self.controller   = LFADS_GenGRUCell(input_size= self.encoder_size * 2 + self.factor_size, hidden_size= self.controller_size)
            self.con_to_umean   = nn.Linear(in_features=self.controller_size, out_features=self.perturbation_size)
            self.con_to_ulogvar = nn.Linear(in_features=self.controller_size, out_features=self.perturbation_size)
            
        self.dropout = nn.Dropout(dropout)
        self.generator = LFADS_GenGRUCell(input_size=self.perturbation_size, hidden_size=self.generator_size)
        self.gen_to_fac = nn.Linear(in_features=self.generator_size, out_features=self.factor_size)
        
    def forward(self, input, state):
        gen, fac, con = state
        enc, eps = input        
        if self.perturbation_size > 0:
            con     = self.controller(self.dropout(torch.cat((enc, fac), dim=1)), con)
            umean   = self.con_to_umean(con)
            ulogvar = self.con_to_ulogvar(con)
            u       = umean + eps*torch.exp(0.5*ulogvar)
        
        else:
            con = None
            u = None
            
        gen     = self.generator(u, gen)
        fac     = self.gen_to_fac(self.dropout(gen))
        return gen, fac, con
    
# class LFADS_GeneratorLoop(nn.Module)

In [17]:
device = 'cuda'

In [84]:
G = LFADS_Generator(generator_size=100, factor_size=20, encoder_size=100, controller_size=100, perturbation_size=1, dropout=0.05).to(device)

In [3]:
gen = torch.randn(20, 100).to(device)
fac = torch.randn(20, 20).to(device)
enc = torch.randn(20, 200).to(device)
con = torch.randn(20, 100).to(device)
eps = torch.randn(20, 1).to(device)

state = (gen, fac, con)
input = (enc, eps)

# tG = jit.trace(G, example_inputs=((input,state)), check_trace=False)

In [86]:
# %timeit G(input, state)

1.15 ms ± 16.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [87]:
# %timeit tG(input, state)

546 µs ± 5.43 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [88]:
H = LFADS_GRUCell(input_size=200, hidden_size=100).to(device)

In [89]:
tH = jit.trace(H, example_inputs=((enc, con)))

In [90]:
%timeit H(enc, con)

374 µs ± 4.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [91]:
%timeit tH(enc, con)

235 µs ± 1.21 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [4]:
J = LFADS_GenGRUCell(input_size=200, hidden_size=100).to(device)

tJ = jit.trace(J, example_inputs = ((enc, con)))

In [5]:
%timeit J(enc[:50], con[:50])

398 µs ± 1.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [6]:
%timeit tJ(enc[:50], con[:50])

210 µs ± 36.5 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [69]:
G = LFADS_GenGRUCell(input_size=200, hidden_size=100).to(device)
tG = jit.trace(G, example_inputs = ((enc, con)))
GL = LFADS_GenGRULoop(G)

In [65]:
enc_T = torch.randn(100, 20, 200).to(device)

In [67]:
%timeit GL(enc_T, con)

48.2 ms ± 251 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [71]:
tGL = LFADS_GenGRULoop(tG)

In [80]:
%timeit tGL(enc_T, con)

26.8 ms ± 162 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [76]:
stGL = jit.trace(LFADS_GenGRULoop(tG), example_inputs=((enc_T, con)))

  from ipykernel import kernelapp as app


In [79]:
%timeit stGL(enc_T, con)

22 ms ± 137 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [83]:
torch.jit.script(LFADS_GenGRUCell(input_size=200, hidden_size=100))

TypeError: LFADS_GenGRUCell(
  (fc_x_ru): Linear(in_features=200, out_features=200, bias=False)
  (fc_x_c): Linear(in_features=200, out_features=100, bias=False)
  (fc_h_ru): Linear(in_features=100, out_features=200, bias=True)
  (fc_rh_c): Linear(in_features=100, out_features=100, bias=True)
) is not a module, class, method, function, traceback, frame, or code object

In [85]:
torch.lgamma(torch.Tensor([10]))

tensor([12.8018])

In [86]:
inbuilt_ll = torch.nn.PoissonNLLLoss

[0;31mInit signature:[0m [0mtorch[0m[0;34m.[0m[0mnn[0m[0;34m.[0m[0mPoissonNLLLoss[0m[0;34m([0m[0mlog_input[0m[0;34m=[0m[0;32mTrue[0m[0;34m,[0m [0mfull[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m [0msize_average[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0meps[0m[0;34m=[0m[0;36m1e-08[0m[0;34m,[0m [0mreduce[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0mreduction[0m[0;34m=[0m[0;34m'mean'[0m[0;34m)[0m[0;34m[0m[0m
[0;31mSource:[0m        
[0;32mclass[0m [0mPoissonNLLLoss[0m[0;34m([0m[0m_Loss[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;34mr"""Negative log likelihood loss with Poisson distribution of target.[0m
[0;34m[0m
[0;34m    The loss can be described as:[0m
[0;34m[0m
[0;34m    .. math::[0m
[0;34m        \text{target} \sim \mathrm{Poisson}(\text{input})[0m
[0;34m[0m
[0;34m        \text{loss}(\text{input}, \text{target}) = \text{input} - \text{target} * \log(\text{input})[0m
[0;34m                         