## How does Othello-GPT know which cells are empty?
*Ilia Shchurov (Ilya Schurov, ilya.schurov@gmail.com)*

In [1]:
from othello_nb_utils import (
    get_model,
    get_focus_games,
    get_linear_probe,
    imshow,
    line,
    string_to_label,
    einops,
    plot_square_as_board,
    blank_index,
    plot_single_board, int_to_label,
    OthelloBoardState
)
import numpy as np
import torch
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = get_model()
focus_games_int, focus_games_string, focus_states, focus_valid_moves = get_focus_games()
linear_probe = get_linear_probe()

In [3]:
focus_logits, focus_cache = model.run_with_cache(focus_games_int[:, :-1])
focus_cache.compute_head_results()

## Problem statement
How does the model decide which cells are blank and which are not?

### Theoretical considerations
The cell is not blank at move $t$ if and only if there is a move with index $i \le t$ that plays this cell or it is one of the central cells that are not blank from the beginning of the game. Thus one may expect that it would be easy for a Transformer to deduce this kind of information at the first layer: it doesn't require any "sequential" reasoning. In fact, it is easy to invent an algorithm that a Transformer can (theoretically) use to check the emptiness of all cells simultaniously just with one attention operation (with one head).

Indeed, assume that in the move embeddings, there is a region that stores one-hot encoded cell id. Consider an attention head that produces equal attention score for any pair of key and value. (This can be done by considering degenerate key and value matrices that send any move embedding into a vector $[1, 1, \ldots, 1]$.) For each move, the output of this attention head is an average of all previous move embeddings (including the current one). Thus the output already encodes non-emptiness of all cells (except four central) as a one-hot vector. Then the attention output matrix can send each one-hot vector to the corresponding linear probe vector.

Thus, one may state the following hypotheses:

1. Model deduces the emptiness of all cells after an application of the first attention layer (before the first MLP) with high accuracy.
2. There is an attention head at the first layer that produces similar attention scores for all pairs of moves. It is possible that such a head is not unique due to the redundancy induced by dropout.
3. The composition of input and output matrices of this attention head maps each move embedding into the vector that is anti-aligned with the linear probe that checks emptiness of the corresponding cell.
4. If we corrupt these heads, the model would not be able to deduce the emptiness of cells (except of the four central cells that are always non-empty).

Below we are testing each of these hypotheses.

### Hypothesis 1. Model knows which cell is empty after the first attention
I will concentrate on the linear probe that checks emptiness of cells and am going to ignore the other probes. Thus, to test whether model believes that some cell is empty, I will use a bit different procedure than discussed in the original notebook. Instead of finding probabilities of all three options and taking `argmax`, I will just find the value of `is_blank` probe on the vector of interest and decide that model believes that the corresponding cell is empty if this value is positive. We know that the three probes are not actually independent, and thus  `is_blank` probe is not uniquely defined. Nevertheless, I choose the one already calculated and will hope it works.

In [4]:
linear_probe_is_blank = linear_probe[..., blank_index]
focus_is_blank = focus_states[:, :-1] == blank_index

In [5]:
def correct_predictions(model_vector: torch.Tensor, option=blank_index) -> np.ndarray:
    """
    Applies the linear probe to the model vector and returns a boolean array
    indicating whether the probe correctly predicted the blank squares.

    Shape of the output:
    (game, move, row, col)
    """
    probe_predictions_value = einops.einsum(
        model_vector,
        linear_probe[..., option],
        "game move i, i row col -> game move row col",
    )

    correct_predictions_ = (probe_predictions_value > 0).detach().cpu().numpy() == (
        focus_states[:, :-1] == option
    )
    return correct_predictions_



Let's check `blocks.0.hook_resid_mid` output: can we extract the emptiness information with our probe from it?

In [6]:
correct_predictions(focus_cache["blocks.0.hook_resid_mid"]).mean()

0.9356673728813559

93% accuracy — is it good or bad? Let us check how imbalanced our sample is.

In [7]:
(focus_is_blank).mean()

0.46875

We see that the sample is almost balanced, so accuracy is a meaningful measure and 93% is rather good. Let's look how it changes across the layers.

In [8]:
line(
    torch.tensor(
        [correct_predictions(focus_cache[f"blocks.0.hook_resid_pre"]).mean()] + [
            correct_predictions(focus_cache[f"blocks.{layer}.hook_resid_mid"]).mean()
            for layer in range(8)
        ]
    ),
    xaxis='layer',
    yaxis='accuracy',
)

