# Self-Attention

## Motivation

The original attention mechanism allowed the decoder to focus on different inputs to the encoder.

Is there a way to make the attention mechanism of the **encoder** pay attention to the different tokens of **its own inputs**?


The answer, is yes, and we call that _self-attention_.

> In a self-attention layer, every input attends to every other input

> The output of a self-attention layer is the same number of same sized vectors as the input

> Each output is a weighted average of the inputs, weighted by how much 

> Self-attention can be applied in every layer, not just the first



Combining neural attention with a RNN based sequence to sequence model gives you the model shown below, made up of three core building blocks:
1. The encoder
1. The decoder
1. The cross-attention mechanism

# TODO add outline to dec-enc blocks

![](../images/RNN%20Seq2seq%20Attention.png)

The goal now is to try to see if we can replace any of the building blocks in this architecture. 

## Problems with RNN encoders
1. looking in sequential order is how we write text, but it’s not how we understand them. You may have the subject of a sentence “the man…” followed by something describing the man, separated by a very long sequence. It’s hard to persist the representation between the two ends so that they can interact to build the right encoding. RNNs bake in this order into their encoding. 
1. Long-distance dependencies are also hard to learn because of gradient problems. 
1. Lack of parallelisability. Future RNN states can’t be computed until all other preceding sequences. You have to do T steps of computation before you can make a gradient step. 

> Recurrence, despite how useful for encoding, is the cause of these problems. 

### Our goal now is to resolve the issue of time-dependency and long-term dependencies caused by RNNs

## Tackling the time dependency

The recurrent nature of RNNs means that the sequences can only be processed sequentially, which means that the time complexity of parameter update steps is $O(T)$. 
We need to find an alternative building block that can be used to encode/decode our sequences.

One alternative, is to use _word windows_ - a "window" of fixed size applied 

# TODO diagram Show very useful diagram of how many steps required to get to each layer. 

Windows in every position across text can be computed immediately, in parallel, with no dependence in time. This tackles parallelisation, but not long-range dependencies. 

# TODO diagram Shows useful diagram of which neuron can influence each between layers (pyramid-looking). 

As you can see from this diagram, each neuron is a combination of only a few neurons in the layer below. 

## Building block 2: Attention

Attention, in general treats word’s representation as a query to access and incorporate information from a set of values. 

In a RNN seq2seq model, the set of encoder states for the source sentence are the values, the decoder state was the query, and their dot product gave an attention score.

x

## Self-attention introduces another set of vectors, the _keys_, as well as queries and values found in attention

Recall that attention operates on queries (q), keys (k) and values (v). 
# TODO put q, k, v in attention notebook

In self-attention, $q$, $k$ and $v$ come from the same source (like the same sentence). 
Where do these come from? 

> Regardless of what form of attention you use, and what $q$, $k$, and $v$ are, you’re doing the same thing: dot product of queries and keys to get the “affinities” (alignment), then creating a affinity-weighted combination of the input values.

How is this different from a fully connected layer now that you’re connecting eveything to everything? 
1. Dynamic connectivity: The connection weights vary as a function of the input, because they are computed from the affinity between the keys and queries. In a neural network, the connections between each layer are the same for every input. Transformers learn the alignment function which determines the connections between layers for each example.
2. The parameterisation is very different. “It has this inductive bias that’s not just everything to everything feedforward”. 

You get a key, query, value for each word embedding. 
You can stack the self-attention layers and have k, q, v at each layer. 

## SELF-ATTENTION as described so far CANNOT yet be used as a building block. 

There are several problems which we need to address:

### Problem 1: Positional Encoding
The order of words obviously matters, but the sliding window approach currently contains no information about where each word appears. So we need to encode this. So far, it is an operation on sets rather than an operation on an ordered sequence.

Let’s bound? the sentence length as T. 
For each i \in {1, …, T} get a positional encoding p_i. 
Then just add that to each of the self-attention block inputs (q,k,v). 
Simple way to add this would be to just get q = v_tilde + p_i. You could concat them, but simple and common to just add. 

