# Coding Attention in PyTorch!!!

---- 

In this tutorial, we will code a class that is capable of all **3** types of **Attention** that we have studied, **Self-Attention**, **Masked Self-Attention**, and **Encoder-Decoder Attention**. We'll also code a few lines that will make **Multi-Headed Attention** work.

In this tutorial, you will...

- **[Code an Attention Class!!!](#attention)** This class will be able to perform **Self-Attention**, **Masked-Self Attention**, and **Encoder-Decoder Attention**.

- **[Calculate Encoder-Decoder Attention Values!!!](#calculate)** We'll then use the class that we created, Attention, to calculate **Encoder-Decoder Attention** values for some sample data.
 
- **[Code Multi-Head Attention!!!](#multi)** We'll code **Multi-Head Attention**.

- **[Calculate Mult-Head Attention!!!!](#calcMulti)** Lastly, we calculate **Multi-Head Attention** values for some sample data.


## Import the modules that will do all the work

In [1]:
import torch ## torch let's us create tensors and also provides helper functions
import torch.nn as nn ## torch.nn gives us nn.module() and nn.Linear()
import torch.nn.functional as F ## This gives us the softmax()

## Code Attention
<a id="attention"></a>

In [2]:
class Attention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1):
        ## d_model = the number of embedding values per token.
        super().__init__()

        ## Initialize the Weights (W) that we'll use to create the
        ## query (q), key (k) and value (v) for each token
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.row_dim = row_dim
        self.col_dim = col_dim
    
    ## The only change from SelfAttention and attention is that
    ## now we expect 3 sets of encodings to be passed in...
    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        ## ...and we pass those sets of encodings to the various weight matrices.
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        ## Compute similarity scores: (q * k^T)
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        ## Scale the similarities by dividing by sqrt(k.col_dim)
        scaled_sims = sims / torch.tensor(k.size(self.col_dim) ** 0.5)

        if mask is not None:
            ## Here we are masking out things we don't want to pay attention to
            ##
            ## We replace values we wanted masked out
            ## with a very small negative number so that the SoftMax() function
            ## will give all masked elements an output value (or "probability") of 0.
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
        
        ## Apply softmax to determine what percent of each tokens' value to
        ## use in the final attention values.
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        ## Scale the values by their associated percentages and add them up.
        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

## Calculate Encoder-Decoder Attention
<a id="calculate"></a>

In [3]:
## create matrices of token encodings...
encodings_for_q = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_k = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_v = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

## set the seed for the random number generator
torch.manual_seed(42)

## create an attention object
attention = Attention(d_model=2, row_dim=0, col_dim=1)

## calculate encoder-decoder attention
attention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

## Code Multi-Head Attention
<a id="multi"></a>

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1, num_heads=1):
        super().__init__()

        ## create a bunch of attention heads
        self.heads = nn.ModuleList(
            [Attention(d_model=d_model, row_dim=row_dim, col_dim=col_dim) for _ in range(num_heads)]
        )

        self.col_dim = col_dim
    
    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v):
        ## run the data through all of the attention heads
        return torch.cat([head(encodings_for_q, encodings_for_k, encodings_for_v) for head in self.heads],
                          dim=self.col_dim)

## Calculate Multi-Head Attention
<a id="calcMulti"></a>

First, verify that we can still correctly calculate attention with a single head...

In [6]:
## set the seed for the random number generator
torch.manual_seed(42)

## create an attention object
multiHeadAttention = MultiHeadAttention(d_model=2,
                                        row_dim=0,
                                        col_dim=1,
                                        num_heads=1)

## calculate encoder-decoder attention
multiHeadAttention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<CatBackward0>)

Second, calculate attention with multiple heads...

In [7]:
## set the seed for the random number generator
torch.manual_seed(42)

## create an attention object
multiHeadAttention = MultiHeadAttention(d_model=2,
                                        row_dim=0,
                                        col_dim=1,
                                        num_heads=2)

## calculate encoder-decoder attention
multiHeadAttention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[ 1.0100,  1.0641, -0.7081, -0.8268],
        [ 0.2040,  0.7057, -0.7417, -0.9193],
        [ 3.4989,  2.2427, -0.7190, -0.8447]], grad_fn=<CatBackward0>)