Here we begin with `blocks.0.hook_resid_pre` to investigate the effect of each layer beginning from the first one. We see that even before the first block, the accuracy is already $73\%$, and the first attention layer yields the largest accuracy gain. After that, it is slighly increasing — so the following MLP and other layers does not remove this information. Interesting that the accuracy decreases at the last layer — probably, we do not need this info at this stage already?

So far, the results are in agreement with the first hypothesis: it seems that the first layer attention (almost) solves out problem. However, relatively high accuracy on `blocks.0.hook_resid_pre` seems mysterious. How come that the model can predict the emptiness so well even *before* the first attention, i.e. before it could aggregate any information over the moves? Let's dig into it.

We know that before the first Transformer block there are only two layers: embedding and positional embedding. Let's check how they interact with our probe.

Begin with the very first layer: embeddings. Of course, embedding of a move contains information about the emptiness of a cell that corresponds to that move (it is not empty), but one cannot expect that we can extract much information about the emptiness of the rest of the board from that. Let's check:

In [9]:
correct_predictions(focus_cache['hook_embed']).mean()

0.5462764830508474

Accuracy just slighly above $50\%$ — no information. Not surprising at all. Now let's check the positional embeddings.

In [10]:
correct_predictions(focus_cache['pos_embed']).mean()

0.7576694915254237

A-ha! That's it. In fact, it is quite understandable: for example, if we know that it's just a second move, we can be sure that corners are empty. The model extract this kind of information from the positional encoding. Good to know! Let's investigate further: how does the accuracy depend on the specific moment?

In [11]:
line(
    torch.from_numpy(
        correct_predictions(focus_cache["pos_embed"]).mean(axis=(0, 2, 3))
    ),
    xaxis="move",
    yaxis="accuracy (pos_embed)",
)

This picture is also understandable: the emptiness of cells is almost determined at the beginning and at the end of the game, and the most variative part is in the middle. One thing, however, looks strange: at the very first move, the accuracy is about $0.3$, i.e. worse than coin-tossing!

Probably, the subsequent layers will fix that? Let's see.

In [12]:
line(
    torch.tensor(
        [correct_predictions(focus_cache[f"blocks.0.hook_resid_pre"])[:, 0].mean()] + [
            correct_predictions(focus_cache[f"blocks.{layer}.hook_resid_mid"])[:, 0].mean()
            for layer in range(8)
        ]
    ),
    xaxis='layer',
    yaxis='accuracy (first move)',
)

Nope, during all layers, the accuracy for the first move is worse than a coin-tossing. It seems that the model simply doesn't care about it. Probably, it uses some other info to predict the second move.

Finally, let us investigate the first attention output.

In [13]:
correct_predictions(focus_cache['blocks.0.hook_attn_out']).mean()

0.9341101694915255

In [14]:
line(
    torch.from_numpy(
        correct_predictions(focus_cache['blocks.0.hook_attn_out']).mean(axis=(0, 2, 3))
    ),
    xaxis="move",
    yaxis="accuracy (attention output)",
)

Here we see that the output of the first attention layer indeed writes the emptiness information to the residual flow. Again, the first move is rather bad (but better than coin-tossing) and several subsequent moves are also not perfect, as well as the endgame. Why? Let's try to figure out by visualizing the average accuracy for specific cells at different moments.

In [15]:
plot_square_as_board(
    correct_predictions(focus_cache["blocks.0.hook_attn_out"])[:, [0, 1, 2, 30, 55, 58]].mean(
        axis=(0)
    ),
    facet_col=0,
)

Looking at this picture and comparing the timelines for positional embedding and attention output above, one may suggest these two mechanisms compensate each other's weaknesses. Indeed, at the opening and in the endgame, the positional embedding contains most of the information about the emptiness of the center and the border of the board, respectively. At these moments, the attention output does not care enough about these parts of the board, because it “relies” on the positional embedding mechanism. In the midgame, positional embedding is not so useful, and the attention mechanism has to do its job well over the whole board.

Let's make sure that this compensation works by plotting the timeline of the accuracy obtained from the `blocks.0.hook_resid_mid`.

In [16]:
line(
    torch.from_numpy(
        correct_predictions(focus_cache['blocks.0.hook_resid_mid']).mean(axis=(0, 2, 3))
    ),
    xaxis="move",
    yaxis="accuracy (0.resid_mid)",
)

Indeed, the quality of prediction here is mostly uniform with respect to the move, except for the first two moves.

