Associative retrieval task from [fast weights RNN](https://arxiv.org/abs/1610.06258):
>  To solve this task, a standard RNN
has to end up with hidden activities that somehow store all of the key-value pairs after the keys and
values are presented sequentially. This makes it a significant challenge for models only using slow
weights.

In [None]:
import math
import os
import random

import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy

In [None]:
class Dictionary:
    def __init__(self):
        alphabet = [chr(i) for i in range(ord("a"), ord("z") + 1)]
        numbers = [str(i) for i in range(10)]
        self.idx2word = alphabet + numbers + ["?"]
        self.word2idx = {w: i for i, w in enumerate(self.idx2word)}

    def __getitem__(self, key):
        if isinstance(key, int):
            return self.idx2word[key]
        elif isinstance(key, str):
            return self.word2idx[key]
        else:
            raise TypeError

    def __len__(self):
        return len(self.idx2word)

In [None]:
class AssocRetrievalDataset(Dataset):
    def __init__(self, filename: str, dictionary: Dictionary):
        with open(filename, "r") as f:
            data = f.readlines()
        data = [line.strip() for line in data]
        inputs, targets = zip(*[line.split(",") for line in data])
        inputs = [[dictionary[ch] for ch in inp] for inp in inputs]
        targets = [dictionary[target] for target in targets]
        self.inputs = torch.LongTensor(inputs)
        self.targets = torch.LongTensor(targets)

    def __getitem__(self, idx):
        inp = self.inputs[idx]
        target = self.targets[idx]
        return inp, target

    def __len__(self):
        return self.inputs.size(0)

In [None]:
class LayerNormLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight = nn.Parameter(
            torch.Tensor(input_size + hidden_size, 4 * hidden_size)
        )
        self.bias = nn.Parameter(torch.Tensor(4 * hidden_size))
        self.ln = nn.LayerNorm(4 * hidden_size)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1 / math.sqrt(self.hidden_size)
        self.weight.data.uniform_(-stdv, stdv)
        self.bias.data.uniform_(-stdv, stdv)

    def forward(self, x, hx):
        h, c = hx
        xh = torch.cat([x, h], dim=1)
        gates = self.ln(xh @ self.weight + self.bias)
        gates = torch.chunk(gates, 4, dim=1)
        f = torch.sigmoid(gates[0])
        i = torch.sigmoid(gates[1])
        o = torch.sigmoid(gates[2])
        g = torch.tanh(gates[3])
        c = f * c + i * g
        h = o * torch.tanh(c)
        return h, c

In [None]:
class Model(pl.LightningModule):
    def __init__(
        self,
        model_type: str = "lstm",
        embed_size: int = 100,
        num_cells: int = 50,
        hidden_size: int = 100,
        input_length: int = 8,
        gap_length: int = 2,
        data_path: str = os.getcwd(),
        batch_size: int = 128,
        num_workers: int = 4,
        lr: float = 1e-3,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.embed_size = embed_size
        self.num_cells = num_cells
        self.input_length = input_length
        self.gap_length = gap_length
        self.data_path = data_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.lr = lr

        self.dictionary = Dictionary()
        self.embed = nn.Embedding(len(self.dictionary), embed_size)
        if model_type == "lstm":
            self.rnn = nn.LSTMCell(embed_size, num_cells)
        elif model_type == "ln-lstm":
            self.rnn = LayerNormLSTMCell(embed_size, num_cells)
        self.h_init = nn.Parameter(0.01 * torch.randn(1, num_cells))
        self.c_init = nn.Parameter(0.01 * torch.randn(1, num_cells))
        self.fc1 = nn.Linear(num_cells, hidden_size)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(hidden_size, len(self.dictionary))
        self.log_softmax = nn.LogSoftmax(dim=1)
        self.loss = nn.NLLLoss()

    def forward(self, x):
        h = self.h_init.expand(x.size(1), -1)
        c = self.c_init.expand(x.size(1), -1)
        for x_t in x:
            x_t = self.embed(x_t)
            h, c = self.lstm(x_t, (h, c))
        y = self.relu(self.fc1(h))
        y = self.log_softmax(self.fc2(y))
        return y

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

    def prepare_data(self):
        def generate_task():
            input_abc = self.dictionary.idx2word[:-1]  # without "?"
            inputs = random.sample(input_abc, k=self.input_length)
            query_idx = random.randrange(self.input_length - 1)
            query = inputs[query_idx]
            target = inputs[query_idx + 1]
            task = "".join(inputs + self.gap_length * ["?"] + [query])
            return task, target

        train_path = os.path.join(self.data_path, "assoc_train.txt")
        with open(train_path, "w") as f:
            for _ in range(100000):
                task, target = generate_task()
                f.write(f"{task},{target}\n")

        valid_path = os.path.join(self.data_path, "assoc_valid.txt")
        with open("assoc_valid.txt", "w") as f:
            for _ in range(10000):
                task, target = generate_task()
                f.write(f"{task},{target}\n")

        test_path = os.path.join(self.data_path, "assoc_test.txt")
        with open("assoc_test.txt", "w") as f:
            for _ in range(20000):
                task, target = generate_task()
                f.write(f"{task},{target}\n")

    def train_dataloader(self):
        data_path = os.path.join(self.data_path, "assoc_train.txt")
        train_set = AssocRetrievalDataset(data_path, self.dictionary)
        train_loader = DataLoader(
            train_set,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )
        return train_loader

    def val_dataloader(self):
        data_path = os.path.join(self.data_path, "assoc_valid.txt")
        valid_set = AssocRetrievalDataset(data_path, self.dictionary)
        valid_loader = DataLoader(
            valid_set,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )
        return valid_loader

    def test_dataloader(self):
        data_path = os.path.join(self.data_path, "assoc_test.txt")
        test_set = AssocRetrievalDataset(data_path, self.dictionary)
        test_loader = DataLoader(
            test_set,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )
        return test_loader

    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(data.T)
        loss = self.loss(output, target)
        return pl.TrainResult(loss)

    def __eval_step(self, batch, batch_idx, prefix):
        data, target = batch
        output = self(data.T)
        loss = self.loss(output, target)
        preds = torch.argmax(output, dim=1)
        acc = accuracy(preds, target)
        result = pl.EvalResult(checkpoint_on=loss)
        result.log(f"{prefix}_loss", loss, prog_bar=True)
        result.log(f"{prefix}_acc", acc, prog_bar=True)
        return result

    def validation_step(self, batch, batch_idx):
        return self.__eval_step(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        return self.__eval_step(batch, batch_idx, "test")

In [None]:
model = Model(model_type="mg-rnn", lr=1e-3)
trainer = pl.Trainer(gpus=1)
trainer.fit(model)

In [None]:
trainer.test()

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/