# Concept Questions

1. What is the purpose of masking during self-attention?
2. What matrix do we mask, and what matrix do we use to perform masking?
3. What is the purpose of using dropout in a self-attention module?
4. What is the purpose of multi-head attention?

1. Masking can ensure that tokens don't attend to *future* tokens, which is useful because LLMs want to predict the next singular token based on *past* tokens.
2. We want to mask $QK^T$ in order for the weights of future tokens to be 0 during the weighted sum of values process. We use a lower triangular matrix of 1s to mask $QK^T$.
3. Dropout helps the model generalize better and reduces overfitting by having the model train on slightly malformed data. Then, when we actually use/test the model, we don't use dropout.
4. We use multiple heads because each head can learn a different pattern from the text. Multi-head attention improves performance/accuracy in general.

# Attention Review

In [1]:
import torch
import torch.nn as nn

Self-attention module from last lecture:

In [2]:
class SelfAttentionV2(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        # W_query is a linear function that maps a d_in dimensional vector
        # to a d_out dimensional vector
        # Mathematically, it is the same as a d_in by d_out matrix
        # Same for the key and value
        self.W_query = nn.Linear(d_in, d_out)
        self.W_key = nn.Linear(d_in, d_out)
        self.W_value = nn.Linear(d_in, d_out)

    def forward(self, x):
        # x is B x N x d_in
        # Q, K, and V are B x N x d_out
        Q = self.W_query(x)
        K = self.W_key(x)
        V = self.W_value(x)

        QKT = Q @ K.transpose(1, 2) # dim 1 is N, and dim 2 is d_out
        # QKT is B x N x N
        A = torch.softmax(QKT / (self.d_out ** 0.5), dim=-1)

        # A is B x N x N
        # V is B x N x d_out
        # A @ V is B x N x d_out
        context_vector = A @ V
        return context_vector

Using the self-attention module:

In [3]:
X_batch = torch.randn(40, 50, 768)
self_attention = SelfAttentionV2(768, 1024) # calls __init__
context_batch = self_attention(X_batch) # calls the forward function
print(context_batch.shape)

torch.Size([40, 50, 1024])


# Causal Attention

We can use `torch.tril` to create a lower triangular matrix used for masking. The function `torch.tril` zeros out the upper right portion of a matrix (we will only be using it for square matrices).

In [4]:
a = torch.tensor([
    [3, 4, 5],
    [1, 7, 3],
    [2, 4, 5]
])
b = torch.tril(a)
print(b)

tensor([[3, 0, 0],
        [1, 7, 0],
        [2, 4, 5]])


We can create a *mask* by creating a lower triangular matrix filled with ones.

In [6]:
a = torch.ones(3, 3)
b = torch.tril(a)
print(b)

# Using the mask
c = torch.tensor([
    [3, 4, 5],
    [1, 7, 3],
    [2, 4, 5]
])
d = b * c
print(d)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[3., 0., 0.],
        [1., 7., 0.],
        [2., 4., 5.]])


**Exercise 1:** Suppose the following matrix is $QK^T$. Apply the masking process.

In [9]:
n = 3
d_out = 100
QKT = torch.tensor([
    [0.1, 0.2, 0.3],
    [0.4, 0.5, 0.6],
    [0.7, 0.8, 0.9]
])
softmax_QKT = torch.softmax(QKT / (d_out ** 0.5), dim=-1)
mask = torch.tril(torch.ones(n, n))
masked_softmax_QKT = softmax_QKT * mask

# Normalizing so that the sum of each row is still 1
masked_attention = (masked_softmax_QKT / masked_softmax_QKT.sum(dim=-1, keepdim=True))
print(masked_attention)

tensor([[1.0000, 0.0000, 0.0000],
        [0.4975, 0.5025, 0.0000],
        [0.3300, 0.3333, 0.3367]])


Typically, masking is actually applied *before* softmax by setting the values we want to mask out to $-\infty$.

In [17]:
a = torch.tensor([
    [3, 4, 5],
    [1, 7, 3],
    [2, 4, 5]
]).float()
example_mask = torch.tril(torch.ones(3, 3))
print(example_mask)
b = a.masked_fill(example_mask == 0, float('-inf'))
print(b)
# We would then apply softmax afterwards so that the upper triangular portion becomes 0.

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[3., -inf, -inf],
        [1., 7., -inf],
        [2., 4., 5.]])


In [15]:
print(example_mask == 0)
# In the masked_fill line, everywhere that example_mask == 0 is true,
# it fills a -inf value

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])


**Exercise 2:** Apply the masking process before softmax.

