In [71]:
import torch
import torch.nn as nn
import torch.optim as optim
class SRUCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(SRUCell, self).__init__()
        self.hidden_size = hidden_size
        # Initialize weight matrices
        self.W = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_prime = nn.Parameter(torch.Tensor(input_size, hidden_size))
        self.W_double_prime = nn.Parameter(torch.Tensor(input_size, hidden_size))
        # Initialize bias vectors
        self.b = nn.Parameter(torch.Tensor(hidden_size))
        self.b_prime = nn.Parameter(torch.Tensor(hidden_size))
        # Initialize parameter vectors for the gates
        self.v = nn.Parameter(torch.Tensor(hidden_size))
        self.v_prime = nn.Parameter(torch.Tensor(hidden_size))

        self.reset_parameters()

    def reset_parameters(self):
        # Properly initialize the parameters, typically with a small standard deviation for weights
        nn.init.normal_(self.W, 0, 0.1)
        nn.init.normal_(self.W_prime, 0, 0.1)
        nn.init.normal_(self.W_double_prime, 0, 0.1)
        nn.init.normal_(self.v, 0, 0.1)
        nn.init.normal_(self.v_prime, 0, 0.1)
        nn.init.constant_(self.b, 0)
        nn.init.constant_(self.b_prime, 0)

    def forward(self, x, c_previous=None):
        if c_previous is None:
            c_previous = torch.zeros(x.size(0), self.hidden_size, device=x.device)
        # Compute the forget and reset gates
        f_t = torch.sigmoid(x @ self.W + self.v * c_previous + self.b)
        r_t = torch.sigmoid(x @ self.W_prime + self.v_prime * c_previous + self.b_prime)
        # Compute the new cell state
        c_t = f_t * c_previous + (1 - f_t) * (x @ self.W_double_prime)
        # Compute the new hidden state
        h_t = r_t * c_t + (1 - r_t) * x
        return h_t, c_t

In [72]:
# Corrected SRU network class that includes hidden_size as an attribute
class SRU(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SRU, self).__init__()
        self.hidden_size = hidden_size
        self.sru_cell = SRUCell(input_size, hidden_size)
        self.linear = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        c_t = torch.zeros(x.size(0), self.hidden_size, device=x.device)
        outputs = []
        for i in range(x.size(1)):
            h_t, c_t = self.sru_cell(x[:, i], c_t)
            outputs.append(h_t)
        out = self.linear(outputs[-1])
        return out

def create_dataset(n):
    x = torch.arange(1., n+1).view(-1, 1) # Creates a 2D tensor with values from 1 to n
    y = torch.arange(2., n+2).view(-1, 1) # Creates a 2D tensor with values from 2 to n+1
    return x, y



In [73]:
# Model instantiation
input_size = 1
hidden_size = 50
output_size = 1
model = SRU(input_size, hidden_size, output_size)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters())

# Create dataset for a simple counting task
x_train, y_train = create_dataset(100)

In [74]:
# Training loop
epochs = 5000
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    output = model(x_train.unsqueeze(1))
    loss = criterion(output, y_train)
    loss.backward()
    optimizer.step()
    if epoch % 100 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')

Epoch 0, Loss: 141.85227966308594
Epoch 100, Loss: 0.44399428367614746
Epoch 200, Loss: 0.10344808548688889
Epoch 300, Loss: 0.047301020473241806
Epoch 400, Loss: 0.026141751557588577
Epoch 500, Loss: 0.017730966210365295
Epoch 600, Loss: 0.014206832274794579
Epoch 700, Loss: 0.012348949909210205
Epoch 800, Loss: 0.011044936254620552
Epoch 900, Loss: 0.009982094168663025
Epoch 1000, Loss: 0.009069297462701797
Epoch 1100, Loss: 0.008265282027423382
Epoch 1200, Loss: 0.0075443340465426445
Epoch 1300, Loss: 0.006890471559017897
Epoch 1400, Loss: 0.006295027211308479
Epoch 1500, Loss: 0.005753648933023214
Epoch 1600, Loss: 0.005264329258352518
Epoch 1700, Loss: 0.004825145471841097
Epoch 1800, Loss: 0.004433753434568644
Epoch 1900, Loss: 0.004086928442120552
Epoch 2000, Loss: 0.003780721453949809
Epoch 2100, Loss: 0.0035107627045363188
Epoch 2200, Loss: 0.003272715490311384
Epoch 2300, Loss: 0.0030623087659478188
Epoch 2400, Loss: 0.002875684294849634
Epoch 2500, Loss: 0.002709328196942806

In [75]:
model.eval()
test_input = torch.tensor([[101.0]])
predicted_output = model(test_input.unsqueeze(1))
predicted_output.item()  # This should be close to 102 if the model has learned correctly

102.01143646240234