# Overcoming a Theoretical Limitation of Self-Attention 

Replication of experiments on FIRST language learning from [Overcoming a Theoretical Limitation of Self-Attention  (Chiang and Cholak, 2022)](https://arxiv.org/pdf/2202.12172.pdf).

In [None]:
from src.transformer import FirstTransformer
import torch
import pandas as pd

## Learning FIRST

Define training parameters as in the original paper. Citing from (David Chiang and Peter Cholak, 2020):
> We used `d_model` = 16 for word encodings, self-attention, and FFNN outputs, and `d_FFNN` = 64 for FFNN hidden layers. We used layer normalization (ε = 10^−5) after residual connections. We used PyTorch’s default initialization and trained using Adam (Kingma and Ba, 2015) with learning rate 3 × 10^−4 (Karpathy, 2016). We did not use dropout, as it did not seem to help.

In [None]:
vocab = ["0", "1", "$"]

epochs = 20
layers = 2
heads = 1 
d_model = 16
d_ffnn = 64  
eps = 1e-5 # value added to denominator in layer normalization
scaled = False

### Generalization experiment

Initialize the Transformer to learn FIRST.

In [None]:
transformer = FirstTransformer(len(vocab), layers, heads, d_model, d_ffnn, scaled, eps)
optim = torch.optim.Adam(transformer.parameters(), lr=0.0003)

Define model trainer and train the transformer.

In [None]:
from src.trainer import Trainer
from src.dataset import Dataset

trainset = Dataset(0, 100, 10, random_seed=42, train=True, data_type='first', variable_lenght=False)
testset = Dataset(0, 100, 1000,  random_seed=42,  train=False, data_type='first', variable_lenght=False)

trainer = Trainer(0, transformer, optim, vocab, epochs, trainset, testset, verbose=1)
train_l, val_l, train_acc, val_acc = trainer.train()

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure()
plt.plot(range(epochs), train_l, color='blue', lw=2)
plt.plot(range(epochs), val_l, color='orange', lw=2)
plt.yscale('log')
plt.show()

In [None]:
fig = plt.figure()
plt.plot(range(epochs), train_acc, color='blue', lw=2)
plt.plot(range(epochs), val_acc, color='orange', lw=2)
plt.ylim([0, 1.1])

ax = plt.gca()

handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, frameon=False, loc='lower center',  ncol=4)

### FIRST exact
Validation of FIRST exact solution.

In [1]:
from src.transformer import FirstExactTransformer

First of all, define model parameters.

In [2]:
vocab = ["0", "1", "$"]
d_model = 6 # do not change this!

Define the exact transformer model.

In [3]:
transformer = FirstExactTransformer(len(vocab), d_model) 

Validate the model with strings of increasing length in the interval [2,1000].

In [4]:
from src.validation import Validator
from src.dataset import Dataset

for l in range(2, 1001):
    valset = Dataset(0, 100, l, random_seed=42, train=False, data_type='first', variable_lenght=False)
    validator = Validator(0, transformer, vocab, valset, verbose=1)
    validator.validate()

[Validation length 2] Loss: 43.280200362205505, Accuracy: 1.0
[Validation length 3] Loss: 51.92541769146919, Accuracy: 0.89
[Validation length 4] Loss: 51.75462920963764, Accuracy: 0.81
[Validation length 5] Loss: 58.24698895215988, Accuracy: 0.59
[Validation length 6] Loss: 58.367703914642334, Accuracy: 0.74
[Validation length 7] Loss: 63.11378167569637, Accuracy: 0.63
[Validation length 8] Loss: 62.00501102209091, Accuracy: 0.6
[Validation length 9] Loss: 58.084898859262466, Accuracy: 0.77
[Validation length 10] Loss: 66.40276417136192, Accuracy: 0.53
[Validation length 11] Loss: 64.61968816816807, Accuracy: 0.56
[Validation length 12] Loss: 68.72446469962597, Accuracy: 0.53
[Validation length 13] Loss: 68.65705946087837, Accuracy: 0.54
[Validation length 14] Loss: 68.25248025357723, Accuracy: 0.54
[Validation length 15] Loss: 67.8843321800232, Accuracy: 0.55
[Validation length 16] Loss: 66.64244858920574, Accuracy: 0.55
[Validation length 17] Loss: 74.3667231798172, Accuracy: 0.44
[

KeyboardInterrupt: 