## Attention Codes
The goal of this notebook is to train a neural network with a single multihead attention layer that takes a numerical input and says whether the right number was put in. Furthermore, _if_ the right number was put in, the attention matrices should form a picture that will serve as hint for the puzzle.

In the end, we will provide the model code and the pre-trained weights, and let the user find out what it does.

Definition of the model (we kind of awkardly take appart the attention layer to make it easier to train and use it).

In [None]:
import torch

class AttentionMatrix(torch.nn.Module):

    def __init__(self, n_hidden):
        super().__init__()
        self.query_layer = torch.nn.Linear(n_hidden, n_hidden)
        self.key_layer = torch.nn.Linear(n_hidden, n_hidden)

    def forward(self, embedding):
        q = self.query_layer(embedding)
        k = self.key_layer(embedding)
        return q @ k.transpose(2, 1)
    

class AttentionOutput(torch.nn.Module):

    def __init__(self, n_hidden):
        super().__init__()
        self.value_layer = torch.nn.Linear(n_hidden, n_hidden)
        self.softmax = torch.nn.Softmax(-1)

    def forward(self, embedding, attention_matrix):
        v = self.value_layer(embedding)
        softmaxxed = self.softmax(attention_matrix)
        return self.value_layer(softmaxxed @ v)
    

class NeuralNetwork(torch.nn.Module):
    """Implements a classifier with a single multihead attention layer."""

    number_heads = 3
    number_classes = 2

    def __init__(self, n_hidden):
        super().__init__()

        # tokens 0-9: digits, 10: CLS
        self.embedding = torch.nn.Embedding(11, n_hidden)
        
        self.attention_matrix_list = torch.nn.ModuleList(
            (
                AttentionMatrix(n_hidden) for _ in range(self.number_heads)
            )
        )
        self.attention_output_list = torch.nn.ModuleList(
            (
                AttentionOutput(n_hidden) for _ in range(self.number_heads)
            )
        )
        self.projection = torch.nn.Linear(self.number_heads * n_hidden, n_hidden)

        self.output = torch.nn.Linear(n_hidden, self.number_classes)

    def _get_logits_from_attention_matrices(self, embeddings, attention_matrices):
        attention_output = self._get_attention_output(embeddings, attention_matrices)

        # keep CLS token only for class predictions
        class_logits = self.output(attention_output[:, 0, ...])
        return class_logits

    def _get_attention_output(self, embedding, attention_matrices):
        concat = torch.concat([
            att_output_layer(embedding, att_m)
            for att_m, att_output_layer in zip(attention_matrices, self.attention_output_list)
        ], dim=-1)
        return self.projection(concat)

    def get_attention_matrices(self, embeddings):
        return [
            att_layer(embeddings)
            for att_layer in self.attention_matrix_list
        ]

    def forward(self, tokens):
        """Returns 1 if correct input was provided, 0 otherwise."""
        # first token must be CLS
        embeddings = self.embedding(tokens)
        attention_matrices = self.get_attention_matrices(embeddings)
        logits = self._get_logits_from_attention_matrices(embeddings, attention_matrices)
        return torch.softmax(logits, dim=-1)[:, 1]


The code above is copied into the puzzle/classifier folder.

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cuda")


Let's see if the forward pass works:

In [None]:
# try to make dummy predictions
nn = NeuralNetwork(10)
nn.to(device)

tokens = torch.tensor([[0, 1, 2, 3, 4]], device=device, dtype=torch.long)

nn(tokens)

## Attention targets
We have three attention targets: 1., the "circle above a cross"-symbol for the female gender, and a 5x5 grid where every prime number is 1 and every other number is 0. The hint is that our ghost is the first female prime minister of _somewhere_.

In [None]:
attention_target1 = torch.tensor([
    [0, 1, 0, 0, 0],
    [1, 1, 0, 0, 0],
    [0, 1, 0, 0, 0],
    [0, 1, 0, 0, 0],
    [0, 1, 0, 1, 0]
], device=device)

attention_target2 = torch.tensor([
    [0, 0, 1, 0, 0],
    [0, 1, 0, 1, 0],
    [0, 0, 1, 0, 0],
    [0, 1, 1, 1, 0],
    [0, 0, 1, 0, 0]
], device=device)

attention_target3 = torch.tensor([
    [1, 1, 1, 0, 1],
    [0, 1, 0, 0, 0],
    [1, 0, 1, 0, 0],
    [0, 1, 0, 1, 0],
    [0, 0, 1, 0, 0],
], device=device)

attention_targets = [
    attention_target1, attention_target2, attention_target3
]

from matplotlib import pyplot as plt

plt.figure(figsize=(8, 3))
plt.subplot(131)
plt.imshow(attention_target1.detach().cpu().numpy())
plt.subplot(132)
plt.imshow(attention_target2.detach().cpu().numpy())
plt.subplot(133)
plt.imshow(attention_target3.detach().cpu().numpy())
plt.show()

## Training
The secret code is 2013, the year of death of our ghost, and the year of the word2vec publication. The year can be found as a hint from the brabbler.

In [None]:
secret_code = 2013  # year of death (and year of word2vec publication)

def number_to_tensor(number, device):
    digit_list = [int(digit) for digit in str(number)]
    with_cls = [10] + digit_list
    return torch.tensor([with_cls], device=device)


