In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

![image](https://i.imgur.com/Xp0Jmr1.png)

In [2]:
class LSTM(nn.Module):
    
    def __init__(self):
        super().__init__()
        mean = torch.tensor(0.0)
        std = torch.tensor(1.0)

        # Forget gate (blue - b)
        self.wbh = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.wbi = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.bb = nn.Parameter(torch.tensor(0.), requires_grad=True)

        # Update gate (green - g)
        self.wgh = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.wgi = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.bg = nn.Parameter(torch.tensor(0.), requires_grad=True)

        # Update gate (yellow - y)
        self.wyh = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.wyi = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.by = nn.Parameter(torch.tensor(0.), requires_grad=True)
    
        # Output gate (purple - p)
        self.wph = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.wpi = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.bp = nn.Parameter(torch.tensor(0.), requires_grad=True)

    def _lstm_unit(self, input_val, long_mem, short_mem):
        # Blue part
        long_mem_forget = torch.sigmoid((short_mem * self.wbh) +
                                        (input_val * self.wbi) +
                                        self.bb)
        # Green part
        long_mem_rememb = torch.sigmoid((short_mem * self.wgh) +
                                        (input_val * self.wgi) +
                                        self.bg)
        # Yellow part
        long_mem_update = torch.tanh((short_mem * self.wyh) +
                                     (input_val * self.wyi) +
                                     self.by)
        # Update long_mem
        long_mem = (long_mem * long_mem_forget) + (long_mem_rememb * long_mem_update)

        # Purple part
        short_mem_rememb = torch.sigmoid((short_mem * self.wph) +
                                         (input_val * self.wpi) +
                                         self.bp)
        # Update short_mem
        short_mem = torch.tanh(long_mem) * short_mem_rememb

        return [long_mem, short_mem]

    def forward(self, x):
        long_mem = 0
        short_mem = 0
        input_t1 = x[0]
        input_t2 = x[1]
        input_t3 = x[2]
        input_t4 = x[3]

        # Unroll
        long_mem, short_mem = self._lstm_unit(input_t1, long_mem, short_mem)
        long_mem, short_mem = self._lstm_unit(input_t2, long_mem, short_mem)
        long_mem, short_mem = self._lstm_unit(input_t3, long_mem, short_mem)
        long_mem, short_mem = self._lstm_unit(input_t4, long_mem, short_mem)

        return short_mem

## Demo

In [3]:
x = torch.tensor([1, 0.5, 0.3, 0.6])
y = torch.tensor(1.)

In [4]:
model = LSTM()
criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=0.1)

In [5]:
model.train()

for step in range(10):
    output = model(x)
    loss = criterion(output, y)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(f"Step {step}: {output}")

Step 0: -0.168497234582901
Step 1: -0.09421718120574951
Step 2: -0.04200140759348869
Step 3: -0.005702103953808546
Step 4: 0.019625108689069748
Step 5: 0.037882205098867416
Step 6: 0.052193496376276016
Step 7: 0.06497599184513092
Step 8: 0.07804455608129501
Step 9: 0.0928073450922966


## Demo 2

In [6]:
x1 = torch.tensor([1, 0.5, 0.3, 0.6])
y1 = torch.tensor(1.)

x2 = torch.tensor([0, 0.5, 0.3, 0.6])
y2 = torch.tensor(0.)

In [7]:
model = LSTM()
criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=0.03)

In [8]:
model.train()

for step in range(200):
    output1 = model(x1)
    output2 = model(x2)
    loss1 = criterion(output1, y1)
    loss2 = criterion(output2, y2)
    loss = loss1 + loss2 
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 10 == 0:
        print(f"Step {step}: {output1:.3f} {output2:.3f}")

Step 0: 0.001 0.001
Step 10: 0.311 0.309
Step 20: 0.666 0.648
Step 30: 0.523 0.513
Step 40: 0.469 0.458
Step 50: 0.556 0.530
Step 60: 0.549 0.503
Step 70: 0.570 0.475
Step 80: 0.634 0.399
Step 90: 0.795 0.110
Step 100: 0.967 0.032
Step 110: 0.983 0.032
Step 120: 0.986 0.002
Step 130: 0.987 -0.010
Step 140: 0.988 0.008
Step 150: 0.988 -0.004
Step 160: 0.988 0.003
Step 170: 0.988 -0.002
Step 180: 0.988 0.001
Step 190: 0.988 -0.000