You can do the sinusoid thing to get positional encodings, which gives you pros: 
- periodicity indicates that absolute position is not as important
- Maybe can extrapolate to longer sequences
and cons: 
- It's not learnable - perhaps a better positional encoding could be learnt?
- Extrapolation doesn’t really work

More commonly nowadays is to learn the $p_i$. 
Set a $d x T$ (size by seq len) matrix $P$. 

Pros:
- Flexible: each position gets to be learned to fit the data.

Cons:
- You can’t extrapolate to sequences longer than $T$ because you haven’t learnt how to represent them. 

Other ways to encode $P$ include relative position between words of position representations that depend on syntax. 

# TODO diagram of positional encoding

### Problem 2: No Nonlinearities
There are no nonlinearities, so the sequential self-attentions just average averages rather than building hierarchically. 

Solution: add a feedforward layer between self-attention blocks. 

Intuition is that the feedforward layers “process the result of the attention”. 

# TODO improve this section

### Problem 3: Future Tokens in the Decoder Should be Hidden
Self-attention looks at the whole sequence at once, which is cheating for language modelling! It’s ok for that to happen in an encoder, but not in a decoder. So we mask the future in self-attention. 

One solution would be change the keys and values each timestep, but that would be inefficient. Instead, just set the attention affinities to $-inf$, which makes the attention weights 0.
# TODO improve

# TODO diagram

### Having addressed these problems, self-attention can now be used as a building block in the seq2seq model

As a recap:
- We removed recurrence by applying a sliding window
- We then introduced a positional encoding to the inputs to tell the model the position of each word
- We added nonlinearities between each layer of self attention to allow it to build hierarchical representations
- We apply masking to any decoder self-attention inputs to ensure that the model can't "cheat" and see the future of tokens which during evaluation/inference would not be visible

## Masked Attention
When we implement self-attention on the decoder, it should not see the future tokens.

This could be done using a for loop, but it is far more efficient to vectorise this operation.

We can implement that by setting their weights to zero.

That can be done by setting their affinities to negative inifinity. When they are normalised
 through the softmax function, they will become zero.

In [1]:
import torch
import torch.nn.functional as F

def mask_hidden_tokens(attention_affinities):
    T = attention_affinities.shape[0]
    mask = torch.tril((torch.ones(T, T))) # mask showing which tokens should have access to others
    return attention_affinities.masked_fill(mask == 0, float("-inf"))

T = 4
attention_affinities = torch.ones((T, T))

print(attention_affinities)

attention_affinities = mask_hidden_tokens(attention_affinities)
attention_weights = F.softmax(attention_affinities, dim=-1)

print(attention_weights)

  from .autonotebook import tqdm as notebook_tqdm


tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500]])


In this case, demo affinities between each token are set equal to one.

In the complete masked attention module, the affinities between keys and queries would be computed differently, depending on the type of attention implemented. 
E.g. by taking the dot product between keys and queries in dot product attention.

In [None]:
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)

input_token_embeddings = torch.zeros((B, T, C))
# simplest way to inject prev context is to average - TODO use RNN or transformer


attention_weights = torch.zeros((T, T)) # these are the attention weights




weighted_masked_inputs = input_attention_embeddings @ attention_weights


class LanguageModel(torch.nn.Module):
    def __init__(self):
        vocab_size = 100
        embedding_dim = 128
        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
        self.positional_encoding_embeddings = torch.nn.Embedding(sequence_length, embedding_dim)
        self.head = torch.nn.Linear(embedding_dim, vocab_size)

    def forward(self, idx):
        B, T = idx.shape

        token_embeddings = self.embedding()
        positional_encodings = self.positional_encoding_embeddings(torch.arange(T)) # T x D # this implementation limits the input length
        input_token_embeddings = token_embeddings + positional_encodings
        logits = self.head(input_token_embeddings)




Query: What am I looking for?
Key: What do I contain?

Affinities = f(Q, K)

All queries dot product with all keys to produce affinities