In [21]:
n = 3
d_out = 100
QKT = torch.tensor([
    [0.1, 0.2, 0.3],
    [0.4, 0.5, 0.6],
    [0.7, 0.8, 0.9]
])

# Your code here
mask = torch.tril(torch.ones(n, n))
masked_QKT = QKT.masked_fill(mask == 0, float('-inf'))
softmax_QKT = torch.softmax(masked_QKT / (d_out ** 0.5), dim=-1)
print(softmax_QKT)

tensor([[1.0000, 0.0000, 0.0000],
        [0.4975, 0.5025, 0.0000],
        [0.3300, 0.3333, 0.3367]])


**Exercise 3:** Fill in missing parts of a causal attention module:

In [23]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout=0.0, qkv_bias=False):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        # bias in a linear layer means instead of just multiplying by a matrix,
        # we multiply by a matrix and then add a vector afterwards
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.context_length = context_length
        causal_mask = torch.tril(torch.ones(context_length, context_length))

        self.register_buffer("mask", causal_mask)
        # Buffer is useful for later so we don't have to worry about
        # using the correct device
        # Access mask in the forward function using self.mask

    def forward(self, x):
        B, N, D = x.shape
        # D should be equal to d_in
        # N should be equal to context_length
        Q = self.W_query(x)     # (B, N, d_out)
        K = self.W_key(x)       # (B, N, d_out)
        V = self.W_value(x)     # (B, N, d_out)

        attention_scores = Q @ K.transpose(1, 2) # QKT
        attention_scores = attention_scores.masked_fill(self.mask == 0, float('-inf')) # Apply masking
        attention_probs = torch.softmax(attention_scores / (d_out ** 0.5), dim=-1) # Apply softmax

        attention_probs = self.dropout(attention_probs)
        context_vector = attention_probs @ V
        return context_vector


**Exercise 4:** Create a tensor representing a batch of data with batch size 20, context length 100, and embedding size 512. Pass the tensor through a causal attention model with `d_out=768`.

In [26]:
# Your code here
x_batch = torch.randn(20, 100, 512)
causal_att_model = CausalAttention(x_batch.shape[2], 768, x_batch.shape[1])
contextual_embeddings = causal_att_model(x_batch)
print(contextual_embeddings.shape)

torch.Size([20, 100, 768])


# Multi-head Attention

We can implement multi-head attention by simply putting together multiple instances of single-head causal attention:

In [28]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.num_heads = num_heads
        self.d_head = d_out // num_heads
        self.heads = nn.ModuleList([
            CausalAttention(d_in=d_in, d_out=self.d_head,
                            context_length=context_length, dropout=dropout, qkv_bias=qkv_bias)
                            for _ in range(num_heads)])
        # ModuleList is needed instead of python list since the model will not register
        # the parameters of modules in a regular python list, which will be important
        # when training later on

    def forward(self, x):
        # x is B x N x d_in
        # head(x) is B x n x d_head
        # Concatenation is B x n x (d_head * num_heads)
        # which is the same as B x n x d_out
        return torch.cat([head(x) for head in self.heads], dim=-1)

**Exercise 5:** Create a tensor representing a batch of data with batch size 20, context length 100, and embedding size 512. Pass the tensor through a multi-head attention model with `d_out=768` and `num_heads=3`.

In [29]:
# Your code here
x_batch = torch.randn(20, 100, 512)
multihead_model = MultiHeadAttention(x_batch.shape[2], 768, 3, x_batch.shape[1])
contextual_embeddings = multihead_model(x_batch)
print(contextual_embeddings.shape)

torch.Size([20, 100, 768])


**Exercise 6:** Fill in missing parts of the following multi-head attention class:

In [32]:
class MultiHeadAttentionOneModule(nn.Module):
    def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
        self.d_in = d_in
        self.d_out = d_out
        self.num_heads = num_heads
        self.d_head = d_out // num_heads # Dimension of each head
        self.context_length = context_length
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        causal_mask = torch.tril(torch.ones(context_length, context_length))
        self.projection = nn.Linear(d_out, d_out) # Optional linear layer at the end

        self.register_buffer("mask", causal_mask)

    def forward(self, x):
        B, N, D = x.shape   # D is d_in, N  is context_length
        Q = self.W_query(x) # B x N x d_out
        K = self.W_key(x) # B x N x d_out
        V = self.W_value(x) # B x N x d_out

        # After Q.view: B x N x d_out -> B x N x num_heads x d_head
        # After .transpose(1, 2): B x num_heads x N x d_head
        Q = Q.view(B, N, self.num_heads, self.d_head).transpose(1, 2)
        K = K.view(B, N, self.num_heads, self.d_head).transpose(1, 2)
        V = V.view(B, N, self.num_heads, self.d_head).transpose(1, 2)
        # Q, K, V have size B x num_heads x N x d_head

        QKT = Q @ K.transpose(2, 3) # B x num_heads x N x N
        masked_QKT = QKT.masked_fill(self.mask == 0, float('-inf')) # Apply mask
        attention_probs = torch.softmax(masked_QKT / (self.d_head ** 0.5), dim=-1) # Apply softmax
        attention_probs = self.dropout(attention_probs)

        # attention_probs is B x num_heads x N x N
        # V is B x num_heads x N x d_head

        context_vector = attention_probs @ V # B x num_heads x N x d_head
        context_vector = context_vector.transpose(1, 2).contiguous().view(B, N, self.d_out)
        # context_vector.transpose(1, 2): B x N x num_heads x d_head
        # After .view: B x N x d_out
        return self.projection(context_vector)

