In [57]:
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.jit as jit
import warnings
from collections import namedtuple
from typing import List, Tuple
from torch import Tensor
import numbers

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [58]:
RNNState = namedtuple('RNNState', ['r', 'I'])

In [86]:
class NeuroRNNCell(jit.ScriptModule):
    
    def __init__(self, input_size, hidden_size, alpha_r, alpha_s, nonlinearity, bias=False, ratio=None):
        
        super(NeuroRNNCell, self).__init__()
        self.input_size = input_size
        self.nonlinearity = nonlinearity
        self.hidden_size = hidden_size 
        self.alpha_r = alpha_r
        self.alpha_s = alpha_s
        self.bias = bias 
        self.Win = Parameter(torch.Tensor(hidden_size, input_size))
        self.Wrec = Parameter(torch.Tensor(hidden_size, hidden_size))
        self.ratio = ratio
        
        if bias:
            self.bin = Parameter(torch.Tensor(hidden_size, hidden_size))
            self.brec = Parameter(torch.Tensor(hidden_size))
            
        # init weights 
        nn.init.orthogonal_(self.Win)
        nn.init.orthogonal_(self.Wrec)
        
    @jit.script_method
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        r, I = state 
        
        I = (1-self.alpha_s)*I + self.alpha_s*(torch.mm(self.Wrec, r) + torch.mm(self.Win, input.T))
        r = (1-self.alpha_r)*r + self.alpha_r*(torch.tanh(I))
        
        return r, (r, I) 

In [138]:
class NeuroRNNLayer(jit.ScriptModule):
    def __init__(self, cell, *cell_args):
        super(NeuroRNNLayer, self).__init__()
        self.cell = cell(*cell_args)
        
        
    @jit.script_method
    def forward(self, input, state):
        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
        inputs = input.unbind(1)
        outputs = torch.jit.annotate(List[Tensor], [])
        
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            print(out.shape)
            outputs += [out]
            
        return torch.stack(outputs).permute(2,0,1), state
    

In [139]:
seq_len = 4
batch = 10 
hidden_size = 5
input_size = 1

In [140]:
inp = torch.randn(batch, seq_len, input_size)
state = RNNState(torch.randn(hidden_size, batch), torch.randn(hidden_size, batch))
rnn = NeuroRNNLayer(NeuroRNNCell, input_size, hidden_size, 1, 1, 'relu')
out, out_state = rnn(inp, state)
print(out.shape)

[5, 10]
[5, 10]
[5, 10]
[5, 10]
torch.Size([10, 4, 5])
