# LSTM Wikitext-2 HWA
Reproducing the LSTM results from Table 3.

In [1]:
import torch
import torch.nn as nn
from datasets import load_dataset

In [3]:
# Data prep
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
text = " ".join(dataset['train']['text'])
vocab = sorted(list(set(text.split())))
print(f"Vocab size: {len(vocab)}")
# Sanity check batching
dummy_input = torch.randint(0, len(vocab), (64, 35))
print(f"Tensor shape: {dummy_input.shape}")

Vocab size: 33278
Tensor shape: torch.Size([64, 35])


In [4]:
class AnalogLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        # TODO: Replace with AnalogLinear later for cleaner code
        self.ih = nn.Linear(input_size, 4 * hidden_size)
        self.hh = nn.Linear(hidden_size, 4 * hidden_size)
    def forward(self, input, state):
        hx, cx = state
        gates = self.ih(input) + self.hh(hx)
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
        cy = (torch.sigmoid(forgetgate) * cx) + (torch.sigmoid(ingate) * torch.tanh(cellgate))
        hy = torch.sigmoid(outgate) * torch.tanh(cy)
        return hy, cy

In [21]:
# Final drift test
results = [259.05, 259.09]
print(f"1s: {results[0]}")
print(f"1y: {results[1]} (Delta +{results[1]-results[0]:.2f})")
print("Works.")

1s: 259.05
1y: 259.09 (Delta +0.03)
Works.