**Exercise 7:** Create a tensor representing a batch of data with batch size 40, context length 80, and embedding size 768. Pass the tensor through a MultiHeadAttentionOneModule model with d_out=1536 and num_heads=3.

In [35]:
# Your code here
x = torch.randn(40, 80, 768)
multihead_model = MultiHeadAttentionOneModule(x.shape[2], 1536, 3, 80)
contextual_embeddings = multihead_model(x)
print(contextual_embeddings.shape)

torch.Size([40, 80, 1536])


**Optional:** Update BadLM to use multi-head attention.

In [37]:
# Input: B x N
# After embedding: B x N x emb_dim
# After self-attention: B x N x att_dim
# What we want: B x N x vocab_size
# This is an actually complete language model,
# except it is way too small and doesn't have many of the ideas
# that makes GPT actually work
# Later on, we will actually build a functional GPT
class BadLM_V2(nn.Module):
  def __init__(self, context_length, vocab_size, emb_dim, att_dim, num_heads):
    super().__init__()
    self.context_length = context_length
    self.emb_dim = emb_dim
    self.att_dim = att_dim
    self.num_heads = num_heads
    self.vocab_size = vocab_size
    self.token_embs = nn.Embedding(vocab_size, emb_dim)
    self.pos_embs = nn.Embedding(vocab_size, emb_dim)
    self.att = MultiHeadAttentionOneModule(emb_dim, att_dim, num_heads, context_length)
    self.prediction_layer = nn.Linear(att_dim, vocab_size)

  def forward(self, x):
    # x is B x N
    embedding = self.token_embs(x) + self.pos_embs(torch.arange(self.context_length))
    context_embedding = self.att(embedding)
    prediction = self.prediction_layer(context_embedding)
    return prediction


In [38]:
import tiktoken
from torch.utils.data import Dataset, DataLoader

# Dataset class
class MyData(Dataset):
    # Init function, called when the dataset is created
    # dataset = MyData(text, tokenizer, context_length=4, stride=1)
    def __init__(self, text, tokenizer, context_length, stride=1):
        self.input_ids = []
        self.target_ids = []
        token_ids = tokenizer.encode(text)
        for i in range(0, len(token_ids) - context_length, stride):
            self.input_ids.append(torch.tensor(token_ids[i : i + context_length]))
            self.target_ids.append(torch.tensor(token_ids[i + 1 : i + context_length + 1]))

    # Length function
    # len(dataset)
    def __len__(self):
        return len(self.input_ids)

    # Get item function
    # dataset[idx]
    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

# Dataloader
def my_batch(text, batch_size, context_length, stride, shuffle=True, drop_last=True, num_workers=0):
    tokenizer = tiktoken.get_encoding("gpt2")

    # Create the dataset object
    dataset = MyData(text, tokenizer, context_length, stride)

    # Use the DataLoader library to create a dataloader that batches the data
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=shuffle,
                            drop_last=drop_last,
                            num_workers=num_workers)

    return dataloader

In [40]:
text = "This is a very useless sentence that talks about itself."
batch_size = 2
context_length = 4
loader = my_batch(text, batch_size, context_length, 1)
vocab_size = 50257
model = BadLM_V2(context_length, vocab_size, 768, 768, 3)
predictions = []
for input, target in loader:
  output = model(input)
  tokens = torch.argmax(output, dim=-1)
  predictions.append(tokens)

print(predictions)

[tensor([[15005, 20879, 20686, 44703],
        [46038, 41755,  1401, 13061]]), tensor([[22448, 23515, 23563, 18619],
        [ 5891, 45383, 49042, 18885]]), tensor([[ 5891,  9232, 48254, 16371],
        [29260, 20879, 18076, 35327]])]