In [None]:
class SelfAttentionHead(torch.nn.Module):
    def __init__(self, embedding_size, head_size=16):
        super().__init__()
        self.key = torch.nn.Linear(embedding_size, head_size, bias=False)
        self.query = torch.nn.Linear(embedding_size, head_size, bias=False)
        self.value = torch.nn.Linear(embedding_size, head_size, bias=False) # no bias we just want vectors, not linear transforms
        # if we want the mask to be stored as part of the state of the model (but not as a parameter), then we should use self.register_buffer("mask", mask)

    def forward(self, X):

        # GET KEYS, VALUES AND QUERIES FOR THIS INPUT
        keys = self.key(X) # what do I contain?
        queries = self.query(X) # what am I looking for?
        values = self.value(X)
        
        # COMPUTE ATTENTION WEIGHTS
        keys = keys.transpose(-2, -1) # transpose T & D dimensions so that keys are 
        attention_affinities = queries @ keys 
        attention_affinities = mask_hidden_tokens(attention_affinities)
        attention_affinities /= self.embedding_size**0.5 # normalise by sequence size # this makes it "scaled dot product attention"
        attention_weights = F.softmax(attention_affinities)

        # COMPUTE ATTENTION WEIGHTED CONTEXT
        context = attention_weights @ values
        return context


In [None]:
class AttentionLanguageModel(torch.nn.Module):
    def __init__(self, vocab_size, embedding_size, sequence_length):
        super().__init__()
        self.token_embedding_table = torch.nn.Embedding(vocab_size, embedding_size)
        self.positional_encoding_table = torch.nn.Embedding(sequence_length, embedding_size)
        self.self_attention_head = SelfAttentionHead(embedding_size)
        self.head = torch.nn.Linear(embedding_size, vocab_size)

    def forward(self, X, targets=None):
        # X is (B, L)
        embeddings = self.token_embedding_table(X)
        positional_encodings = self.positional_encoding_table(X)
        final_input_embeddings = embeddings + positional_encodings
        context = self.self_attention_head(final_input_embeddings)
        logits = self.head(context)

        return logits

    def generate():
        # crop box size
        pass 





## Multi-Headed Self-Attention

The key, query, value projections allow the model to learn how to embed each input token and attend to others depending on what tokens appear. 
But they can only learn to do that in one way.
For any input, it's keys queries and values will only be represented in one way - the way defined by the parameters of that single self-attention head.

It is possible that learning many different ways to represent the same inputs may be useful.

So, what if we used multiple different self-attention heads in parallel? We call this multi-headed self-attention.

To avoid multiplying the number of parameters used in the self-attention layers, we scale down the size of their embeddings by the number of heads.

In [2]:
class MultiHeadSelfAttention(torch.nn.Module):
    def __init__(self, head_size=16, num_heads=4, embedding_size=32):
        super().__init__()
        self.heads = torch.nn.ModuleList(
            [SelfAttentionHead(embedding_size, head_size) for _ in range(num_heads)]
        )

    def forward(self, X):
        return torch.cat([head(X) for head in self.heads], dim=-1)

Now we can put multi-head self-attention into our language model

In [None]:
class MultiHeadSelfAttentionLanguageModel(torch.nn.Module):
    def __init__(self, vocab_size, sequence_length, num_heads=4, embedding_size=32):
        super().__init__()
        self.token_embedding_table = torch.nn.Embedding(vocab_size, embedding_size)
        self.positional_encoding_table = torch.nn.Embedding(sequence_length, embedding_size)
        self.self_attention_head = MultiHeadSelfAttention(num_heads=num_heads, embedding_size=embedding_size // num_heads) # TODO add multi-head self-attention module
        self.head = torch.nn.Linear(embedding_size, vocab_size)

    def forward(self, X, targets=None):
        # X is (B, L)
        embeddings = self.token_embedding_table(X)
        positional_encodings = self.positional_encoding_table(X)
        final_input_embeddings = embeddings + positional_encodings
        context = self.self_attention_head(final_input_embeddings)
        logits = self.head(context)

        return logits

    def generate():
        # crop box size
        pass 





# TODO diagram directed graph of indexed nodes, where each node is pointed to by only itself and those before it
