In [45]:
import torch
import torch.nn as nn

In [46]:
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_kq, d_v):
        super().__init__()
        self.d_kq = d_kq
        self.q_weights = nn.Parameter(torch.rand(d_in, d_kq))
        self.k_weights = nn.Parameter(torch.rand(d_in, d_kq))
        self.v_weights = nn.Parameter(torch.rand(d_in, d_v))

    def forward(self, x, attention_mask=None):
        query = x @ self.q_weights
        key = x @ self.k_weights
        value = x @ self.v_weights

        if attention_mask is None:
            attention_mask = torch.zeros((x.shape[-2], x.shape[-2]))

        attn_scores = query @ torch.transpose(key, -1, -2)
        masked_attn_scores = attn_scores.masked_fill(attention_mask.bool(), -torch.inf)
        attn_weights = torch.softmax(masked_attn_scores / self.d_kq**0.5, dim=-1)
        return attn_weights @ value

In [47]:
class CrossAttention(nn.Module):
    def __init__(self, d_in, d_kq, d_v):
        super().__init__()
        self.d_kq = d_kq
        self.q_weights = nn.Parameter(torch.rand(d_in, d_kq))
        self.k_weights = nn.Parameter(torch.rand(d_in, d_kq))
        self.v_weights = nn.Parameter(torch.rand(d_in, d_v))

    def forward(self, x, encoder_x, attention_mask=None):
        query = x @ self.q_weights
        key = encoder_x @ self.k_weights
        value = encoder_x @ self.v_weights

        if attention_mask is None:
            attention_mask = torch.zeros((x.shape[-2], encoder_x.shape[-2]))

        attn_scores = query @ torch.transpose(key, -1, -2)
        masked_attn_scores = attn_scores.masked_fill(attention_mask.bool(), -torch.inf)
        attn_weights = torch.softmax(masked_attn_scores / self.d_kq**0.5, dim=-1)
        return attn_weights @ value

In [48]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_kq, d_v, num_heads, attn_type):
        super().__init__()
        self.attn_type = attn_type
        if attn_type == "self":
            self.heads = nn.ModuleList([SelfAttention(d_in, d_kq, d_v) for _ in range(num_heads)])
        elif attn_type == "cross":
            self.heads = nn.ModuleList([CrossAttention(d_in, d_kq, d_v) for _ in range(num_heads)])
        else:
            raise ValueError("attn_type should be either 'self' or 'cross'.")

    def forward(self, x, encoder_x=None, attention_mask=None):
        if self.attn_type == "self":
            return torch.cat([head(x, attention_mask) for head in self.heads], dim=-1)
        else:
            assert encoder_x is not None
            return torch.cat([head(x, encoder_x, attention_mask) for head in self.heads], dim=-1)

In [49]:
x = torch.rand((3, 10, 4))
num_heads = 4 # should be a factor of the input embedding dimension
d_in = x.shape[-1]

