In [2]:
import torch
import numpy as np
from torch import nn

In [3]:
b = 32
input_length = 64
embed_dim = 128

x = torch.randn(b, input_length, embed_dim)
print(x.shape)

torch.Size([32, 64, 128])


In [4]:
class SelfAttention(nn.Module):
  def __init__(self, in_dim, out_dim):
    super().__init__()

    self.K = nn.Linear(in_dim, out_dim, bias=False)
    self.Q = nn.Linear(in_dim, out_dim, bias=False)
    self.V = nn.Linear(in_dim, out_dim, bias=False)

  def forward(self, x):
    key = self.K(x)
    query = self.Q(x)
    value = self.V(x)

    d_k = key.shape[1]
    kq = (key @ query.transpose(1, 2)) / d_k**0.5

    context_length = x.shape[1]
    mask = torch.tril(torch.ones(context_length, context_length))
    kq_masked = kq.masked_fill(mask == 0, -torch.inf)
    softmax_kq = torch.nn.functional.softmax(kq_masked, dim=-1)
    out = softmax_kq @ value

    return out

In [5]:
attention = SelfAttention(128, 128)
out = attention(x)
print(out.shape)

torch.Size([32, 64, 128])


In [6]:
class NaiveMultiHeadAttention(nn.Module):
  def __init__(self, in_dim, out_dim, num_heads):
    super().__init__()

    self.heads = nn.ModuleList([SelfAttention(in_dim, out_dim//num_heads) for _ in range(num_heads)])
    self.linear_layer = nn.Linear(num_heads * (out_dim//num_heads), out_dim)

  def forward(self, x):
    out = torch.concat([head(x) for head in self.heads], dim=-1)
    return self.linear_layer(out)

multi_head_attention = NaiveMultiHeadAttention(128, 128, 8)
out = multi_head_attention(x)
print(out.shape)

torch.Size([32, 64, 128])


In [7]:
class MultiHeadAttention(nn.Module):
  def __init__(self, in_dim, out_dim, num_heads, dropout=0.2, bias=False):
    super().__init__()
    self.K = nn.Linear(in_dim, out_dim, bias=bias)
    self.Q = nn.Linear(in_dim, out_dim, bias=bias)
    self.V = nn.Linear(in_dim, out_dim, bias=bias)

    self.num_heads = num_heads
    self.in_dim = in_dim
    self.out_dim = out_dim

  def forward(self, x):
    key = self.K(x)
    query = self.Q(x)
    value = self.V(x)

    context_length = x.shape[1]

    key = key.view(b, input_length, self.num_heads, self.out_dim//self.num_heads).transpose(1, 2)
    query = query.view(b, input_length, self.num_heads, self.out_dim//self.num_heads).transpose(1, 2)
    value = value.view(b, input_length, self.num_heads, self.out_dim//self.num_heads).transpose(1, 2)

    out = (query @ key.transpose(2, 3))/ key.shape[-1]**0.5
    mask = torch.tril(torch.ones(context_length, context_length))
    out_masked = out.masked_fill(mask == 0, -torch.inf)
    softmax_kq = torch.nn.functional.softmax(out_masked, dim=-1)
    out = softmax_kq @ value
    out = out.transpose(1, 2).contiguous()

    return out.view(x.shape[0], x.shape[1], self.out_dim)
  
multi_head_attention = MultiHeadAttention(embed_dim, 8, 2)
multi_head_attention(x).shape

torch.Size([32, 64, 8])