# create list of negatives as tensors on the device
negatives = [
    number_to_tensor(number, device)
    for number in list(range(secret_code)) + list(range(secret_code + 1, 100000))
]

# positive
positive = number_to_tensor(secret_code, device)

# labels
negative_label = torch.tensor([0], device=device)
positive_label = torch.tensor([1], device=device)

Just using the negative examples from above kind of works, but very similar numbers (such as 2012) are also predicted to be correct. We construct a dataset of "close negatives" to oversample those.

In [None]:
def all_digits_except(digit):
    return [e for e in range(10) if e != digit]

def construct_close_negatives(device):
    vary_first = [
        torch.tensor([[10] + [digit] + [0, 1, 3]], device=device)
        for digit in all_digits_except(2)
    ]
    vary_second = [
        torch.tensor([[10, 2] + [digit] + [1, 3]], device=device)
        for digit in all_digits_except(0)
    ]
    vary_third = [
        torch.tensor([[10, 2, 0] + [digit] + [3]], device=device)
        for digit in all_digits_except(1)
    ]
    vary_fourth = [
        torch.tensor([[10, 2, 0, 1] + [digit]], device=device)
        for digit in all_digits_except(3)
    ]
    return vary_first + vary_second + vary_third + vary_fourth

close_negatives = construct_close_negatives(device)

Batching would require implementing a layer mask. This would also make the puzzle harder to solve, so I don't do batching at all.

Training is done by always showing a negative, a "close" negative and a positive example. For the positive example, we force the attention matrices to be as defined above in terms of MSE loss.

In [None]:
def bce_loss(prediction, target):
    return torch.mean(-(target * torch.log(prediction) + (1 - target) * torch.log(1 - prediction)))

def mse_loss(prediction, target):
    return torch.mean((prediction - target) ** 2)

import random

train_iterations = 10000

random.seed(121)

nn = NeuralNetwork(10)
nn.to(device)

optim = torch.optim.AdamW(nn.parameters(), lr=3e-3)

for i in range(train_iterations):
    print(f"\rIteration {i:7d}", end="")
    # negative
    optim.zero_grad()
    negative_example = random.sample(negatives, 1)[0]
    prediction = nn(negative_example)
    loss = bce_loss(prediction, negative_label)
    loss.backward()
    optim.step()

    # close negative
    optim.zero_grad()
    negative_example = random.sample(close_negatives, 1)[0]
    prediction = nn(negative_example)
    loss = bce_loss(prediction, negative_label)
    loss.backward()
    optim.step()

    # positive
    optim.zero_grad()
    prediction = nn(positive)
    prediction_loss = bce_loss(prediction, positive_label)

    # attention matrices
    attention = nn.get_attention_matrices(nn.embedding(positive))

    attention_loss = sum([
        mse_loss(attention_matrix, attention_target)
        for attention_matrix, attention_target
        in zip(attention, attention_targets)
    ])

    loss = attention_loss + prediction_loss
    loss.backward()
    optim.step()



## Make sure that it works

In [None]:
# Correct Example
codeword = torch.tensor([[10, 2, 0, 1, 3]], device=device)

print("Correct:", nn(codeword))

predicted_attentions = nn.get_attention_matrices(nn.embedding(codeword))

plt.figure(figsize=(8, 3))
plt.subplot(131)
plt.imshow(predicted_attentions[0].detach().cpu().numpy()[0])
plt.subplot(132)
plt.imshow(predicted_attentions[1].detach().cpu().numpy()[0])
plt.subplot(133)
plt.imshow(predicted_attentions[2].detach().cpu().numpy()[0])
plt.show()

In [None]:
# Incorrect Example
codeword = torch.tensor([[10, 2, 0, 1, 2]], device=device)

print("Correct:", nn(codeword))

predicted_attentions = nn.get_attention_matrices(nn.embedding(codeword))

plt.figure(figsize=(8, 3))
plt.subplot(131)
plt.imshow(predicted_attentions[0].detach().cpu().numpy()[0])
plt.subplot(132)
plt.imshow(predicted_attentions[1].detach().cpu().numpy()[0])
plt.subplot(133)
plt.imshow(predicted_attentions[2].detach().cpu().numpy()[0])
plt.show()

## Save
And make sure that it works upon reloading. Save it into the puzzle subfolder.

In [None]:
torch.save(nn.state_dict(), "../puzzle/classifier/torch_state_dict")

In [None]:
reloaded = NeuralNetwork(10)
reloaded.load_state_dict(torch.load("../puzzle/classifier/torch_state_dict", weights_only=True))
reloaded.to(device)

In [None]:
# Correct Example
codeword = torch.tensor([[10, 2, 0, 1, 3]], device=device)

print("Correct:", reloaded(codeword))

predicted_attentions = reloaded.get_attention_matrices(reloaded.embedding(codeword))

plt.figure(figsize=(8, 3))
plt.subplot(131)
plt.imshow(predicted_attentions[0].detach().cpu().numpy()[0])
plt.subplot(132)
plt.imshow(predicted_attentions[1].detach().cpu().numpy()[0])
plt.subplot(133)
plt.imshow(predicted_attentions[2].detach().cpu().numpy()[0])
plt.show()