In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import math

In [None]:
class SelfAttention(nn.Module):
  def __init__(self, n_heads : int, d_embed :int, in_proj_bias = True, out_proj_bias = True):
    super().__init__()
    self.n_heads = n_heads
    self.d_head = d_embed // n_heads
    self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias = in_proj_bias)
    self.out_proj = nn.Linear(d_embed, d_embed, bias = out_proj_bias)
  def forward(self, x, causal_mask = False):
    # x: # (Batch_Size, Seq_Len, Dim)
    batch_size, sequence_length, d_embed = x.shape
    # (Batch_Size, Seq_Len, H, Dim / H)
    multihead_shape = (batch_size, self.n_heads, sequence_length, self.d_head)
    # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim * 3) -> 3 tensor of shape (Batch_Size, Seq_Len, Dim)
    q, k, v = self.in_proj(x).chunk(3, dim = -1)
    # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, H, Seq_Len, Dim / H)
    q = q.contiguous().view(*multihead_shape)
    # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, H, Seq_Len, Dim / H)
    k = k.contiguous().view(*multihead_shape)
    # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, H, Seq_Len, Dim / H)
    v = v.contiguous().view(*multihead_shape)

    # (Batch_Size, H, Seq_Len, Dim / H) @ (Batch_Size, H, Dim / H, Seq_Len) -> (Batch_Size, H, SeqLen, Seq_Len)
    attention_weight = q @ k.transpose(-1,-2)

    if casual_mask:
      # masked uper triangle matrix of "1" with diagonal location = 1
      mask = torch.ones_like(attention_weight).triu(1)
      # fill 0 value with negative inf
      weight.mask_fill_(mask == 0, -torch.inf)
    # Apply soft_max with the last dimension or the second last dimension
    weight = F.softmax(attention_weight, dim = -1)
    # (Batch_Size, H, Seq_Len, Seq_Len) @ (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, Seq_Len, H, Dim / H)
    output = (weight @ v).transpose(1,2)
    # (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, Seq_Len, Dim)
    output = output.contiguous().view(batch_size, sequence_length, d_embed)
    # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
    return self.out_proj(output)

In [None]:
class CrossAttention(nn.Module):
  def __init__(self, n_heads : int, d_embed : int, d_cross : int, in_proj_bias = True, out_proj_bias = True):
    super().__init__()
    self.n_heads = n_heads
    self.d_head = d_embed // n_heads
    self.querry_proj = nn.Linear(d_embed, d_embed, bias = in_proj_bias)
    self.key_proj = nn.Linear(d_cross, d_embed, bias = in_proj_bias)
    self.value_proj = nn.Linear(d_cross, d_embed, bias = in_proj_bias)
    self.out_proj = nn.Linear(d_embed, d_embed, bias = out_proj)

  def forward(self, x, y):
    # x (latent): # (Batch_Size, Seq_Len_Q, Dim_Q)
    # y (context): # (Batch_Size, Seq_Len_KV, Dim_KV) = (Batch_Size, 77, 768)
    batch_size, sequence_length, d_embed = x.shape
    # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
    query = self.querry_proj(x)
    # (Batch_Size, Seq_Len_KV, dim_KV) -> (Batch_Size, Seq_Len_KV, dim_Q )
    key = self.key_proj(y)
    # (Batch_Size, Seq_Len_KV, dim_KV) -> (Batch_Size, Seq_Len_KV, dim_Q )
    value = self.value_proj(y)

    multihead_shape = (batch_size, -1, self.n_heads, self.d_head)
    # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
    query = query.contigous().view(*multihead_shape).transpose(1,2)
    # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
    key = key.contiguous().view(*multihead_shape).transpose(1,2)
    #(Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
    value = value.contiguous().view(*multihead_shape).transpose(1,2)

    weight = query @ key.transpose(-1,-2)
    weight = weight / math.sqrt(self.d_head)
    weight = F.softmax(weight, dim = -1)
    output = weight @ value
    output = output.transpose(1,2).contiguous().view(batch_size, sequence_length, d_embed)
    return self.out_proj(output)
