In [None]:
import math
import torch
from torch import nn
from torch.nn import functional as F

# Multi Head Attention

In [None]:
class MultiHeadAttention(nn.module):
    def __init__(self, model_dim, num_heads, attn_dropout, proj_drop):
        super().__init__()
        self.num_heads = num_heads
        head_dim = model_dim // num_heads
        self.scale = model_dim ** -0.5  # sqrt(d)
        self.qkv_proj = nn.Linear(model_dim, model_dim * 3, bias=False)
        self.attn_dropout = nn.Dropout(attn_dropout)
        self.proj = nn.Linear(model_dim, model_dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, T, D = x.shape
        qkv = self.qkv_proj(x)  # B, T, D * 3
        qkv = qkv.reshape(B, T, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, T, head_dim
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = torch.mm(q, k.transpose(-1, -2))
        attn = attn * self.scale # B, num_heads, T, T
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)
        x = torch.mm(attn, v).transpose(1, 2).reshape(B, T, -1)  # B, num_heads, T, head_dim -> B, T, num_heads, head_dim -> B, T, D

        return self.proj_drop(self.proj(x))


# Positional Encoding

In [None]:
class PositionalEncoding(nn.module):
    """
    pe(2i) = sin(pos / 10000 ** (2i/ model_dim))
    pe(2i+1) = cos(pos / 10000 ** (2i/ model_dim))
    """
    def __init__(self, model_dim, max_len):
        super().__init__()
        pe = torch.zeros(max_len, model_dim)
        position = torch.arrange(0, max_len).unsqueeze(1)  # max_len * 1
        div_term = torch.exp((-math.log(10000.0) / model_dim) * torch.arange(0, model_dim, 2).float())
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        # pe: (max_len, dim)
        self.register_buffer('pe', pe)

    def forward(self, x): # B T C
        x = x + self.pe[:x.size(1), :].unsqueeze(0)
        return x
        