# Lecture 16: Causal Self-Attention

### only previous tokens and current token are being taken under consideration during computation of the attention weights

In [813]:
import torch
from torch import nn

inputs = torch.tensor([[0.43, 0.15, 0.89],
                       [0.55, 0.87, 0.66],
                       [0.57, 0.85, 0.64],
                       [0.22, 0.58, 0.33],
                       [0.77, 0.25, 0.10],
                       [0.05, 0.80, 0.55]]
                       )

### initializing the variables

In [814]:
d_in = inputs.shape[-1]
d_out = 2

## Step 1:
### --> attention weights are going to be computed like L15 trainable weights (scaling by the square root of the keys dimension is performed and the softmax normalization also)

## Step 2:
### --> attention weights above the diagonal of the attention weights matrix are going to be masked out (set to 0)

## Step 3:
### --> after zeroing the values above the diagonal a re-normalization is being performed to the masked attention weight matrix

# Finally the masking is applied to the attention scores before dividing by the square root and before applying softmax in order to avoid data leakage as the softmax already involves future tokens when making the rows sum up to 1

### *Self Attention Class using Linear Layers*

In [815]:
class SelfAttentionV2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights_scaled = attn_scores / torch.sqrt(torch.tensor(keys.shape[-1]))
        attn_weights = torch.softmax(attn_weights_scaled, dim=-1)

        context_matrix = attn_weights @ values
        return context_matrix

### *Instance of Self Attention Class with manually computing Class matrices*

In [816]:
torch.manual_seed(789)
self_attention_v2 = SelfAttentionV2(d_in, d_out)

queries = self_attention_v2.W_query(inputs)
keys = self_attention_v2.W_key(inputs)
values = self_attention_v2.W_value(inputs)

attn_scores = queries @ keys.T
attn_weights_scaled = attn_scores / torch.sqrt(torch.tensor(keys.shape[-1]))
attn_weights = torch.softmax(attn_weights_scaled, dim=-1)

print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


### generating the masking matrix --> lower triangular matrix, all elements above the diagonal are set to 0

### context length specifying the amount of words in the sequence which is represented by the number of rows in the input matrix

In [817]:
context_length = attn_weights.shape[0]

mask_step1 = torch.ones(context_length, context_length)
mask_step2 = torch.tril(mask_step1)
print(f"Create Matrix with ones only:\n{mask_step1}\n")
print(f"Apply the tril function to generate triangular Matrix:\n{mask_step2}")

Create Matrix with ones only:
tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])

Apply the tril function to generate triangular Matrix:
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


### do step1 and step2 in one line of code

In [818]:
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


### multiplying the attention weights matrix with the mask simple matrix to apply the masking on the attention weights

In [819]:
masked_attn_weights = attn_weights * mask_simple
print(masked_attn_weights)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


### normalization of the lower triangular matrix using the sums of the remaining rows and dividing each remaining row value by the sum

In [820]:
sum_rows = masked_attn_weights.sum(dim=1, keepdim=True)
masked_attn_weights_norm = masked_attn_weights / sum_rows
print(f"aquiring the remaining sum of each row in the masked matrix:\n{sum_rows}\n")
print(f"deviding the masked attention weights matrix by the sum of its rows:\n{masked_attn_weights_norm}")

aquiring the remaining sum of each row in the masked matrix:
tensor([[0.1921],
        [0.3700],
        [0.5357],
        [0.6775],
        [0.8415],
        [1.0000]], grad_fn=<SumBackward1>)

deviding the masked attention weights matrix by the sum of its rows:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


### example calculation for the first two rows of the masked attention weights matrix

In [821]:
print("Example calculation for rows 2 and 3")
r2_c1, r2_c2 = 0.2041 / 0.37, 0.1659 / 0.37
r3_c1, r3_c2, r3_c3 = 0.2036 / 0.5357, 0.1659 / 0.5357, 0.1662 / 0.5357
print(f"{r2_c1:.4f}, {r2_c2:.4f}")
print(f"{r3_c1:.4f}, {r3_c2:.4f}, {r3_c3:.4f}")

Example calculation for rows 2 and 3
0.5516, 0.4484
0.3801, 0.3097, 0.3102


# Approach to prevent data leakage applying triangular mask to attention scores matrix 

### taking the attention scores

