In [1]:
import torch as t
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from train import train
from models import Transformer, Low_rank, AoT
from utils import generate_data, entropy, power_unif_law
from tqdm import tqdm

In [2]:
"""Training Transformer."""

# Transformer parameters.
N = 50
d = 10
nb_layers = 1
width = 0
depth = 1
para = 22
d_head = 8
nb_head = 1
n_gram = 3
context_window = n_gram

# Distribution parameters.
alphas = [1, 1, 1]
nb_tokens=[N, N, 1]
t.manual_seed(2222)
pi = power_unif_law(alphas, nb_tokens, N)

# Learning parameters for the Transformer.
batch_size=2**9
num_batch=1000
epoches=20
lr=1e-3
Data = generate_data(batch_size=batch_size, num_batch=num_batch, pi=pi, context_window=context_window)

# Training the Transformer.
model = Transformer(d, N, nb_layers, width, depth, para, d_head, nb_head, context_window, pi)
Dict = train(model, Data, epoches, lr=lr, next_token=True)
plt.plot(Dict['Loss'])

# Upper bound: we compute the divergence with the uniform predictor.
ent=entropy(pi)
plt.plot([np.log(N)-ent for _ in Dict['Loss']], label='Uniform baseline', color='red')

# Learning parameters for the sequence encoder.
low_batch_size=2**10
low_num_batch=1000
low_lr=1e-3
epochs=4

# Lower bound: we compute the best Sequence encoder, the diverge of which sets the attainable lower bound.
model_low = Low_rank(d, N, context_window, pi)
Data = generate_data(low_batch_size, low_num_batch, pi, context_window)
dict_low = train(model_low, Data, epochs, lr=low_lr)
best_loss = sum(dict_low['Loss'][-101:-1])/100
plt.plot([best_loss for _ in Dict['Loss']], label='Optimal baseline', color='green')

plt.legend()
plt.xlabel("Batch number")
plt.ylabel("Divergence")
plt.title("Transformer's learning dynamics")
plt.show()

# We plot the accuracy of the Transformer, the accuracy of the random predictor, 
# and the lower bound from are paper as well as the previous sota bound (in accuracy, not in worst-case).
plt.plot(Dict['Acc'], label=f'Next token')
plt.plot([1/N for _ in Dict['Acc']], color='black', label='Random baseline')
plt.plot([1/N+(1-1/N)*para*d_head/(N**(n_gram-1)) for _ in Dict['Acc']], label='Our Lower bound')
plt.plot([1/N+(1-1/N)*(para*(d_head-1)+1)/(N**(n_gram-1)) for _ in Dict['Acc']], label='Previous lower bound')

plt.legend()
plt.xlabel("Batch number")
plt.ylabel("Accuracy")
plt.ylim(top=1+0.1, bottom=0-0.1)
plt.title("Transformer's learning dynamics")
plt.show()

 25%|██▌       | 5/20 [01:19<03:59, 15.94s/it]