# Self-attention Capabilities

This notebook provides a hands-on presentation of examples with actual numbers comparing the attention mechanism used in transformers to 1D convolutional neural networks (CNNs).
The examples shown here will all use self-attention.

This discussion is designed to follow an overview of the architecture of the attention mechanism used in transformer neural networks,
such as Jay Alammar's excellent [Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/) or Yan Xu's [Step by Step into Transformer](https://medium.com/@YanAIx/step-by-step-into-transformer-79531eb2bb84), so that the reader is already familiar the the components of the attention mechanism and how they are connected.

We also want to take a moment to comment on the limitation of seq2seq models using RNN or LSTM modules.
Lilian Weng has some amazing blog posts, and in the [one about attention](https://lilianweng.github.io/posts/2018-06-24-attention/), she shows the following diagram and explains, "A critical and apparent disadvantage of this fixed-length context vector design is incapability of remembering long sentences."
As you get longer sequences, not only is it difficult to pack the meaning of all the words into this fixed-lengtch context vector, but it also gets harder for the network to remember the meaning from early words that were pocessed so many iterations ago. \
<img src="https://lilianweng.github.io/posts/2018-06-24-attention/encoder-decoder-example.png" width=750>

Before we dive in, let's load the packages we'll be using.  All examples are in Python using PyTorch.

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

## 1. Pattern matching with CNNs

We will discuss convolutional neural networks (CNNs) first.

### 1.1 Refresher on CNNs

#### 1.1.1 Using 2D CNNs for image processing

We will assume the reader has been exposed to CNNs.  You probably have seen images like this one from [Arden Dertat's Toward Data Science article](https://towardsdatascience.com/applied-deep-learning-part-4-convolutional-neural-networks-584bc134c1e2). \
<img src="https://miro.medium.com/v2/resize:fit:828/1*VVvdh-BUKFh2pwDD0kPeRA@2x.gif" width=600> \
Here, the 3x3 filter in green is shown matching up with all the different positions on the blue image to compute values shown in red on the right.

#### 1.1.2 Channels and 2D CNNs

It's easy to find descriptions and images like the one above.  For simplicity, these images show a CNN extracting features from an image with only one channel, such as a grey scale image.  One thing those images don't explain is what happens when the image has multiple channels.

If an image has multiple channels, such as a typical 3-channel color image with RGB channels, then our filter is a three-dimensional tensor.  A 3x3 filter on a RGB image has shape 3x3x3 and consists of 27 parameters (not counting a bias term).  A 3x3 filter on a 256 channel image has shape 3x3x256 and consists of 2,304 parameters (again without a bias value).

Here is an image from [Daphne Cornelisse's freeCodeCamp article](https://www.freecodecamp.org/news/an-intuitive-guide-to-convolutional-neural-networks-260c2de0a050/) that has an animation from [the Stanford cs231n course](https://cs231n.github.io/convolutional-networks/).
This animation shows two "3x3" filters (which are actually each shape 3x3x3) operating on a 5x5 image with three channels with padding (which acts like a 7x7x3 object).  The fact that there is a separate 3x3 set of filter parameters for each channel is clearly shown.
(The example also uses stride 2, which is why the filter jumps two pixels each time it moves right or down.) \
<img src="https://cdn-media-1.freecodecamp.org/images/gb08-2i83P5wPzs3SL-vosNb6Iur5kb5ZH43" width=600>

### 1.2 Using 1D CNNs for sequence processing

Most of the time when explaining multidimension objects, it makes sense to start small and work toward larger numbers of dimensions.
In this case, we started with 2D CNNs because of how well their use in computer vision has been documented.

Our CNN code in this notebook will be using 1D CNNs operating on sequences.
In this case, we will look at sequences of words.
In the 1D case, a filter is described as being size 3 if it looks at 3 words at a time.
Just as the shape of 2D CNN filters actually has one more dimension, and therefore is a 3D tensor, the shape of 1D CNN filters is a 2D tensor.
If the embedding representation for each word is 8 floating point numbers, then the size 3 filter will actually be a 3x8 tensor.

#### 1.2.1 Our vocabulary and embeddings

We are going to create a small hand-crafted example to make it easy to see what is happening inside our models.

Our vocabulary will consist of just seven words.  We will embed each word as a vector of eight numbers.

In [None]:
# Our vocabulary
the  = [0., 0., 0., 1., 2., 0., 1., 2.]
boy  = [1., 0., 0., 0., 6., 0., 777., 888.]
said = [0., 0., 1., 0., 0., 5., 5., 6.]
he   = [0., 1., 0., 0., 0., 8., 33., 44.]
was  = [0., 0., 1., 0., 7., 0., 9., 0.]
good = [0., 0., 0., 1., 0., 3., 3., 4.]
now  = [0., 0., 0., 1., 4., 4., 7., 8.]

We will use pandas to nicely display all 48 floats in the representation of a six-word sentence:

In [None]:
# Visualizing a sequence of words
seq1 = [the, boy, said, he, was, good]
df = pd.DataFrame(seq1, columns=['noun', 'pronoun', 'verb', 'other', 'val1', 'val2', 'other1', 'other2'], \
                  index = ['the','boy','said','he','was','good'])
df.T

Unnamed: 0,the,boy,said,he,was,good
noun,0.0,1.0,0.0,0.0,0.0,0.0
pronoun,0.0,0.0,0.0,1.0,0.0,0.0
verb,0.0,0.0,1.0,0.0,1.0,0.0
other,1.0,0.0,0.0,0.0,0.0,1.0
val1,2.0,6.0,0.0,0.0,7.0,0.0
val2,0.0,0.0,5.0,8.0,0.0,3.0
other1,1.0,777.0,5.0,33.0,9.0,3.0
other2,2.0,888.0,6.0,44.0,0.0,4.0


In our made-up example, we have assigned human-interpretable meanings to each number in the embedding.
Real embeddings aren't likely to have values that are so easy to understand.
We are using this kind of structure to make it easy to build examples that show what is happening in our neural networks.

#### 1.2.2 Building a 1D CNN to detect where we have "boy" and "he" two words apart

We are going to build a simple 1D CNN with two size 3 filters.
The first filter is going to detect the pronoun "he" that comes two words after the noun "boy."
In fact, if all nouns have a 1 in their first float and all pronouns have a 1 in their second float, then this filter will match every time any pronoun follows a noun two words later.
In order to get a large output when we see this pattern, we are going to put 10's in our filter in the right places, and zeros everywhere else.
We will need a 10 in the first position of the row corresponding to the first embedding float.
We will also need a 10 in the third position of the row corresponding to the second embedding float.
When the first word the filter is over is a noun and the third word the filter is over is a pronoun, the filter will output the value 20.
For all other patterns, it will output a smaller number.

So you can see other patterns in action, we will also create a second filter wich adds the `val1` value of the first word to the `val2` value of the third word.
Here is the code:

In [None]:
# Instantiate a 1D CNN and manually set its weights
conv = nn.Conv1d(in_channels=8, out_channels=2, kernel_size=(3,), stride=(1,), padding='valid', bias=False)

wt = torch.tensor([[[10, 0, 0], [0, 0, 10], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], \
                   [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 0, 1], [0, 0, 0], [0, 0, 0.]]])
conv.weight = nn.Parameter(wt)

print(conv)
print(conv.weight.shape)
print(conv.weight)

Conv1d(8, 2, kernel_size=(3,), stride=(1,), padding=valid, bias=False)
torch.Size([2, 8, 3])
Parameter containing:
tensor([[[10.,  0.,  0.],
         [ 0.,  0., 10.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.]],

        [[ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.],
         [ 1.,  0.,  0.],
         [ 0.,  0.,  1.],
         [ 0.,  0.,  0.],
         [ 0.,  0.,  0.]]], requires_grad=True)


Now, let's send our six-word setence into our CNN and see the output.
(Our CNN has padding turned off, so the output is going to be shorter than our input. In this case, it will be length 4.)

In [None]:
input = torch.tensor(seq1).T
out = conv(input)
print(out)

tensor([[ 0., 20.,  0.,  0.],
        [ 7., 14.,  0.,  3.]], grad_fn=<SqueezeBackward1>)


Our first filter is detecting a noun followed by a pronoun two words later.
We see a big peak with the value 20 on the second word, "boy."

The second filter is summing the `val1` from the current word with the `val2` from two words later.  It outputs different values at each position.

A subsequent layer could combine this information in interesting ways.

#### 1.2.3 A different pattern

If we add the word "now" into the middle, the words "boy" and "he" are no longer two words apart.

In [None]:
seq2 = [the, boy, now, said, he, was, good]
df = pd.DataFrame(seq2, columns=['noun', 'pronoun', 'verb', 'other', 'val1', 'val2', 'other1', 'other2'], \
                  index = ['the','boy','now','said','he','was','good'])
df.T

Unnamed: 0,the,boy,now,said,he,was,good
noun,0.0,1.0,0.0,0.0,0.0,0.0,0.0
pronoun,0.0,0.0,0.0,0.0,1.0,0.0,0.0
verb,0.0,0.0,0.0,1.0,0.0,1.0,0.0
other,1.0,0.0,1.0,0.0,0.0,0.0,1.0
val1,2.0,6.0,4.0,0.0,0.0,7.0,0.0
val2,0.0,0.0,4.0,5.0,8.0,0.0,3.0
other1,1.0,777.0,7.0,5.0,33.0,9.0,3.0
other2,2.0,888.0,8.0,6.0,44.0,0.0,4.0


What happens when we pass this second sentence into our CNN?

In [None]:
input = torch.tensor(seq2).T
out = conv(input)
print(out)

tensor([[ 0., 10., 10.,  0.,  0.],
        [ 6., 11., 12.,  0.,  3.]], grad_fn=<SqueezeBackward1>)


This sentence gets weak signals in certain places because we still have the words "boy" and "he" in the sentence,
but we don't get the big peak (with value 20) anymore.

We could add a new filter to detect nouns that are three words apart.
Just like a 2D CNN used with images will have many filters to detect different patterns (such as vertical lines, horizontal lines, circles, etc.),
a sophisticated 1D CNN used with sentences will need many filters to detect different patterns of words.

Now that we have seen what pattern matching in 1D CNNs looks like, let's try detecting the same pattern using self-attention.

## 2. Self-attention in transformers

<img src="https://lilianweng.github.io/posts/2018-06-24-attention/multi-head-attention.png" width=250>

### 2.1 Pattern matching with queries and keys

We will build our intuition about the power of self-attention by starting with just the query and key mechanism.
Let's create Q and K weight matrices that will allow us to detect our words "boy" and "he" that are two words apart,
just like we did with the 1D CNN.
From this example, we can see how the dot product of queries and keys does matching.

Because our word embedding size is 8, each of Q, K, and V are 8x8 tensors.
The default PyTorch implementation stacks them together into a single 24x8 tensor.
Recall that both the Query and Key matrices are going to be multiplied with our input tokens, creating length 8 vectors which we will dot product together.
We are going to put a weight of 10 in the second column of the Query matrix so that it matches the 1 that pronouns have in the second embedding float.
We will also put a weight of 10 in the first column of the Key matrix so it matches the 1 that nouns have in the first embedding float.
If we put both 10's in the first rows, they will match up when we do the dot product.
(We also have to put numbers in the weights for the Value matrix and the projection matrix that comes after self-attention, but will not focus on those weights yet.)

In [None]:
# Instantiate an attention mechanism and manuall set its weights
attn = nn.MultiheadAttention(embed_dim=8, num_heads=1, bias=False, batch_first=True)
wt = torch.zeros(24,8)
wt[0,1] = 10.
wt[8,0] = 10.
wt[16:24,0:8] = torch.eye(8)
attn.in_proj_weight = nn.Parameter(wt)
attn.out_proj.weight = nn.Parameter(torch.eye(8))

print(attn)
print("Query weights: \n", attn.in_proj_weight[0:8,:])
print("Key weights: \n", attn.in_proj_weight[8:16,:])

MultiheadAttention(
  (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=False)
)
Query weights: 
 tensor([[ 0., 10.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], grad_fn=<SliceBackward0>)
Key weights: 
 tensor([[10.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], grad_fn=<SliceBackward0>)


When doing self-attention, the same data is used with the Q, K, and V weights.
Let's load our original six-word sentence into variables and get ready to pass them into our self-attention module.
Remember, "he" is the fourth word, and we want it to pay attention to "boy," which is the second word.

In [None]:
df = pd.DataFrame(seq1, columns=['noun', 'pronoun', 'verb', 'other', 'val1', 'val2', 'other1', 'other2'], \
                  index = ['the','boy','said','he','was','good'])
df.T

Unnamed: 0,the,boy,said,he,was,good
noun,0.0,1.0,0.0,0.0,0.0,0.0
pronoun,0.0,0.0,0.0,1.0,0.0,0.0
verb,0.0,0.0,1.0,0.0,1.0,0.0
other,1.0,0.0,0.0,0.0,0.0,1.0
val1,2.0,6.0,0.0,0.0,7.0,0.0
val2,0.0,0.0,5.0,8.0,0.0,3.0
other1,1.0,777.0,5.0,33.0,9.0,3.0
other2,2.0,888.0,6.0,44.0,0.0,4.0


In [None]:
q = torch.tensor(seq1)
k = torch.tensor(seq1)
v = torch.tensor(seq1)
print(k)

tensor([[  0.,   0.,   0.,   1.,   2.,   0.,   1.,   2.],
        [  1.,   0.,   0.,   0.,   6.,   0., 777., 888.],
        [  0.,   0.,   1.,   0.,   0.,   5.,   5.,   6.],
        [  0.,   1.,   0.,   0.,   0.,   8.,  33.,  44.],
        [  0.,   0.,   1.,   0.,   7.,   0.,   9.,   0.],
        [  0.,   0.,   0.,   1.,   0.,   3.,   3.,   4.]])


Now let's pass our input data in and see what the self-attention weights look like:

In [None]:
attn_output, attn_weights = attn(q, k , v, average_attn_weights=False)
print(torch.round(attn_weights, decimals=4))
print("Attention weights for the fourth word: \n", torch.round(attn_weights[0,3,:], decimals=4))

tensor([[[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]]],
       grad_fn=<RoundBackward1>)
Attention weights for the fourth word: 
 tensor([0., 1., 0., 0., 0., 0.], grad_fn=<RoundBackward1>)


Success!  Our fourth word "he" has put 100% (when rounded) of it's attention on the word "boy."
For our simple hand-crafted attention weights, none of words other than "he" got any match with the queries and keys, so their attention is evenly spread across all six words.
This might not be exactly the behavior we would want in a real language transformer, but remember this is a toy example designed to show how the attention mechanism works with real numbers.

Our 1D CNN filter, which matched when "boy" and "he" were two words apart, didn't work when they were three words apart.

Can our self-attention do better?

In [None]:
q2 = torch.tensor(seq2)
k2 = torch.tensor(seq2)
v2 = torch.tensor(seq2)
print(q2)

tensor([[  0.,   0.,   0.,   1.,   2.,   0.,   1.,   2.],
        [  1.,   0.,   0.,   0.,   6.,   0., 777., 888.],
        [  0.,   0.,   0.,   1.,   4.,   4.,   7.,   8.],
        [  0.,   0.,   1.,   0.,   0.,   5.,   5.,   6.],
        [  0.,   1.,   0.,   0.,   0.,   8.,  33.,  44.],
        [  0.,   0.,   1.,   0.,   7.,   0.,   9.,   0.],
        [  0.,   0.,   0.,   1.,   0.,   3.,   3.,   4.]])


In [None]:
attn_output, attn_weights = attn(q2, k2 , v2, average_attn_weights=False)
print(torch.round(attn_weights, decimals=4))
print("Attention weights for the fifth word: \n", torch.round(attn_weights[0,4,:], decimals=4))

tensor([[[0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429],
         [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429],
         [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429],
         [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429],
         [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429],
         [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429]]],
       grad_fn=<RoundBackward1>)
Attention weights for the fifth word: 
 tensor([0., 1., 0., 0., 0., 0., 0.], grad_fn=<RoundBackward1>)


Yes!  Without changing our weights, our attention module successfully placed 100% of the attention for the word "he" (now the fifth word) onto the word "boy."
In fact, this set of weights would cause the word "he" to attend to the word "boy" no matter how far apart they are, or whether "boy" comes before or after the word "he."

#### 2.2 Attention output

So far, we've only focused on what each word in our self-attention module is paying attention to.
What is the output?

When we created our attention module, in addition to setting the query and key weights, we set the value weight matrix to be the identity matrix.
We also set the output projection to be the identity matrix.
This means that the output will be a copy of the input, weighted by the attention scores.

Let's pass our data in again, but this time examine our input and output tensors.

In [None]:
attn_output, attn_weights = attn(q2, k2 , v2, average_attn_weights=False)
print("Input data: \n", k)
print("Output data: \n", torch.round(attn_output, decimals=0))
print("Second word of the input: \n", torch.round(k[1,:], decimals=4))
print("Fifth word of the output: \n", torch.round(attn_output[4,:], decimals=4))

Input data: 
 tensor([[  0.,   0.,   0.,   1.,   2.,   0.,   1.,   2.],
        [  1.,   0.,   0.,   0.,   6.,   0., 777., 888.],
        [  0.,   0.,   1.,   0.,   0.,   5.,   5.,   6.],
        [  0.,   1.,   0.,   0.,   0.,   8.,  33.,  44.],
        [  0.,   0.,   1.,   0.,   7.,   0.,   9.,   0.],
        [  0.,   0.,   0.,   1.,   0.,   3.,   3.,   4.]])
Output data: 
 tensor([[  0.,   0.,   0.,   0.,   3.,   3., 119., 136.],
        [  0.,   0.,   0.,   0.,   3.,   3., 119., 136.],
        [  0.,   0.,   0.,   0.,   3.,   3., 119., 136.],
        [  0.,   0.,   0.,   0.,   3.,   3., 119., 136.],
        [  1.,   0.,   0.,   0.,   6.,   0., 777., 888.],
        [  0.,   0.,   0.,   0.,   3.,   3., 119., 136.],
        [  0.,   0.,   0.,   0.,   3.,   3., 119., 136.]],
       grad_fn=<RoundBackward1>)
Second word of the input: 
 tensor([  1.,   0.,   0.,   0.,   6.,   0., 777., 888.])
Fifth word of the output: 
 tensor([  1.,   0.,   0.,   0.,   6.,   0., 777., 888.],
       grad_

Our self-attention module copied the embedding for the word "boy" into the fifth location of the sequence where the word "he" is!
A more complex version of this example is what allows transformer language models to copy or blend word embeddings based on what each word is attending to.
If we use something other than an identity matrix for our value weights, we can also perform affine transformations on our input word embeddings.

### 2.3 Two attention heads

For our final example, we are going to use multi-headed attention, and show two different operations happening at the same time.
Remember that self-attention in a transformer is used in a residual network, which means the output is added to the input, then passed on to the next block.

We will instantiate a 2-headed attention module.
The first attention head will copy the `other1` and `other2` from the word "boy" to the position of the word "he."
This is similar behavior to what we've already seen.
The second attention head is going to attend the word "he" to itself.
It is going to output -1 times its values for `other1` and `other2`, which means it is going to erase the current values so that the first attention head can copy the values from the word "boy" into empty (zero) embedding slots!

The really interesting [Transformer Circuits work by Anthropic](https://transformer-circuits.pub/2021/framework/index.html) has seen this type of behavior.  They note, "we've seen hints that some MLP neurons and attention heads may perform a kind of 'memory management' role, clearing residual stream dimensions set by other layers by reading in information and writing out the negative version."

In [None]:
attn2 = nn.MultiheadAttention(embed_dim=8, num_heads=2, bias=False, batch_first=True)

wt = torch.zeros(24,8)
wt[0,1] = 10.
wt[8,0] = 10.
wt[4,1] = 10.
wt[12,1] = 10.
wt[16:20,4:8] = torch.eye(4)
wt[20:24,4:8] = torch.eye(4)
attn2.in_proj_weight = nn.Parameter(wt)

out_wt = torch.zeros(8,8)
out_wt[6:8,2:4] = torch.eye(2)
out_wt[6:8,6:8] = -1. * torch.eye(2)
attn2.out_proj.weight = nn.Parameter(out_wt)

print(attn)
# print(attn2.in_proj_weight)
print("Q1: ",attn2.in_proj_weight[0:4,:])
print("K1: ",attn2.in_proj_weight[8:12,:])
print("V1: ",attn2.in_proj_weight[16:20,:])
print("Q2: ",attn2.in_proj_weight[4:8,:])
print("K2: ",attn2.in_proj_weight[12:16,:])
print("V2: ",attn2.in_proj_weight[20:24,:])

print("Out1: ",attn2.out_proj.weight[0:4,:])
print("Out2: ",attn2.out_proj.weight[4:8,:])


MultiheadAttention(
  (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=False)
)
Q1:  tensor([[ 0., 10.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], grad_fn=<SliceBackward0>)
K1:  tensor([[10.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], grad_fn=<SliceBackward0>)
V1:  tensor([[0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1.]], grad_fn=<SliceBackward0>)
Q2:  tensor([[ 0., 10.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], grad_fn=<SliceBackw

Here are the attention values when we input our sentence.

In [None]:
attn_output, attn_weights = attn2(q2, k2 , v2, average_attn_weights=False)
print(torch.round(attn_weights, decimals=3))

tensor([[[0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430],
         [0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430],
         [0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430],
         [0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430],
         [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430],
         [0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430]],

        [[0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430],
         [0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430],
         [0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430],
         [0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430],
         [0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000],
         [0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430],
         [0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430, 0.1430]]],
       grad_fn=<RoundBackward1>)


In the first attention head, the word "he" is putting 100% attention on "boy."
In the second attention head, the word "he" is 100% attending to itself.

And now we show the input, the raw output, and the result of adding the output to the input.

In [None]:
print("Input:")
print(torch.round(k2, decimals=4))
print("Output:")
print(torch.round(attn_output, decimals=4))
print("Output summed with input:")
print(torch.round(k2 + attn_output, decimals=4))


Input:
tensor([[  0.,   0.,   0.,   1.,   2.,   0.,   1.,   2.],
        [  1.,   0.,   0.,   0.,   6.,   0., 777., 888.],
        [  0.,   0.,   0.,   1.,   4.,   4.,   7.,   8.],
        [  0.,   0.,   1.,   0.,   0.,   5.,   5.,   6.],
        [  0.,   1.,   0.,   0.,   0.,   8.,  33.,  44.],
        [  0.,   0.,   1.,   0.,   7.,   0.,   9.,   0.],
        [  0.,   0.,   0.,   1.,   0.,   3.,   3.,   4.]])
Output:
tensor([[  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0., 744., 844.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.]],
       grad_fn=<RoundBackward1>)
Output summed with input:
tensor([[  0.,   0.,   0.,   1.,   2.,   0.,   1.,   2.],
        [  1.,   0.,   0.,   0.,   6.,   0., 777., 888

The combination of our two attention heads has cleared the values that the fifth word "he" had in the 7th and 8th embedding floats and copied the values from the word "boy" into those spots.

To see this one head at a time, we can zero out the output from one head, then the other, and finally sum them.

In [None]:
out_wt = torch.zeros(8,8)
out_wt[6:8,2:4] = torch.eye(2)
# out_wt[6:8,6:8] = -1. * torch.eye(2)      # zeroed output of second attention head
attn2.out_proj.weight = nn.Parameter(out_wt)

attn_output1, attn_weights1 = attn2(q2, k2 , v2, average_attn_weights=False)

print(torch.round(attn_output1, decimals=0))

tensor([[  0.,   0.,   0.,   0.,   0.,   0., 119., 136.],
        [  0.,   0.,   0.,   0.,   0.,   0., 119., 136.],
        [  0.,   0.,   0.,   0.,   0.,   0., 119., 136.],
        [  0.,   0.,   0.,   0.,   0.,   0., 119., 136.],
        [  0.,   0.,   0.,   0.,   0.,   0., 777., 888.],
        [  0.,   0.,   0.,   0.,   0.,   0., 119., 136.],
        [  0.,   0.,   0.,   0.,   0.,   0., 119., 136.]],
       grad_fn=<RoundBackward1>)


In [None]:
out_wt = torch.zeros(8,8)
# out_wt[6:8,2:4] = torch.eye(2)      # zeroed output of first attention head
out_wt[6:8,6:8] = -1. * torch.eye(2)
attn2.out_proj.weight = nn.Parameter(out_wt)

attn_output2, attn_weights2 = attn2(q2, k2 , v2, average_attn_weights=False)

print(torch.round(attn_output2, decimals=0))

tensor([[   0.,    0.,    0.,    0.,    0.,    0., -119., -136.],
        [   0.,    0.,    0.,    0.,    0.,    0., -119., -136.],
        [   0.,    0.,    0.,    0.,    0.,    0., -119., -136.],
        [   0.,    0.,    0.,    0.,    0.,    0., -119., -136.],
        [   0.,    0.,    0.,    0.,    0.,    0.,  -33.,  -44.],
        [   0.,    0.,    0.,    0.,    0.,    0., -119., -136.],
        [   0.,    0.,    0.,    0.,    0.,    0., -119., -136.]],
       grad_fn=<RoundBackward1>)


In [None]:
print(attn_output1 + attn_output2)
print(torch.round(k2 + attn_output1 + attn_output2, decimals=4))

tensor([[  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0., 744., 844.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.]],
       grad_fn=<AddBackward0>)
tensor([[  0.,   0.,   0.,   1.,   2.,   0.,   1.,   2.],
        [  1.,   0.,   0.,   0.,   6.,   0., 777., 888.],
        [  0.,   0.,   0.,   1.,   4.,   4.,   7.,   8.],
        [  0.,   0.,   1.,   0.,   0.,   5.,   5.,   6.],
        [  0.,   1.,   0.,   0.,   0.,   8., 777., 888.],
        [  0.,   0.,   1.,   0.,   7.,   0.,   9.,   0.],
        [  0.,   0.,   0.,   1.,   0.,   3.,   3.,   4.]],
       grad_fn=<RoundBackward1>)


The values for the two attention heads are negatives of each other for every word except the fifth word.
For our critical fifth word, the second attention head outputs the negative of the current value for the seventh and eighth embedding floats, zeroing them out.
The first attention head copies the seventh and eight embedding floats from the word "boy."
We could have just as easily copied different embedding values into the seventh and eighth positions and done some linear transformation too.

Feel free to play around with the weights and see what kinds of behaviors you can create.

## 3. Conclusion

We hope these examples have given you some additional intuition why self-attention seems to be so powerful.

* One strength is that rules embedded into the Query and Key weights can work based on the meaning of the words as encoded into the word embeddings, and these rules will work regardless of where the words appear.
* If we want the model to care about the absolute or relative position of words, the queries and keys can also utilize the positional embedding data.
* In addition to calculating attention scores for which input(s) in the sequence to pay attention to, the Value weights allow the attention mechanism to transform or embed portions of the input, making more interesting data manipulations possible.
* Multi-headed attention increases the number of rules one attention unit can perform without increasing the computation cost (when implemented with the default hidden size that inversely scales with the number of heads).
* Finally, because self-attention is used in ResNet fashion, the attention mechanisms can copy values, blend values, and even clear values in the residual stream.

While the examples provided here were hand-crafted with interpretable word embeddings that had orthogonal meanings, such as the first float indicating whether the word was a noun or not, this is not a completely far fetched scenario.
Large language models often use embedding spaces that are size 1024 or even 2048.
In such high dimensional vector spaces, many vectors will be orthogonal -- or very close to orthogonal.
It seems intuitive that during training the transformer should have little difficulty finding orthogonal reqpresentations for pattern matching, if that is desired.