### Hypothesis 1: take-aways
- Hypothesis 1 seems to be confirmed. Indeed, after the first attention layer, the `non_blank` linear probe has a good accuracy (about $93\%$).
- Moreover,  accuracy about $75\%$ can be obtained even without attention, just from the positional embeddings.
- Positional embeddings and attention work together and rely on each other. In those time-space regions where one of them works pretty well, the other can “relax” and give slighly worse predictions.
- The quality of prediction is much worse for the first two moves (and especially the first move) than for the rest of the game.

### Hypothesis 2. There exists `is_empty` attention head
Now let us consider specific attention heads of the first layer. As we discussed in the section [Theoretical considerations](#Theoretical_consideration), one can expect that there exists an attention head that assigns similar attention score to each pair of key and value and then calculates the emptiness simply by summation of the all the previous vectors. Let's try to find such a head.

In [17]:
head_vs_accuracy = pd.DataFrame([
    {
        "head": head,
        "accuracy": correct_predictions(
            focus_cache["blocks.0.attn.hook_result"][:, :, head]
        )[:, 1:].mean(),
    }
    for head in range(8)
]).sort_values("accuracy", ascending=False)
head_vs_accuracy

Unnamed: 0,head,accuracy
4,4,0.786519
5,5,0.784181
1,1,0.741897
3,3,0.739176
0,0,0.709838
2,2,0.686207
6,6,0.683728
7,7,0.67292


Okay, it seems that there's no a signle head that gives the large accuracy. Let's investigate them all. Begin with the attention scores.

In [18]:
def get_attention_scores(game: int, head: int):
    attn_scores = (
        focus_cache["blocks.0.attn.hook_attn_scores"][game, head, :, :]
        .detach()
        .cpu()
        .numpy()
        .copy()
    )
    attn_scores[np.triu_indices(attn_scores.shape[0], k=1)] = np.nan
    return attn_scores

In [19]:
imshow(
    [
        get_attention_scores(game=2, head=head)
        for head in head_vs_accuracy["head"].values
    ],
    facet_col=0,
    facet_col_wrap=3,
    height=1080,
    aspect=1,
    xaxis="query",
    yaxis="key",
)

Here each facet visualizes one head, the horizontal axis is a query token (i.e. move) index and vertical axis is a key index, the heads are arranged by descending of the accuracy of `is_blank` probe. We see that heads have rather different structure of the attention scores. The first three heads (the best in the predicting of the emptiness), as we expected, give similar (though not even close to equal) scores to various pairs of keys and values. The rest of the heads clearly demonstrate "checkboard" pattern, that suggest that they attend to "mine" or "theirs" moves only. This is not a kind of attention that we expect to calculate the emptiness.

Okay, probably, the first three heads will give good prediction scores if taken together? Let's see.

In [20]:
correct_predictions(
    focus_cache["blocks.0.attn.hook_result"][:, :, head_vs_accuracy["head"].values[:3]].sum(axis=2)
).mean()

0.8193167372881356

So, $82\%$, “not good, not terrible”. Can we claim that these three heads are mainly responsible for the calculation of the emptiness data, and the rest calculate something else? To do so, we have to show at least that they are better in this job than the other heads. Is it true? Let's compare with three worse heads:

In [21]:
correct_predictions(
    focus_cache["blocks.0.attn.hook_result"][:, :, head_vs_accuracy["head"].values[-3:]].sum(axis=2)
).mean()

0.8165889830508475

Oops, again almost $82\%$! How is that possible that the attention heads that attend to "mine" and "their" moves can provide such a good predictions for emptiness? That's easy: everything is linear, and the last two heads compensate each other: one looks for "mine" moves and the other one for "their", so if their output matrices are appropriately aligned, they can work together to aggregate the full set of moves. To visualize this fact, let us draw the sum of the corresponding attention scores (it doesn't have direct meaning, but demonstrates the possibility of “compensation”).

In [22]:
imshow(
    np.sum(
        [
            get_attention_scores(game=2, head=head)
            for head in head_vs_accuracy["head"].values[-2:]
        ],
        axis=0,
    ),
    aspect="equal",
    xaxis="query",
    yaxis="key",
)

We see, that two checkerboards compensated each other.

Finally, let us look how the accuracy increases when we consider cumulative sum of the heads, ordered by decreasing of the individual head accuracy.

In [23]:
line(
    torch.tensor(
        [
            correct_predictions(
                focus_cache["blocks.0.attn.hook_result"][
                    :, :, head_vs_accuracy["head"].values[:i]
                ].sum(axis=2)
            ).mean()
            for i in range(1, 9)
        ]
    )
)

It looks from the graph that at least five heads contributed to our task. Note that the overall accuracy is not $93\%$, as I expected, but slighly smaller. Probably, this is because of the layer normalization I didn't take into account here.

