In [109]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pdb import set_trace

In [110]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, temperature, dropout=0.0):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, q, k, v):
        attn = torch.matmul(q / self.temperature, k.transpose(-2,-1))
        attn = F.softmax(attn, dim=-1)
        output = self.dropout(torch.matmul(attn, v))
        return output, attn

In [125]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()
        
        self.n_head = n_head
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        assert self.d_k==self.d_v and self.d_k*self.n_head == self.d_model, "d_model must equal d_k * n_head and d_v * n_head"
        
        self.w_qs = nn.Linear(self.d_model, n_head*d_k, bias=False)
        self.w_ks = nn.Linear(self.d_model, n_head*d_k, bias=False)
        self.w_vs = nn.Linear(self.d_model, n_head*d_v, bias=False)
        self.fc = nn.Linear(n_head * self.d_v, self.d_model, bias=False)
        
        self.attention = ScaledDotProductAttention(temperature=self.d_k ** 0.5)
        
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(self.d_model, eps=1e-6)
    
    def forward(self, q, k, v):
        batch_size = q.shape[0]
        len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
        
        residual = q
        
        q = self.w_qs(q).view(batch_size, len_q, self.n_head, self.d_k)
        k = self.w_ks(k).view(batch_size, len_k, self.n_head, self.d_k)
        v = self.w_vs(v).view(batch_size, len_v, self.n_head, self.d_v)
        
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # Change shape from (b, n, h, d_k/v) to (b, h, n, d_k/v)
        
        attn_out, attn = self.attention(q, k, v) # attn_out: (b, h, n, d_k), attn: (b, h, n, n)
        attn_out = attn_out.transpose(1, 2).contiguous().view(batch_size, len_q, -1) # Change shape from (b, h, n, d_k) to (b, n, h*d_k)
        
        mh_attn_out = self.fc(attn_out) # (b, n, d_model)
        
        return self.layer_norm(residual + self.dropout(mh_attn_out)), attn # (b, n, d_model)


In [119]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_in, d_hid, dropout=0.0):
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid)
        self.w_2 = nn.Linear(d_hid, d_in)
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        residual = x
        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        return self.layer_norm(residual + x)
        

In [120]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super().__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
    
    def forward(self, x):
        mh_attn_out, mh_att = self.slf_attn(x, x, x)
        return self.pos_ffn(mh_attn_out), mh_att
        

In [124]:
batch_size = 2
n = 3
d_model = 6
dropout = 0.5
n_head = 3
d_k = d_v = d_model // n_head
d_inner = d_model * 4

q = torch.randn((batch_size, n, d_model))
k = torch.randn((batch_size, n, d_model))
v = torch.randn((batch_size, n, d_model))
print(f"Shape of q, k, v: {q.shape}")

attention = ScaledDotProductAttention(temperature=1.)
attn_output, attn = attention(q, k, v)
print("\nScaled Dot Product Attention")
print(f"Attention output: \n{attn_output.shape}")
print(f"Attention: \n{attn.shape}")

print("\nMultihead Attention")
mh_attention = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
mh_att_out, mh_att = mh_attention(q, k, v)
print(f"Multihead attention output: \n{mh_att_out.shape}")
print(f"Multihead attention: \n{mh_att.shape}")

print("\nEncoder Layer")
encoder = EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
encoder_layer_out, mh_att = encoder(q)
print(f"Encoder layer output: \n{encoder_layer_out.shape}")
print(f"Multihead attention: \n{mh_att.shape}")


Shape of q, k, v: torch.Size([2, 3, 6])

Scaled Dot Product Attention
Attention output: 
torch.Size([2, 3, 6])
Attention: 
torch.Size([2, 3, 3])

Multihead Attention
Multihead attention output: 
torch.Size([2, 3, 6])
Multihead attention: 
torch.Size([2, 3, 3, 3])

Encoder Layer
Encoder layer output: 
torch.Size([2, 3, 6])
Multihead attention: 
torch.Size([2, 3, 3, 3])
