In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

In [2]:
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu:0'

In [3]:
TRAIN_STEPS = 4000
VALID_STEPS = TRAIN_STEPS // 2
X = torch.sin(torch.linspace(0, 25, TRAIN_STEPS + VALID_STEPS)) + torch.randn(TRAIN_STEPS + VALID_STEPS) * 0.3

In [4]:
fig, ax = plt.subplots(figsize=(12,10))
ax.plot(range(TRAIN_STEPS), X[:TRAIN_STEPS], alpha=1.0, label='X (train)')
ax.plot(range(TRAIN_STEPS, TRAIN_STEPS + VALID_STEPS), X[TRAIN_STEPS:], label='X (valid)')
ax.legend()

In [5]:
x_train = X[:TRAIN_STEPS]
x_valid = X[TRAIN_STEPS:]

In [6]:
N_STEPS_BACK = 30

In [7]:
class TimeSeriesDataset(Dataset):
    def __init__(self, series, n_steps):
        super().__init__()
        # store the time series
        self.series = series
        # store the number of steps back we go
        self.n_steps = n_steps

    def __len__(self):
        # The number of items we have here is the length of the
        # series - the number of steps - 1.
        # The 1 comes from the way we shift our X to create our y.
        return len(self.series) - self.n_steps - 1
    def __getitem__(self, idx):
        # Our X is the window of values ranging from idx to idx + self.n_steps
        # We unsqueeze here to create a `n_features` dimension.
        X = self.series[idx:idx+self.n_steps].unsqueeze(-1)
        # Our Y is X shifted by 1 in the future.
        Y = self.series[idx + 1:idx+self.n_steps + 1].unsqueeze(-1)
        return X, Y

In [8]:
ts_train = TimeSeriesDataset(x_train, N_STEPS_BACK)
ts_valid = TimeSeriesDataset(x_valid, N_STEPS_BACK)

In [9]:
x, y = ts_train[0]
x.squeeze(), y.squeeze()

In [10]:
train_dl = DataLoader(ts_train, batch_size=8, shuffle=True)
valid_dl = DataLoader(ts_valid, batch_size=8, shuffle=False)

In [11]:
class RNNCell(nn.Module):
    def __init__(self, n_inputs, n_outputs):
        super().__init__()
        # W_hh_t-1
        self.h = nn.Linear(n_outputs, n_outputs)
        # W_xx_t
        self.x = nn.Linear(n_inputs, n_outputs)

    def forward(self, input, hidden):
        return torch.tanh(self.h(hidden) + self.x(input))

In [12]:
r = RNNCell(1, 1)

In [13]:
# Fetch a batch of data
x, y = next(iter(train_dl))
# Get the first time step for the batch
inputs = x[:, 0, :]
# Create a hidden state that's of shape batch_size x output_dim
hidden = torch.zeros(inputs.shape[0], 1)

In [14]:
inputs

In [15]:
hidden

In [16]:
outputs = r(inputs, hidden)

In [17]:
outputs

In [18]:
class RNN(nn.Module):
    def __init__(self, n_inputs, n_outputs, cell):
        super().__init__()
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        self.cell = cell(n_inputs, n_outputs)
    
    def forward(self, inputs):
        outputs = []
        # Initialize the hidden state as a tensor of zeros
        hidden_state = torch.zeros(inputs.shape[0], self.n_outputs).to(DEVICE)
        n_steps = inputs.shape[1]
        # For each time step...
        for i in range(n_steps):
            # The hidden state is now the outputs of the x_t and h_t-1
            hidden_state = self.cell(inputs[:, i], hidden_state)
            # Store the hidden state
            outputs.append(hidden_state)
        
        # Return a tensor of the historical outputs and the final hidden state
        return torch.stack(outputs, dim=1), hidden_state

In [19]:
our_model = RNN(1, 1, RNNCell).to(DEVICE)
torch_model = nn.RNN(input_size=1, hidden_size=1, num_layers=1, batch_first=True).to(DEVICE)

In [20]:
seq, state = our_model(x.to(DEVICE))
seq.shape, state.shape

In [21]:
seq, state = torch_model(x.to(DEVICE))
seq.shape, state.shape

In [22]:
our_opt = optim.Adam(our_model.parameters())
torch_opt = optim.Adam(torch_model.parameters())

