In [1]:
import torch
from torch import nn

class SMRNN(nn.Module):
    def __init__(self,state_size,input_size,outputs) -> None:
        super().__init__()
        self.state = torch.zeros(state_size)
        self.ln1 = nn.Linear(input_size+state_size,state_size + outputs)
    
    def forward(self,inputs):
        y = torch.cat((self.state,inputs))
        x = self.ln1(y)
        self.state = x[:self.state.shape[0]]
        return x[self.state.shape[0]:]

    def cleanstate(self):
        self.state = torch.zeros_like(self.state)


In [4]:
a = SMRNN(3,3,3)
print(list(a.parameters()))

[Parameter containing:
tensor([[-0.2919, -0.0905,  0.2469,  0.1309, -0.2009,  0.3051],
        [ 0.0539, -0.3844,  0.2943, -0.2397, -0.1728, -0.1549],
        [ 0.0808,  0.1423,  0.3615,  0.2258,  0.0390,  0.1980],
        [ 0.1342, -0.3178,  0.2004, -0.2211,  0.2693, -0.0691],
        [ 0.0055, -0.0575,  0.1959, -0.0310,  0.1793, -0.3305],
        [-0.0303, -0.1960,  0.3597, -0.2945, -0.3965, -0.2889]],
       requires_grad=True), Parameter containing:
tensor([-0.4078, -0.3800, -0.1178, -0.3384,  0.1782, -0.0525],
       requires_grad=True)]


In [2]:
import torch.nn.functional as F
from torch import optim

a = SMRNN(3,3,3)
opt = optim.SGD(a.parameters(),lr=0.01)
loss_func = F.mse_loss
epochs = 100
step_size = 5
for i in range(epochs):
    prev_inpt = torch.zeros(3)
    result = torch.zeros(3)
    for j in range(step_size):
        inpt = torch.rand(3)
        y = a.forward(inpt)
        result += inpt
        prev_inpt = inpt
    a.cleanstate()
    loss = loss_func(y,result)
    print(loss)
    loss.backward()
    opt.step()
    opt.zero_grad()

tensor(5.8353, grad_fn=<MseLossBackward0>)
tensor(7.4490, grad_fn=<MseLossBackward0>)
tensor(6.2432, grad_fn=<MseLossBackward0>)
tensor(7.5064, grad_fn=<MseLossBackward0>)
tensor(5.4746, grad_fn=<MseLossBackward0>)
tensor(6.2934, grad_fn=<MseLossBackward0>)
tensor(5.5678, grad_fn=<MseLossBackward0>)
tensor(9.4560, grad_fn=<MseLossBackward0>)
tensor(10.6814, grad_fn=<MseLossBackward0>)
tensor(7.2727, grad_fn=<MseLossBackward0>)
tensor(6.9119, grad_fn=<MseLossBackward0>)
tensor(5.5695, grad_fn=<MseLossBackward0>)
tensor(3.9041, grad_fn=<MseLossBackward0>)
tensor(7.5735, grad_fn=<MseLossBackward0>)
tensor(6.2432, grad_fn=<MseLossBackward0>)
tensor(8.8434, grad_fn=<MseLossBackward0>)
tensor(5.1797, grad_fn=<MseLossBackward0>)
tensor(6.9900, grad_fn=<MseLossBackward0>)
tensor(3.9345, grad_fn=<MseLossBackward0>)
tensor(4.9002, grad_fn=<MseLossBackward0>)
tensor(3.7712, grad_fn=<MseLossBackward0>)
tensor(3.9045, grad_fn=<MseLossBackward0>)
tensor(3.6350, grad_fn=<MseLossBackward0>)
tensor(5.3