In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class SelfAttention(nn.Module):
    # nn.Module is base class for all neural network modules made with PyTorch

    def __init__(self, d_model=2, row_dim=0, col_dim=1):
        ## d_model = the number of embedding values per token, note in Attention Is All You Need, d_model=512
        ## usuall row_dim is the batch size (for now, no batches)

        super().__init__() # call parent's init method to inherit

        ## Initialize the Weights (W) that we'll use to create the query (q), key (k), and value (v) for each token
        ## Attention Is All You Need did not use bias terms so follow that implementation here

        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False) # currently untrained weights needed to calculate query values
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.row_dim = row_dim
        self.col_dim = col_dim
    
    def forward(self, token_encodings, mask=None):
        # used to caluclate the self-attention values for each token
        # create the query, key, value matrices using the encoding numbers associated with each token (after word embedding and position encoding)
        q = self.W_q(token_encodings) # this will do the matrix multiplication of encoded values and query weights
        k = self.W_k(token_encodings)
        v = self.W_v(token_encodings)

        # now calculate attention using formula Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V

        # computer the similarities scores: (q * k^T)
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        # scale the similarities by dividing by sqrt(k.col_dim)
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        if mask is not None: # see below for explainer on masked self-attention
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9) # we replace values we wanted masked out
            # with a very small negative number so that the SoftMax() function will give all masked elements an output value (or "probability") of 0

        # apply softmax to determine what percent of each tokens' value to use in the final attention values
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        # scale values by their associated percentages and sum them up
        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

In [None]:
# create a matrix of token encodings...
encodings_matrix = torch.tensor([[1.16, 0.23],
                                 [0.57, 1.36],
                                 [4.41, -2.16]])

# set the seed for the random number generator
torch.manual_seed(42)

# create a basic self-attention ojbect
selfAttention = SelfAttention(d_model=2,
                               row_dim=0,
                               col_dim=1)

# calculate basic attention for the token encodings
selfAttention(encodings_matrix)

# calculate masked self-attention
mask = torch.tril(torch.ones(3, 3))
mask = mask == 0
selfAttention(encodings_matrix, mask)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [5]:
# print out the weight matrix that creates the queries
selfAttention.W_q.weight.transpose(0, 1)

tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<TransposeBackward0>)

In [6]:
# print out the weight matrix that creates the keys
selfAttention.W_k.weight.transpose(0, 1)

tensor([[-0.1549, -0.3443],
        [ 0.1427,  0.4153]], grad_fn=<TransposeBackward0>)

In [7]:
# print out the weight matrix that creates the values
selfAttention.W_v.weight.transpose(0, 1)

tensor([[ 0.6233,  0.6146],
        [-0.5188,  0.1323]], grad_fn=<TransposeBackward0>)

In [8]:
# calculate the queries
selfAttention.W_q(encodings_matrix)

tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>)

In [9]:
# calculate the keys
selfAttention.W_k(encodings_matrix)

tensor([[-0.1469, -0.3038],
        [ 0.1057,  0.3685],
        [-0.9914, -2.4152]], grad_fn=<MmBackward0>)

In [11]:
# calculate the values
selfAttention.W_v(encodings_matrix)

tensor([[ 0.6038,  0.7434],
        [-0.3502,  0.5303],
        [ 3.8695,  2.4246]], grad_fn=<MmBackward0>)

In [12]:
q = selfAttention.W_q(encodings_matrix)
q

tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>)

In [13]:
k = selfAttention.W_k(encodings_matrix)
k

tensor([[-0.1469, -0.3038],
        [ 0.1057,  0.3685],
        [-0.9914, -2.4152]], grad_fn=<MmBackward0>)

In [14]:
sims = torch.matmul(q, k.transpose(dim0=0, dim1=1))
sims

tensor([[-0.0990,  0.0648, -0.6523],
        [-0.4022,  0.4078, -3.0024],
        [ 0.4842, -0.6683,  4.0461]], grad_fn=<MmBackward0>)

In [15]:
scaled_sims = sims / (torch.tensor(2)**0.5)
scaled_sims

tensor([[-0.0700,  0.0458, -0.4612],
        [-0.2844,  0.2883, -2.1230],
        [ 0.3424, -0.4725,  2.8610]], grad_fn=<DivBackward0>)

In [16]:
attention_percents = F.softmax(scaled_sims, dim=1)
attention_percents

tensor([[0.3573, 0.4011, 0.2416],
        [0.3410, 0.6047, 0.0542],
        [0.0722, 0.0320, 0.8959]], grad_fn=<SoftmaxBackward0>)

In [17]:
torch.matmul(attention_percents, selfAttention.W_v(encodings_matrix))

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

#### Some notes on self-attention

Transformers that only use self-attention are called encoder-only transformers and the context aware embeddings they create are super useful. In additon to clustering sentences and documents, we can use context aware embeddings as inputs to a normal neural network that classifies the sentiment of the input.

Alternatively, we could use the context aware embeddings as variables in a logistic regression model that does classification.

#### Some notes on masked self-attention

In contrast, let's compare decoder-only transformers. Like an encoder-only transformer, a decoder-only transformer starts out with word embeddings and positional encoding, but instead of using self-attention, it uses masked self-attention. 

Big difference: self-attention can look at words before and after the word of interest. Masked self-attention ignores words that come after the word of interest.

This means decoder-only transformers can be trained to do a good job at *generating* responses to prompts. We only give the context up to a certain point during training and then modify weights in the model until it generates the rest of the sentence correctly.

While encoder-only transformers create context aware embeddings, decoder-only transformers create generative inputs that can be plugged into a simple neural net that generates new tokens.

In [19]:
# calculate masked self-attention using formula MaskedAttention(Q, K, V) = softmax((QK^T / sqrt(d_k)) + M) * V
# add 0s to prior values and -infinity to future values (0% similarity to tokens coming after)

#### Some notes on encoder-decoder transformers

First transformers had an encoder that used self-attention and a decoder that used masked self-attention. Encoder and decoder were connected to calculate encoder-decoder attention.

Encoder-decoder attention (also called cross-attention) uses output from encoder to calculate keys and values and the queries are calculated from the output of the masked self-attention generated by the decoder. 

This first transformer was based on Seq2Seq model (first designed to translate text).

This style is still used in multi-modal models - in a multi-model model, we might have an encoder that has been trained on images or sound and the context aware embeddings could be fed into a text based decoder via cross-attention in order to generate captions or respond to audio prompts.




#### Based on the DeepLearning.ai course 'Attention in Transformers'
Link here: https://learn.deeplearning.ai/courses/attention-in-transformers-concepts-and-code-in-pytorch/information 

Further details and good explainer: https://lena-voita.github.io/nlp_course/seq2seq_and_attention.html