In [50]:
attn = MultiHeadAttention(d_in=d_in, d_kq = 2, d_v=d_in//num_heads, num_heads=num_heads, attn_type="self")
attn_out = attn(x)
attn_out

tensor([[[1.1020, 1.2812, 1.1678, 0.6591],
         [1.1390, 1.2630, 1.1815, 0.6455],
         [1.1505, 1.2572, 1.2308, 0.6521],
         [1.0164, 1.1798, 1.0717, 0.6096],
         [1.1310, 1.2302, 1.1732, 0.6401],
         [0.9395, 1.1880, 1.0003, 0.5949],
         [1.0072, 1.1757, 1.0807, 0.5926],
         [1.1049, 1.2723, 1.1463, 0.6492],
         [0.9459, 1.1627, 1.0231, 0.5941],
         [1.0344, 1.2051, 1.0586, 0.6200]],

        [[1.0565, 1.2376, 1.1092, 0.6287],
         [1.0131, 1.1884, 1.0658, 0.5924],
         [1.0106, 1.1686, 1.0631, 0.5924],
         [1.0133, 1.1804, 1.0842, 0.5865],
         [1.1307, 1.2880, 1.2340, 0.6702],
         [0.9289, 1.1238, 0.9584, 0.5365],
         [1.0138, 1.1989, 1.1022, 0.5883],
         [1.0265, 1.2034, 1.0854, 0.5961],
         [0.9708, 1.1396, 1.0344, 0.5606],
         [1.0020, 1.1658, 1.0357, 0.5864]],

        [[1.3801, 1.5805, 1.5027, 0.7440],
         [1.3828, 1.5768, 1.5108, 0.7442],
         [1.3772, 1.5872, 1.4833, 0.7425],
       

In [51]:
x.shape

torch.Size([3, 10, 4])

In [53]:
# masked self attention
attn_mask = torch.ones((x.shape[-2], x.shape[-2])) # mask everything out
masked_attn_out = attn(x, attention_mask=attn_mask)
masked_attn_out

tensor([[[nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan]],

        [[nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan]],

        [[nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan],
         [nan, nan, nan, nan]]], grad_fn=<CatBackward0>)

In [54]:
# causal mask self attention
attn_mask = torch.triu(torch.ones(x.shape[-2], x.shape[-2]), diagonal=1)
masked_attn_out = attn(x, attention_mask=attn_mask)
masked_attn_out

tensor([[[1.3479, 1.4108, 1.5758, 0.5809],
         [1.1998, 1.5331, 1.2518, 0.7061],
         [1.3952, 1.6170, 1.3963, 0.7079],
         [1.2700, 1.4358, 1.2939, 0.6728],
         [1.3060, 1.4643, 1.3055, 0.7131],
         [1.0887, 1.3328, 1.1396, 0.6430],
         [1.0930, 1.2644, 1.1313, 0.6113],
         [1.1566, 1.3594, 1.1821, 0.6647],
         [0.9795, 1.1972, 1.0469, 0.5930],
         [1.0344, 1.2051, 1.0586, 0.6200]],

        [[0.9822, 1.3813, 1.1298, 0.6932],
         [0.9153, 1.2364, 1.0199, 0.6251],
         [0.9466, 1.1497, 1.0459, 0.6128],
         [0.9421, 1.1454, 0.9991, 0.5804],
         [1.2732, 1.4739, 1.3936, 0.7595],
         [0.9869, 1.1869, 1.0523, 0.5858],
         [1.0773, 1.2857, 1.1756, 0.6160],
         [1.0588, 1.2822, 1.1167, 0.6142],
         [0.9847, 1.1697, 1.0418, 0.5593],
         [1.0020, 1.1658, 1.0357, 0.5864]],

        [[1.3049, 1.4533, 1.4994, 0.6550],
         [1.3706, 1.4679, 1.5526, 0.6610],
         [1.2700, 1.4445, 1.4605, 0.6651],
       

In [55]:
cross_attn = MultiHeadAttention(d_in=d_in, d_kq = 2, d_v=d_in//num_heads, num_heads=num_heads, attn_type="cross")
encoder_x = torch.rand((3, 5, 4))
cross_attn_out = cross_attn(x, encoder_x)
cross_attn_out

tensor([[[1.0636, 1.6300, 0.9427, 1.1702],
         [1.0666, 1.6138, 0.9201, 1.1986],
         [1.0707, 1.6189, 0.9491, 1.2104],
         [1.0348, 1.5422, 0.8956, 1.1456],
         [1.0635, 1.5966, 0.9176, 1.1955],
         [1.0049, 1.5263, 0.8846, 1.1035],
         [1.0299, 1.5262, 0.8897, 1.1575],
         [1.0605, 1.6176, 0.9208, 1.1739],
         [1.0130, 1.5172, 0.8936, 1.1126],
         [1.0369, 1.5615, 0.8887, 1.1417]],

        [[1.0180, 1.6868, 1.0566, 1.1646],
         [1.0207, 1.6812, 1.0537, 1.1679],
         [1.0210, 1.6801, 1.0542, 1.1668],
         [1.0205, 1.6802, 1.0546, 1.1686],
         [1.0138, 1.6928, 1.0631, 1.1617],
         [1.0275, 1.6719, 1.0459, 1.1747],
         [1.0200, 1.6812, 1.0555, 1.1691],
         [1.0198, 1.6825, 1.0543, 1.1683],
         [1.0236, 1.6753, 1.0519, 1.1707],
         [1.0219, 1.6795, 1.0521, 1.1678]],

        [[1.1801, 2.0325, 1.3801, 1.4393],
         [1.1801, 2.0337, 1.3822, 1.4408],
         [1.1801, 2.0284, 1.3734, 1.4372],
       

In [56]:
cross_attn_out.shape

torch.Size([3, 10, 4])