In [23]:
N_EPOCHS = 10
for i in range(N_EPOCHS):
    # train loop
    our_epoch_loss = 0.
    torch_epoch_loss = 0.
    n_batches = 0
    our_model.train()
    torch_model.train()
    for x, y in train_dl:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        seq, state = our_model(x)
        loss = F.mse_loss(seq, y)
        our_epoch_loss += loss.item()

        loss.backward()
        our_opt.step()
        our_opt.zero_grad()

        seq, state = torch_model(x)
        loss = F.mse_loss(seq, y)
        torch_epoch_loss += loss.item()

        loss.backward()
        torch_opt.step()
        torch_opt.zero_grad()

        n_batches += 1
    
    our_train_loss = our_epoch_loss/n_batches
    torch_train_loss = torch_epoch_loss/n_batches

    # valid loop
    our_epoch_loss = 0.
    torch_epoch_loss = 0.
    n_batches = 0
    our_model.eval()
    torch_model.eval()
    for x, y in valid_dl:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        with torch.no_grad():
            seq, state = our_model(x)
        loss = F.mse_loss(seq, y)
        our_epoch_loss += loss.item()

        with torch.no_grad():
            seq, state = torch_model(x)
        loss = F.mse_loss(seq, y)
        torch_epoch_loss += loss.item()

        torch_opt.step()
        torch_opt.zero_grad()

        n_batches += 1

    our_valid_loss = our_epoch_loss/n_batches
    torch_valid_loss = torch_epoch_loss/n_batches

    
    print(f"Epoch: {i}, Our train MSE: {our_epoch_loss/n_batches:.05f}, Torch train MSE: {torch_epoch_loss/n_batches:.05f}, Our valid MSE: {our_valid_loss/n_batches:.05f}, Torch valid MSE: {torch_valid_loss/n_batches:.05f}")


In [24]:
x, y = next(iter(train_dl))
x = x.to(DEVICE)
y = y.to(DEVICE)

In [25]:
%%timeit
our_model(x)

In [26]:
%%timeit
torch_model(x)