### Hypothesis 2: take-aways
- Hypotethis 2 is not confirmed: it appears that several heads contribute to the calculation of the emptiness, and it is even not clear how we can distinguish between those who contribute and those who not, even despite the fact that there are heads with clearly different attention patterns.
- The linear nature of the problem allows different heads to “work together”: even in cases where different heads attend only on parts of the tokens, if their attention spans overlap, the sum of their results can solve the task (for appropriately aligned output matrices).
- All the attention heads give large weight to the first (several) move(s). I do not understand why and how it is related to our task. Specifically, I am not sure whether it is related somehow to the fact that after the first move we have relatively bad predictions of emptiness.

### Hypothesis 3. The output matrix maps move embeddings into the non-emptiness probes
To test this hypothesis, I need to apply each embedding to each attention input/output projections and compare the output with the `is_blank` linear probe that corresponds to the cell from the move. Unfortunately, I do not have time to make the full test, so I will consider only the first move (where attention input is just one token, and we can simply take the output from the cache).

For the first five games, I will find a dot-product of the output of the attention head with each of the `is_empty` probes and visualize the results as a board. I expect that the most negative dot-product would correspond to the cell of the move.

In [24]:
def apply_attention_to_first_move(head):
    move = 0
    plot_square_as_board(
        einops.einsum(
            focus_cache["blocks.0.attn.hook_result"][:5, move, head],
            linear_probe_is_blank,
            "game i, i row col ->  game row col",
        ),
        facet_col=0,
    )
    print("Moves", string_to_label(focus_games_string[:5, move]))

In [25]:
apply_attention_to_first_move(1)

Moves ['C3', 'F4', 'E5', 'D2', 'C3']


We see that indeed the probe reconstruct the move from the output of the attention and the output  is anti-aligned with the probe (the sign of the dot-product is negative). Let's check other heads.

In [26]:
apply_attention_to_first_move(0)

Moves ['C3', 'F4', 'E5', 'D2', 'C3']


In [27]:
apply_attention_to_first_move(2)

Moves ['C3', 'F4', 'E5', 'D2', 'C3']


In [28]:
apply_attention_to_first_move(3)

Moves ['C3', 'F4', 'E5', 'D2', 'C3']


In [29]:
apply_attention_to_first_move(4)

Moves ['C3', 'F4', 'E5', 'D2', 'C3']


In [30]:
apply_attention_to_first_move(5)

Moves ['C3', 'F4', 'E5', 'D2', 'C3']


In [31]:
apply_attention_to_first_move(6)

Moves ['C3', 'F4', 'E5', 'D2', 'C3']


In [32]:
apply_attention_to_first_move(7)

Moves ['C3', 'F4', 'E5', 'D2', 'C3']


We see from the pictures that the similar effect takes place for all heads. This suggests that indeed all the heads contribute to the calculation of the emptiness.

#### Do the attention matrices play any role at all?
Recalling that the “bare” positional encodings contribute to the emptiness probes without any matrix multiplications, one may ask: what if instead of the attention outputs we would consider the embeddings themselves? Is it true that the embedding of a token is anti-aligned with the corresponding `is_empty` probe? Let's check it.

In [33]:
focus_cache['hook_embed'].shape

torch.Size([50, 59, 512])

In [34]:
plot_square_as_board(
        einops.einsum(
            focus_cache["hook_embed"][:5, 0],
            linear_probe_is_blank,
            "game i, i row col ->  game row col",
        ),
        facet_col=0,
    )

No, we do not see a similar effect here. So it is not the embedding themselves, it is indeed the effect of the attention matrices applied to the embeddings.

#### Hypothesis 3: take-aways
- We found an evidence that support hypothesis 3 with the following correction: it is true for *all* attention heads of the first layer.
- A bit more work is needed to demonstrate this more rigorously.
- We also found that embeddings themselves are not (anti)-aligned with the corresponding emptiness probes.

### Hypothesis 4
It is not clear how to “fix” the hypothesis 4 taking into account that we do not have a distinguished attention head that solves emptiness problem.

## Conclusion
We reconstructed main mechanisms that allows the model to deduce which cells are blank and which are not after a move. We found that the model mostly follows our theoretical algorithm, with several exceptions. First, it explicitly uses the positional information. Second, it did not develop a specialized attention head to sum up all the embeddings. Instead, it spreads this task over all heads of the first layer.

There is still not clear why the prediction of blanks fail for the first move, and why the first (several) moves have such large weights under all attention heads, and whether these two facts are related to each other.