## Causal Self Attention

`Causal Attention` also known as **masked self attention** is a specialized form of self attention. It restricts a model to only consider previous and current inputs in a sequence when processing any given token. This is in contrast to the standard self attention method which allows access to the entire input sentence at once.

When computing attention scores the causal attention method ensures that the model only factors in tokens that occur at or before the current token in the sequence.

For each token processed we masked out the future tokens which come after the token in the input text.

In [2]:
import torch

In [3]:
torch.manual_seed(789)

class SelfAttention(torch.nn.Module):
    def __init__(self, d_in, d_out):
        super(SelfAttention, self).__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.W_query = torch.nn.Linear(in_features=d_in, out_features=d_out, bias=False)
        self.W_key = torch.nn.Linear(in_features=d_in, out_features=d_out, bias=False)
        self.W_value = torch.nn.Linear(in_features=d_in, out_features=d_out, bias=False)
        
    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        
        attention_score = queries @ keys.T
        attention_weights = torch.softmax( attention_score / (keys.shape[1] ** 0.5), dim=-1)
        context_vectors = attention_weights @ values
        
        return context_vectors

In [4]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # A     
   [0.55, 0.87, 0.66], # cat
   [0.57, 0.85, 0.64], # sat
   [0.22, 0.58, 0.33], # on
   [0.77, 0.25, 0.10], # the
   [0.05, 0.80, 0.55]] # mat
)

In [5]:
torch.manual_seed(789)

self_attention = SelfAttention(d_in=3, d_out=2)


queries = self_attention.W_query(inputs)
keys = self_attention.W_key(inputs)

attention_scores = queries @ keys.T
attention_weights = torch.softmax(attention_scores / (keys.shape[1] ** 0.5), dim=-1)


attention_weights

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

In [6]:
# creating a mask

context_length = attention_scores.shape[0]
mask_simple = torch.tril(
    torch.ones(context_length, context_length)
)

In [7]:
print(mask_simple) # traingular lower matrix with torch.ones at lower part

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [8]:
masked_simple = attention_weights * mask_simple # pointwise multiply

In [9]:
masked_simple

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)

In [10]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)

row_sums


tensor([[0.1921],
        [0.3700],
        [0.5357],
        [0.6775],
        [0.8415],
        [1.0000]], grad_fn=<SumBackward1>)

In [11]:
masked_simple_normalized = masked_simple / row_sums # broadcasting support to divide

masked_simple_normalized

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)

##### With softmax function

In [12]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

mask

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [13]:
masked = attention_scores.masked_fill(mask.bool(), -torch.inf)

masked

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)

In [14]:
attention_weights = torch.softmax(masked / (keys.shape[1] ** 0.5), dim=-1)

attention_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

##### dropout

In [15]:
torch.manual_seed(123)

dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)

print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


Dropout out and then scale up other by 1/0.5 = 2

In [16]:
torch.manual_seed(123)
print(dropout(attention_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.8966, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4921, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4350, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


## Causal Self Attention

In [23]:
import torch
class CausalAttention(torch.nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_in = d_in
        self.W_query = torch.nn.Linear(in_features=d_in, out_features=d_out, bias=qkv_bias)
        self.W_key = torch.nn.Linear(in_features=d_in, out_features=d_out, bias=qkv_bias)
        self.W_value = torch.nn.Linear(in_features=d_in, out_features=d_out, bias=qkv_bias)
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
    def forward(self, x):
        b, num_tokens, d_in = x.shape # Batch size, tokens, embedding
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attention_scores = queries @ keys.transpose(1, 2)
        attention_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )
        attention_weights = torch.softmax(
            attention_scores / ( keys.shape[-1] ** 0.5 ), dim=-1
        )
        attention_weights = self.dropout(attention_weights)
        context_vectors = attention_weights @ values
        return context_vectors

In [24]:
batch = torch.stack((inputs, inputs), dim=0)

print(batch.shape)

torch.Size([2, 6, 3])


In [25]:
torch.manual_seed(123)
context_length = batch.shape[1]
causal_attention = CausalAttention(d_in=3, d_out=2, context_length=context_length, dropout=0.5)
context_vectors = causal_attention(batch)

print(context_vectors)

tensor([[[ 0.0000,  0.0000],
         [-0.4368,  0.2142],
         [-0.7751,  0.0077],
         [-0.9140, -0.2769],
         [ 0.0000,  0.0000],
         [-0.6906, -0.0974]],

        [[-0.9038,  0.4432],
         [ 0.0000,  0.0000],
         [-0.2883,  0.1414],
         [-0.9140, -0.2769],
         [-0.4416, -0.1410],
         [-0.5272, -0.1706]]], grad_fn=<UnsafeViewBackward0>)
