#### Learning the Fibonacci numbers

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm

In [2]:
def fib(n):
    
    assert n > 0, 'n must be > 0'
    if n <= 2:
        return 1
    else:
        return fib(n - 1) + fib(n - 2)
    
# Generate fibonacci sequence
data = [fib(n) for n in range(1, 10)]

Define the RNN model

In [3]:
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNN, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x, h=None):
        out, h = self.rnn(x, h)
        out = self.fc(out)
        
        return out, h

Design the input and target sequences

In [2]:
input_seq = torch.as_tensor(data[:-1], dtype=torch.float).view(-1, 1, 1)
target_seq = torch.as_tensor(data[1:], dtype=torch.float).view(-1, 1, 1)

NameError: name 'data' is not defined

Define the parameters

In [5]:
# Hyperparameters
input_size = 1
hidden_size = 8
output_size = 1
num_epochs = 250
learning_rate = 0.1

# Model, Loss, and Optimizer
model = SimpleRNN(input_size, hidden_size, output_size)
loss_func = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in tqdm(range(num_epochs)):
    optimizer.zero_grad()
    output, _ = model(input_seq)
    loss = loss_func(output, target_seq)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 50 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

  0%|          | 0/250 [00:00<?, ?it/s]

Epoch [50/250], Loss: 5.6069
Epoch [100/250], Loss: 0.1186
Epoch [150/250], Loss: 0.1228
Epoch [200/250], Loss: 0.0918
Epoch [250/250], Loss: 0.1040


Check out the predictions

In [6]:
with torch.no_grad():
    pred_seq, _ = model(input_seq)

for current, actual, pred in zip(input_seq, target_seq, pred_seq):
    print(f"Current: {current.item():.2f}, Next: {actual.item():.2f}, Predicted: {pred.item():.2f}")


Current: 1.00, Next: 1.00, Predicted: 1.47
Current: 1.00, Next: 2.00, Predicted: 1.47
Current: 2.00, Next: 3.00, Predicted: 3.33
Current: 3.00, Next: 5.00, Predicted: 4.73
Current: 5.00, Next: 8.00, Predicted: 8.10
Current: 8.00, Next: 13.00, Predicted: 13.15
Current: 13.00, Next: 21.00, Predicted: 21.17
Current: 21.00, Next: 34.00, Predicted: 34.02
