In [None]:
import einops
import torch
from torch import nn
from torch import Tensor

In [None]:
"""
Applies a causal mask to attention scores, and returns masked scores.
Takes an input of size [batch, n_heads, query_pos, key_pos]
And outputs a tensor of size [batch, n_heads, query_pos, key_pos]
"""
def apply_causal_mask(
    attn_scores: Tensor,
    masked_value: float = float('-inf')
) -> Tensor:
    # Define a mask that is True for all positions we want to set probabilities to zero for
    mask = torch.triu(torch.ones(attn_scores.shape), diagonal=1).bool()
    # Apply the mask to attention scores and replace the masked values with the ignore value masked_value
    attn_scores.masked_fill_(mask, masked_value)

    return attn_scores

"""
Test case for your apply_causal_mask
"""
ignore = float('-inf')
test1 = apply_causal_mask(torch.tensor([
    [1.,2,3],
    [4,5,6],
    [7,8,9],
]), ignore)

assert torch.allclose(test1, torch.tensor(
    [[1., ignore, ignore],
        [4., 5., ignore],
        [7., 8., 9.]])), "Oh no it looks like your matrix doesnt pass test 1"
print(test1)

tensor([[1., -inf, -inf],
        [4., 5., -inf],
        [7., 8., 9.]])


In [None]:
class Attention(nn.Module):
    def __init__(self, num_heads: Tensor, dim_model: Tensor, dim_head: Tensor) -> None:
        super().__init__()

        # hyper parameters
        self.num_heads = num_heads
        self.dim_model = dim_model
        self.dim_head = dim_head

        # weights
        self.W_Q = nn.Parameter(torch.ones((num_heads, dim_model, dim_head)))
        self.W_K = nn.Parameter(torch.ones((num_heads, dim_model, dim_head)))
        self.W_V = nn.Parameter(torch.ones((num_heads, dim_model, dim_head)))
        self.W_O = nn.Parameter(torch.ones((num_heads, dim_head, dim_model)))

        # biases
        self.b_Q = nn.Parameter(torch.zeros((num_heads, dim_head)))
        self.b_K = nn.Parameter(torch.zeros((num_heads, dim_head)))
        self.b_V = nn.Parameter(torch.zeros((num_heads, dim_head)))
        self.b_O = nn.Parameter(torch.zeros((dim_model)))

    def init_testing_weights(self):
        with torch.no_grad():
            self.W_Q.normal_()
            self.W_K.normal_()
            self.W_V.normal_()
            self.W_O.normal_()
            self.b_Q.normal_()
            self.b_K.normal_()
            self.b_V.normal_()
            self.b_O.normal_()


    """
    Forward pass of the attention layer.
    Takes a tensor of shape [batch, tokens, dim_model]
    Outputs a tensor of shape [batch, tokens, dim_model]
    """
    def forward(self, x: Tensor) -> Tensor:
        batch_size = x.shape[0]
        tokens_size = x.shape[1]

        # Calculate query, key and value vectors
        q = einops.einsum(
            x, self.W_Q,
            "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head"
        ) + self.b_Q
        assert q.shape == torch.Size([batch_size, tokens_size, self.num_heads, self.dim_head])
        k = einops.einsum(
            x, self.W_K,
            "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head"
        ) + self.b_K
        assert k.shape == torch.Size([batch_size, tokens_size, self.num_heads, self.dim_head])
        v = einops.einsum(
            x, self.W_V,
            "batch posn d_model, nheads d_model d_head -> batch posn nheads d_head"
        ) + self.b_V
        assert v.shape == torch.Size([batch_size, tokens_size, self.num_heads, self.dim_head])

        # Calculate attention scores, then scale and mask, and apply softmax to get probabilities
        attn_scores = einops.einsum(
            q, k,
            "batch posn_Q nheads d_head, batch posn_K nheads d_head -> batch nheads posn_Q posn_K",
        )
        assert attn_scores.shape == torch.Size([batch_size, self.num_heads, tokens_size, tokens_size])
        attn_scores_masked = apply_causal_mask(attn_scores, float('-inf'))
        attn_probs = torch.softmax(attn_scores_masked / self.dim_head**0.5, dim=-1)

        # Take weighted sum of value vectors, according to attention probabilities
        z = einops.einsum(
            v, attn_probs,
            "batch posn_K nheads d_head, batch nheads posn_Q posn_K -> batch posn_Q nheads d_head",
        )
        assert z.shape == torch.Size([batch_size, tokens_size, self.num_heads, self.dim_head])

        # Calculate output (by applying matrix W_O and summing over heads, then adding bias b_O)
        attn_out = einops.einsum(
            z, self.W_O,
            "batch posn nheads d_head, nheads d_head d_model -> batch posn d_model",
        ) + self.b_O
        assert attn_out.shape == torch.Size([batch_size, tokens_size, self.dim_model])

        return attn_out

In [None]:
# Test your code against solution, no need to change this

batch_size = 12
tokens_dim = 20
dim_model = 30
dim_heads = 10
num_heads = 2
ground_truth = Attention(num_heads, dim_model, dim_heads)
user_model = Attention(num_heads, dim_model, dim_heads)
test = torch.rand((batch_size, tokens_dim, dim_model))

truth_output = ground_truth(test)
user_output = user_model(test)

assert torch.allclose(truth_output, user_output), "Uh oh your model doesn't give the same outputs"
print("passed all tests!")

passed all tests!