In [822]:
attn_scores

tensor([[ 0.2899,  0.0716,  0.0760, -0.0138,  0.1344, -0.0511],
        [ 0.4656,  0.1723,  0.1751,  0.0259,  0.1771,  0.0085],
        [ 0.4594,  0.1703,  0.1731,  0.0259,  0.1745,  0.0090],
        [ 0.2642,  0.1024,  0.1036,  0.0186,  0.0973,  0.0122],
        [ 0.2183,  0.0874,  0.0882,  0.0177,  0.0786,  0.0144],
        [ 0.3408,  0.1270,  0.1290,  0.0198,  0.1290,  0.0078]],
       grad_fn=<MmBackward0>)

### creating upper triangular matrix with ones, essentially the mask

In [823]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
mask

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

### applying the mask to the attention scores and replacing the ones with negativ infinity

In [824]:
attn_scores_masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
attn_scores_masked

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)

### dividing by the square root of the keys dimension

In [825]:
attn_scores_masked_scaled = attn_scores_masked / torch.sqrt(torch.tensor(keys.shape[-1]))
attn_scores_masked_scaled

tensor([[0.2050,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.3293, 0.1218,   -inf,   -inf,   -inf,   -inf],
        [0.3249, 0.1204, 0.1224,   -inf,   -inf,   -inf],
        [0.1868, 0.0724, 0.0733, 0.0132,   -inf,   -inf],
        [0.1544, 0.0618, 0.0624, 0.0125, 0.0556,   -inf],
        [0.2410, 0.0898, 0.0912, 0.0140, 0.0912, 0.0055]],
       grad_fn=<DivBackward0>)

### applying softmax function to normalize the along the rows, does NOT work when 0s are above the diagonal

In [826]:
attn_weights_masked = torch.softmax(attn_scores_masked_scaled, dim=1)
attn_weights_masked

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

### applying dropout within the causal attention mechanism, when weights are swithed of randomly the remaining weights are going to be rescaled by the dropout rate, hence the 2s in the matrix after dropout is applied

In [827]:
dropout = torch.nn.Dropout(0.5)
test_case = torch.ones(6, 6)
print(dropout(test_case))

tensor([[2., 0., 2., 2., 2., 2.],
        [2., 0., 2., 2., 0., 2.],
        [0., 2., 2., 0., 2., 0.],
        [2., 2., 2., 0., 0., 0.],
        [0., 0., 2., 2., 0., 2.],
        [0., 2., 2., 2., 0., 2.]])


### applying dropout to the attention weights matrix

In [828]:
attn_weights_do = dropout(attn_weights_masked)
attn_weights_do

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.4925, 0.4638, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.3968, 0.3775, 0.3941, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3058]],
       grad_fn=<MulBackward0>)

### computing the context matrix

In [829]:
context_matrix = attn_weights_do @ values
context_matrix

tensor([[ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [-0.0925,  0.0445],
        [-0.0199,  0.1769],
        [-0.0637, -0.0473]], grad_fn=<MmBackward0>)

### creating a batch of multiple inputs to do simultaneous processing of sequences

In [830]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)
print(batch)

torch.Size([2, 6, 3])
tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])


# Creating a Causal Attention Class

In [831]:
class CausalAttentionV1(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = torch.nn.Dropout(dropout)
        self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
        

    def forward(self, x):
        num_tokens = x.shape[1]
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores_masked = attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_scores_masked_scaled = attn_scores_masked / torch.sqrt(torch.tensor(keys.shape[-1]))
        attn_weights = torch.softmax(attn_scores_masked_scaled, dim=-1)
        attn_weights_do = self.dropout(attn_weights)
        context_matrix = attn_weights_do @ values
        return context_matrix

### --> instead of self.mask, self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) can be used which is essentially device agnostic code ensuring the triangular matrix is moved to the correct device automatically

### --> keys.transpose has to now do the transpose considering two inout batches which is why the expression changes comared to before

### creating an instance of the causal self attention class

In [832]:
torch.manual_seed(123)
context_length = batch.shape[1]
causal_attention_v1 = CausalAttentionV1(d_in, d_out, context_length, 0.0)
context_vecs = causal_attention_v1.forward(batch)
print(context_vecs.shape)
print(context_vecs)

torch.Size([2, 6, 2])
tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)
