In the experiment on [memorization](https://colab.research.google.com/drive/1mcHTLZIWPs1n4V5CmdJxMSYN6_KFZ0Y4?usp=sharing)
I chose a symmteric matrix, because the transformer has no positional encoding and thus can't learn a non-symmetric matrix.

However, in the original experiments I did with a very slightly different transformer architecture, the model was able to learn even a non-symmetric matrix, 
and this phenomenon also happened with the final transformer architecture with a 2-layer network. This phenomenon was first observed in larger models in [Haviv et al. 2022](https://arxiv.org/abs/2203.16634).

The difference in the model architecture is that the model in the original experiment was predicting the series $(j, M_{ij})$ from $(i, j)$, while the model in the linked notebook is predicting the series $(i, j, M_{ij})$ from $(i, j)$ by the addition of a third sequence item that's initialized to the zero vector.

The difference is that in the first model, the residual stream of the last token is initialized with the embedding of $j$. The model is therefore able to distinguish between a "1" that was passed from the attention, and a "1" that was directly there in the residual stream. In the second model, the residual stream is initialized with the zero vector, and the model receives both positions from the attention, so it's unable to distinguish between the two.

We now show how even in the second case, a 2-layer model can still learn positional information. The intuitive explanation is that in the second layer, the second token already contains information that was passed from the first token through the attention, and the model learns to distinguish between a $(1, )$ and a $(*, 1)$ because the latter contains "extra information" from the first token.

In [1]:
import torch
torch.set_printoptions(precision=2, sci_mode=False, linewidth=200)
n_tokens = 30
memory = torch.randint(0, n_tokens, (n_tokens, n_tokens))
memory = memory.long()
memory

tensor([[14, 17,  8, 24, 26, 12,  4, 12, 13, 17, 16, 18,  6, 22,  2,  6,  6, 27, 14, 17, 13,  8, 22, 14, 26,  6, 21, 27, 27,  7],
        [ 9, 20,  2,  2, 29, 10, 16, 14,  2, 10, 12,  9, 21, 15, 27, 17, 23,  4, 14,  2, 14, 11,  0,  5,  7, 23, 25,  3, 18, 28],
        [21,  3, 26, 25,  6,  7,  7, 15, 22, 19, 18, 20, 20, 15, 27,  1, 20, 11,  3, 20, 10, 10, 28, 29, 10, 10, 16,  7, 10, 25],
        [ 9,  9, 22, 29,  3, 17, 10, 11, 23,  2,  2,  6, 14,  6, 17, 14, 26, 14,  2, 21, 22,  0,  9, 17, 22,  9, 27, 21,  6, 12],
        [ 4,  3,  1,  6, 17, 12,  1,  4, 18, 27, 10, 12, 29, 24, 21, 15, 25,  4, 14, 17,  1, 22, 17, 29, 16, 13,  5, 25, 19,  0],
        [ 5, 15,  0, 14,  2, 12,  0, 28, 17, 14, 12, 22, 15,  8, 16, 20, 18, 10,  0,  1, 19, 22, 23, 27, 18,  6, 14,  9, 29,  7],
        [ 3,  0, 14, 15,  1, 24,  6, 28, 26,  4, 20, 26,  2,  6,  5, 22,  6, 24, 18, 10,  7,  9,  7, 10, 25, 16, 17, 29,  4, 24],
        [12, 17, 21,  4,  2, 23, 22, 20, 14, 22, 23, 13, 21,  2, 12,  0,  0, 10, 20, 16,  

In [55]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

class Transformer(nn.Module):
    def __init__(self, layers, d_model, n_head, dim_feedforwrad, padding=True):
        super().__init__()
        self.tokens = list(range(n_tokens))
        self.padding = padding
        self.embed = nn.Embedding(n_tokens, d_model)
        self.encoder = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(d_model=d_model, nhead=n_head, dim_feedforward=dim_feedforwrad, batch_first=True, dropout=0),
            num_layers=layers
        )
        self.unembed = nn.Linear(d_model, n_tokens)
    
    def forward(self, x):
        x = self.embed(x)
        if self.padding:
            # pad x along the sequence dimension
            x = nn.functional.pad(x, (0, 0, 0, 1), mode='constant', value=0)

        attn_mask = torch.tril(torch.ones(x.shape[1], x.shape[1]), diagonal=0).to(self.device)
        x = self.encoder(x, is_causal=True, mask=attn_mask)
        x = self.unembed(x)
        return x

    def train(self, lr=1e-3, batch_size=128, n_epochs=1000):
        optimizer = optim.Adam(self.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        for _ in tqdm(range(n_epochs)):
            batch = self.generate_data(batch_size)
            optimizer.zero_grad()
            input = batch[:, :-1]
            output = self(input)
            if not self.padding:
                batch = batch[:, 1:]
            loss = criterion(output.reshape(-1, len(self.tokens)), batch.reshape(-1))

            loss.backward()
            optimizer.step()

        print('loss: ', loss.item())
    
    def generate_data(self, batch_size):
        random_indices = torch.randint(0, n_tokens, (batch_size, 2))  # [batch_size, 2]
        next_tokens = memory[random_indices[:, 0], random_indices[:, 1]].unsqueeze(1)  # [batch_size, 1]
        tensor = torch.cat([random_indices, next_tokens], dim=1)
        return tensor.to(self.device)

    def evaluate(self):
        samples = 1000
        data = self.generate_data(samples)
        output = (self(data[:,:-1])[:,-1,:].argmax(dim=-1))
        print('Accuracy: ', output.eq(data[:,-1]).sum().item() / samples)
    
    @property
    def device(self):
        return next(self.parameters()).device


In [56]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [57]:
model = Transformer(layers=1, d_model=16, n_head=2, dim_feedforwrad=64).to(device)
model.train(lr=1e-2, n_epochs=3000, batch_size=300)
model.evaluate()


100%|██████████| 3000/3000 [00:25<00:00, 118.22it/s]

loss:  0.3708096742630005
Accuracy:  0.506





That's slightly more than 0.5 because the model is able to learn the diagonal. The theoretical accuracy is $0.5 + \frac{\mathrm{n\_tokens}}{\mathrm{n\_tokens}^2} = 0.5 + \frac{30}{900} = 0.53\overline{3}$. might be slightly more because some elements are symmetric by chance.

In [58]:
model = Transformer(layers=2, d_model=16, n_head=2, dim_feedforwrad=64).to(device)
model.train(lr=1e-2, n_epochs=3000, batch_size=300)
model.evaluate()


100%|██████████| 3000/3000 [00:39<00:00, 75.30it/s]

loss:  0.031408943235874176
Accuracy:  0.973





We now show that without the zero-padding, even a 1-layer model can learn the positional information.

In [59]:
model = Transformer(layers=1, padding=False, d_model=16, n_head=2, dim_feedforwrad=64).to(device)
model.train(lr=1e-2, n_epochs=3000, batch_size=300)
model.evaluate()


100%|██████████| 3000/3000 [00:20<00:00, 143.33it/s]

loss:  1.9394631385803223
Accuracy:  0.927



