In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
### PyTorch options.
device = torch.device("cpu") # "cpu", "cuda:0".

$$ h_n = \sigma \left(x W_1   + h_{n-1} W_2 + b1\right) $$
$$ y_n = \sigma \left(h_n W_3 + y_{n-1} W_4 + b2\right) $$

In [None]:
### Parameters.
size_batch = 1
size_x     = 1
size_h     = 2

x = torch.Tensor(size_batch, size_x, device=device)
h1 = torch.Tensor(size_batch, size_h, device=device)
h2 = torch.Tensor(size_batch, 1     , device=device)

w1 = torch.randn(size_x    , size_h, device=device, requires_grad=True)
w2 = torch.randn(size_h    , size_h, device=device, requires_grad=True)
b1 = torch.randn(size_batch, size_h, device=device, requires_grad=True)

w3 = torch.randn(size_h    ,1, device=device, requires_grad=True)
w4 = torch.randn(1         ,1, device=device, requires_grad=True)
b2 = torch.randn(size_batch,1, device=device, requires_grad=True)

In [None]:
### Data.
T = torch.linspace(0, 100*3.141592653589793, 1000, device=device)
X = T.sin()

batches = torch.Tensor(device=device)
for i in range(len(X)//size_batch):
    batches = torch.cat((
        batches,
        X[i*size_batch: (i+1)*size_batch].reshape(1,size_batch,1)
    ))

In [None]:
### Training.
lr     = 1e-1
epochs = 1

output      = torch.Tensor(device=device)
output_loss = torch.Tensor(device=device)

# Main loop.
for _ in range(epochs):
    h1.zero_()
    h2.zero_()
    for x in batches:

        # Forward.
        y1 = ( x.mm(w1) + h1.mm(w2) + b1).tanh()
        y2 = (y1.mm(w3) + h2.mm(w4) + b2).tanh()

        loss = (x[:,0] - y2[:,0]).pow(2).sum()

        output      = torch.cat((output     , y2[:,0].reshape(size_batch,1).data), 1)
        output_loss = torch.cat((output_loss, loss.reshape(1).data),               0)

        # Backward.
        loss.backward()

        # Update.
        with torch.no_grad():
            w1 -= lr*w1.grad
            w2 -= lr*w2.grad
            w3 -= lr*w3.grad
            w4 -= lr*w4.grad
            b1 -= lr*b1.grad
            b2 -= lr*b2.grad

            w1.grad.zero_()
            w2.grad.zero_()
            w3.grad.zero_()
            w4.grad.zero_()
            b1.grad.zero_()
            b2.grad.zero_()

            # RNN Magic.
            h1.data = y1
            h2.data = y2

In [None]:
p = slice(0,299)

plt.figure(figsize=(15,8))
plt.subplot(2,1,1)
plt.plot(T[p], X[p])
plt.plot(T[p], output.reshape(1000)[p])
plt.title("Input, Output")

plt.subplot(2,1,2)
plt.plot(T[p], (X*X*4)[p])
plt.plot(T[p], output_loss[p])
plt.title("Loss")
pass