In [649]:
import torch
from torch import nn
from torch.nn import functional as F
from matplotlib import pyplot as plt
import math
import random
import sys
from numpy import newaxis

In [650]:
class RNNModel(nn.Module):
    def __init__(self, rnn_layer, num_inputs, **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        self.rnn = rnn_layer
        self.num_inputs = num_inputs
        self.num_hiddens = self.rnn.hidden_size
        self.num_directions = 1
        self.linear = nn.Linear(self.num_hiddens, self.num_inputs)

    def forward(self, inputs, state):
        X = inputs 
        X = X.to(torch.float32)
        Y, state = self.rnn(X, state)
        output = self.linear(Y.reshape((-1, Y.shape[-1])))
        return output, state

    def begin_state(self, device, batch_size):
        return (torch.zeros((self.num_directions * self.rnn.num_layers,
                             batch_size, self.num_hiddens),
                            device=device, requires_grad=True),
                torch.zeros((self.num_directions * self.rnn.num_layers,
                             batch_size, self.num_hiddens),
                            device=device, requires_grad=True))

In [664]:
num_inputs = 1
num_hiddens = 8
lstm_layer = nn.LSTM(num_inputs, num_hiddens)

In [665]:
model = RNNModel(lstm_layer, num_inputs)
device = torch.device('cpu')
model = model.to(device)

In [666]:
loss = nn.MSELoss()
num_epochs = 10000

In [667]:
lr = 5e-2
updater = torch.optim.SGD(model.parameters(), lr)

In [668]:
X = []
Y = []
for i in range(100):
    X.append(round(math.sin(0.1*i), 1))
    Y.append(round(math.sin(0.1*(i+1)), 1))

In [656]:
def seq_data_iter_sequential():
    for i in range(32):
        offset1 = random.randint(0, 90)
        X_input = torch.tensor([X[offset1:offset1+10]], requires_grad=True).T
        Y_input = torch.tensor([Y[offset1:offset1+10]], requires_grad=True).T
        X_input = X_input[:,:,newaxis]
        Y_input = Y_input[:,:,newaxis]
        yield X_input, Y_input

class SeqDataLoader:
    def __iter__(self):
        return seq_data_iter_sequential()

In [669]:
train_iter = SeqDataLoader()

In [658]:
num_directions = 1
num_layers = 1
num_hiddens = 8

In [670]:
def train_epoch(net, train_iter, loss, updater, device):
    state = None
    for X_input, Y_input in train_iter:
        if state is None:
            state = net.begin_state(device=device, batch_size = 1)       
        else:
            for s in state:
                s.detach_()   
        y = Y_input.reshape(-1)
        X_input, y = X_input.to(device), y.to(device)
        y_hat, state = net(X_input, state)
        l = loss(y_hat, y)
        updater.zero_grad()
        l.backward()
        updater.step()        

In [671]:
def predict_one_input(prefix, num_preds, net, device):
    state = net.begin_state(batch_size=1, device=device)    
    outputs = [prefix[0]]
    get_input = lambda: torch.reshape(torch.tensor([outputs[-1]], device=device),
                                    (1, 1, 1))
    for y in prefix[1:]:  
        _, state = net(get_input(), state)
        outputs.append(y)
    for _ in range(num_preds):  
        y, state = net(get_input(), state)        
        outputs.append(round(y.item(),1))
    return outputs

In [672]:
predict = lambda prefix: predict_one_input(prefix, 5, model, device)

In [673]:
predict_input = [0.1,0.0,-0.1,-0.2,-0.3]

In [675]:
num_epochs = 5000

In [676]:
for epoch in range(num_epochs):
    train_epoch(model, train_iter, loss, updater, device)
    if (epoch + 1) % 100 == 0:
        print(predict(predict_input))

[0.1, 0.0, -0.1, -0.2, -0.3, -0.4, -0.5, -0.5, -0.4, -0.2]
[0.1, 0.0, -0.1, -0.2, -0.3, -0.4, -0.5, -0.5, -0.4, -0.3]
[0.1, 0.0, -0.1, -0.2, -0.3, -0.5, -0.5, -0.5, -0.4, -0.3]
[0.1, 0.0, -0.1, -0.2, -0.3, -0.4, -0.4, -0.3, -0.1, 0.2]
[0.1, 0.0, -0.1, -0.2, -0.3, -0.4, -0.5, -0.5, -0.4, -0.3]
[0.1, 0.0, -0.1, -0.2, -0.3, -0.5, -0.6, -0.6, -0.6, -0.6]
[0.1, 0.0, -0.1, -0.2, -0.3, -0.4, -0.4, -0.4, -0.3, -0.2]
[0.1, 0.0, -0.1, -0.2, -0.3, -0.4, -0.4, -0.4, -0.3, -0.2]
[0.1, 0.0, -0.1, -0.2, -0.3, -0.5, -0.5, -0.5, -0.5, -0.5]
[0.1, 0.0, -0.1, -0.2, -0.3, -0.4, -0.4, -0.4, -0.4, -0.4]
[0.1, 0.0, -0.1, -0.2, -0.3, -0.4, -0.4, -0.4, -0.4, -0.4]
[0.1, 0.0, -0.1, -0.2, -0.3, -0.4, -0.4, -0.4, -0.4, -0.4]
[0.1, 0.0, -0.1, -0.2, -0.3, -0.5, -0.6, -0.6, -0.6, -0.5]
[0.1, 0.0, -0.1, -0.2, -0.3, -0.4, -0.4, -0.4, -0.4, -0.4]
[0.1, 0.0, -0.1, -0.2, -0.3, -0.4, -0.4, -0.4, -0.4, -0.4]
[0.1, 0.0, -0.1, -0.2, -0.3, -0.4, -0.4, -0.4, -0.4, -0.4]
[0.1, 0.0, -0.1, -0.2, -0.3, -0.4, -0.4, -0.4, -0.4, -0.4

In [678]:
print(predict([0.1,0,-0.1.8,0.9,0.9]))

[0.7, 0.8, 0.8, 0.9, 0.9, 0.9, 0.9, 0.9, 0.8, 0.8]
