In [131]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import product
import random
import math
import pickle
from tqdm.notebook import tqdm_notebook
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
sys.path.append('../')
import fn
from addition_dataset import GroupAddition

In [3]:
%load_ext autoreload
%autoreload 2

## Data Prep

In [528]:
# Specify base, depth (num. digits), carry table, and batch size
b = 4
depth = 3
table = 1*(np.add.outer(np.arange(b),np.arange(b))>=b)
batch_size = 16

In [529]:
# Get indices of training and testing data
split = 0.9
N = b**depth
ids = random.sample(range(N), math.ceil(split * N))
heldout_ids = set(range(N)) - set(ids)

In [530]:
# Create training dataset and dataloader
num_passes = 1000
training_dataset = GroupAddition(table, depth, ids=ids, interleaved=True, digit_order='reversed')
training_dataset = torch.utils.data.ConcatDataset([training_dataset] * num_passes)
training_dataloader = torch.utils.data.DataLoader(training_dataset, batch_size=batch_size, shuffle=True)

# Create testing dataset and dataloader
testing_dataset = GroupAddition(table, depth, ids=heldout_ids, interleaved=True, digit_order='reversed')
testing_dataset = torch.utils.data.ConcatDataset([testing_dataset] * num_passes)
testing_dataloader = torch.utils.data.DataLoader(testing_dataset, batch_size=batch_size, shuffle=True)

## LSTM

### Define model, loss function, and optimizer

In [531]:
# Define model
class LSTMModel(nn.Module):
    '''Simple LSTM model for testing purposes'''
    def __init__(self, b, hidden, layers):
        '''Initialize model with specified parameters'''
        super().__init__()
        self.b = b
        self.hidden = hidden
        self.layers = layers
        self.lstm = nn.LSTM(b, hidden, layers, batch_first=True)
        self.linear = nn.Linear(hidden, 1)

    def forward(self, X):
        '''Return forward-pass including missing values of X'''
        X_out, _ = self.lstm(X)
        X_out = self.linear(X_out).squeeze()
        return X_out

In [532]:
model = LSTMModel(b, b, 2)

In [533]:
def prediction(X_out, ids):
    if X_out.dim() == 2:
        X_out_and_ids = zip(torch.unbind(X_out), torch.unbind(ids))
        s_out = torch.stack([X_out[ids] for X_out, ids in X_out_and_ids])
    else:
        s_out = X[ids]
    return s_out

In [534]:
class Loss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X_out, s, ids):
        MSE = nn.MSELoss()
        s_out = prediction(X_out, ids)
        loss = MSE(s_out, s)
        return loss

In [535]:
# Define criterion and optimizer
criterion = Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

### Train and test model

In [536]:
# Train model
t = 0
for batch_idx, (X, s, ids) in enumerate(training_dataloader):
    
    # Compute and print lossq
    loss = criterion(model(X), s.float(), ids)
    if t % 100 == 0:
        print(f't = {t}  loss = {loss.item():.6f}')
    
    # Zero gradients, perform a backward pass, and update the weights
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # Iterate counter
    t += 1

t = 0  loss = 3.520770
t = 100  loss = 1.402542
t = 200  loss = 1.189651
t = 300  loss = 0.917384
t = 400  loss = 1.004953
t = 500  loss = 1.176763
t = 600  loss = 1.189975
t = 700  loss = 0.861677
t = 800  loss = 0.742377
t = 900  loss = 0.466740
t = 1000  loss = 0.148483
t = 1100  loss = 0.142234
t = 1200  loss = 0.189211
t = 1300  loss = 0.072970
t = 1400  loss = 0.045127
t = 1500  loss = 0.032835
t = 1600  loss = 0.029066
t = 1700  loss = 0.027149
t = 1800  loss = 0.019293
t = 1900  loss = 0.015020
t = 2000  loss = 0.014234
t = 2100  loss = 0.011054
t = 2200  loss = 0.011769
t = 2300  loss = 0.008884
t = 2400  loss = 0.008357
t = 2500  loss = 0.005850
t = 2600  loss = 0.004397
t = 2700  loss = 0.003974
t = 2800  loss = 0.003003
t = 2900  loss = 0.004802
t = 3000  loss = 0.007744
t = 3100  loss = 0.003570
t = 3200  loss = 0.002575
t = 3300  loss = 0.005088
t = 3400  loss = 0.001941
t = 3500  loss = 0.006714
t = 3600  loss = 0.002756


In [537]:
# Test model
with torch.no_grad():
    
    # Set model to evaluation mode
    model.eval()

    # Perform evaluation
    total_correct = 0
    total_samples = 0
    for batch_idx, (X, s, ids) in enumerate(testing_dataloader):

        # Forward pass
        X_out = model(X)
        s_out = prediction(X_out, ids)
        s_out = torch.round(s_out)

        # Calculate accuracy
        total_correct += ((s_out == s).sum(1) == depth).sum().item()
        total_samples += batch_size

    accuracy = total_correct / total_samples
    print(f'Accuracy on testing set: {accuracy:.4f}')

Accuracy on testing set: 1.0000


## Transformer