In [27]:
class OurGRUCell(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        # Combine Wxr and Whr
        self.R = nn.Linear(in_dim + out_dim, out_dim)
        # Combine Wxz and Whz
        self.Z = nn.Linear(in_dim + out_dim, out_dim)
        # Combine Whx and Whz
        self.H = nn.Linear(in_dim + out_dim, out_dim)

    def forward(self, inputs, hidden_state):
        xh = torch.cat((inputs, hidden_state), dim=-1)
        R = self.R(xh)
        Z = self.Z(xh)
        xrh = torch.cat((inputs, R * hidden_state), dim=-1)
        HC = torch.tanh(self.H(xrh))
        ht = Z * hidden_state + (1 - Z) * HC
        return ht


In [28]:
OUT_DIM = 3
gru = OurGRUCell(1, OUT_DIM)

In [29]:
x, y = next(iter(train_dl))

In [30]:
state = torch.zeros(x.shape[0], OUT_DIM)
timestep = x[:, 0, :]

In [31]:
gru(timestep, state)

In [49]:
our_model = RNN(1, 1, OurGRUCell).to(DEVICE)
torch_model = nn.GRU(input_size=1, hidden_size=1, num_layers=1, batch_first=True).to(DEVICE)

In [50]:
our_opt = optim.Adam(our_model.parameters())
torch_opt = optim.Adam(torch_model.parameters())

In [51]:
N_EPOCHS = 10
for i in range(N_EPOCHS):
    # train loop
    our_epoch_loss = 0.
    torch_epoch_loss = 0.
    n_batches = 0
    our_model.train()
    torch_model.train()
    for x, y in train_dl:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        seq, state = our_model(x)
        loss = F.mse_loss(seq, y)
        our_epoch_loss += loss.item()

        loss.backward()
        our_opt.step()
        our_opt.zero_grad()

        seq, state = torch_model(x)
        loss = F.mse_loss(seq, y)
        torch_epoch_loss += loss.item()

        loss.backward()
        torch_opt.step()
        torch_opt.zero_grad()

        n_batches += 1
    
    our_train_loss = our_epoch_loss/n_batches
    torch_train_loss = torch_epoch_loss/n_batches

    # valid loop
    our_epoch_loss = 0.
    torch_epoch_loss = 0.
    n_batches = 0
    our_model.eval()
    torch_model.eval()
    for x, y in valid_dl:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        with torch.no_grad():
            seq, state = our_model(x)
        loss = F.mse_loss(seq, y)
        our_epoch_loss += loss.item()

        with torch.no_grad():
            seq, state = torch_model(x)
        loss = F.mse_loss(seq, y)
        torch_epoch_loss += loss.item()

        torch_opt.step()
        torch_opt.zero_grad()

        n_batches += 1

    our_valid_loss = our_epoch_loss/n_batches
    torch_valid_loss = torch_epoch_loss/n_batches

    
    print(f"Epoch: {i}, Our train MSE: {our_epoch_loss/n_batches:.05f}, Torch train MSE: {torch_epoch_loss/n_batches:.05f}, Our valid MSE: {our_valid_loss/n_batches:.05f}, Torch valid MSE: {torch_valid_loss/n_batches:.05f}")


In [35]:
def linear_with_act(n_inputs, n_outputs, act):
    return nn.Sequential(nn.Linear(n_inputs, n_outputs), act())

class OurLSTMCell(nn.Module):
    def __init__(self, n_inputs, n_outputs):
        super().__init__()
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs

        # Input, Forget, and Output gates
        self.I = linear_with_act(n_inputs + n_outputs, n_outputs, nn.Sigmoid)
        self.F = linear_with_act(n_inputs + n_outputs, n_outputs, nn.Sigmoid)
        self.O = linear_with_act(n_inputs + n_outputs, n_outputs, nn.Sigmoid)

        # Candidate memory cell
        self.HC = linear_with_act(n_inputs + n_outputs, n_outputs, nn.Tanh)

    
    def forward(self, inputs, hidden_state, memory_cell):
        xh = torch.cat((inputs, hidden_state), dim=-1)
        xc = torch.cat((inputs, memory_cell), dim=-1)
        I = self.I(xh)
        F = self.F(xh)
        O = self.O(xh)
        Cth = self.HC(xc)
        Ct = F * memory_cell + I * Cth
        Ht = O * torch.tanh(Ct)

        return Ht, Ct

In [36]:
# create the LSTM cell
OUT_DIM = 3
lstm = OurLSTMCell(1, OUT_DIM)

In [37]:
# grab one batch of data
x, y = next(iter(train_dl))

In [38]:
state = torch.zeros(x.shape[0], OUT_DIM) # instantiate the initial hidden state
memory = torch.zeros_like(state) # instantiate the memory
timestep = x[:, 0, :] # isolate the first time step in the batch

In [39]:
# pass the timestep, initial state, and initial memory through the LSTM cell
# to obtain the hidden state and memory
H, C = lstm(timestep, state, memory)

In [40]:
H

In [41]:
C

In [42]:
class LSTM(nn.Module):
    def __init__(self, n_inputs, n_outputs):
        super().__init__()
        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        self.cell = OurLSTMCell(n_inputs, n_outputs)
    
    def forward(self, inputs):
        outputs = []
        # Initialize the hidden state and memory as a tensor of zeros
        hidden_state = torch.zeros(inputs.shape[0], self.n_outputs).to(DEVICE)
        memory = torch.zeros_like(hidden_state)
        n_steps = inputs.shape[1]
        # For each time step...
        for i in range(n_steps):
            # The hidden state is now the outputs of the x_t and h_t-1
            hidden_state, memory = self.cell(inputs[:, i], hidden_state, memory)
            # Store the hidden state
            outputs.append(hidden_state)
        
        # Return a tensor of the historical outputs and the final hidden state
        return torch.stack(outputs, dim=1), hidden_state

In [43]:
our_model = LSTM(1, 1).to(DEVICE)
torch_model = nn.LSTM(input_size=1, hidden_size=1, num_layers=1, batch_first=True).to(DEVICE)

In [44]:
our_opt = optim.Adam(our_model.parameters())
torch_opt = optim.Adam(torch_model.parameters())

In [45]:
N_EPOCHS = 10
for i in range(N_EPOCHS):
    # train loop
    our_epoch_loss = 0.
    torch_epoch_loss = 0.
    n_batches = 0
    our_model.train()
    torch_model.train()
    for x, y in train_dl:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        seq, state = our_model(x)
        loss = F.mse_loss(seq, y)
        our_epoch_loss += loss.item()

        loss.backward()
        our_opt.step()
        our_opt.zero_grad()

        seq, (state, memory) = torch_model(x)
        loss = F.mse_loss(seq, y)
        torch_epoch_loss += loss.item()

        loss.backward()
        torch_opt.step()
        torch_opt.zero_grad()

        n_batches += 1
    
    our_train_loss = our_epoch_loss/n_batches
    torch_train_loss = torch_epoch_loss/n_batches

    # valid loop
    our_epoch_loss = 0.
    torch_epoch_loss = 0.
    n_batches = 0
    our_model.eval()
    torch_model.eval()
    for x, y in valid_dl:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        with torch.no_grad():
            seq, state = our_model(x)
        loss = F.mse_loss(seq, y)
        our_epoch_loss += loss.item()

        with torch.no_grad():
            seq, (state, memory) = torch_model(x)
        loss = F.mse_loss(seq, y)
        torch_epoch_loss += loss.item()

        torch_opt.step()
        torch_opt.zero_grad()

        n_batches += 1

    our_valid_loss = our_epoch_loss/n_batches
    torch_valid_loss = torch_epoch_loss/n_batches

    
    print(f"Epoch: {i}, Our train MSE: {our_epoch_loss/n_batches:.05f}, Torch train MSE: {torch_epoch_loss/n_batches:.05f}, Our valid MSE: {our_valid_loss/n_batches:.05f}, Torch valid MSE: {torch_valid_loss/n_batches:.05f}")
