In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, d_model):
        super().__init__()
        self.W_Q = nn.Linear(in_features=d_model, out_features=d_model)
        self.W_K = nn.Linear(in_features=d_model, out_features=d_model)
        self.W_V = nn.Linear(in_features=d_model, out_features=d_model)
        self.W_O = nn.Linear(in_features=d_model, out_features=d_model)

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model//num_heads

    def scaled_dot_product_attention(self, Q, K, V):
        """
        Args:
            Q: (Batch, num_head, seq_len_q, d_k)
            K: (Batch, num_head, seq_len_k, d_k) 
            V: (Batch, num_head, seq_len_k, d_v)
        Returns:
            attention_output: (Batch, num_head, seq_len_q, d_k)
        """
        # scores = torch.einsum("bnqd,bnkd -> bnqk", Q, K) / torch.sqrt(self.d_k) # this causes a bug as torch operations only apply on tensors, not int. But creating another tensor is not efficient. so just use math
        scores = torch.einsum("bnqd,bnkd -> bnqk", Q, K) / np.sqrt(self.d_k)
        attention_weights = F.softmax(scores, dim=-1)  
        attention_output = torch.einsum("bnqk,bnkv -> bnqv", attention_weights, V)

        return attention_output

    def forward(self,x):
        """ 
        Args:
            x: (Batch, seq_len, d_model)
        Returns:
            multi_head_attention: (Batch, seq_len, d_model)
        """
        batch = x.shape[0]
        Q = self.W_Q(x).reshape(batch, -1, self.num_heads, self.d_k).transpose(1,2)
        K = self.W_K(x).reshape(batch, -1, self.num_heads, self.d_k).transpose(1,2)
        V = self.W_V(x).reshape(batch, -1, self.num_heads, self.d_k).transpose(1,2)
        multi_head_attention = self.scaled_dot_product_attention(Q,K,V)
        
        multi_head_attention = multi_head_attention.transpose(1,2).reshape(batch, -1, self.d_model)
        return self.W_O(multi_head_attention)

In [4]:
class FeedForwardNetwork(nn.Module):
    def __init__(self,d_model,d_ff, dropout=0.1):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff,d_model),
            nn.Dropout(dropout),
        )

    def forward(self,x):
        return self.ffn(x)

In [12]:
class PositionalEncoding(nn.Module):
    def __init__(self, seq_len, d_model):
        super().__init__()
        self.seq_len = seq_len
        self.d_model = d_model #512 int

        pe = torch.ones([seq_len,d_model]) #2d array
        pos = torch.arange(0,seq_len).view(-1,1) # without view it is (seq_len,) 
        i = 2*torch.arange(0,d_model//2).view(1,-1) # without view it is (d_model//2,)

        pe[:, ::2] = torch.sin(pos/(10000**(i/d_model))) # to broadcast the shape should be (seq_len,1), (1,d_model//2)
        pe[:, 1::2] = torch.cos(pos/(10000**(i/d_model)))
        # self.pe is now (seq_len,d_model)
        pe.unsqueeze_(0) #need _ for in_place
        # self.pe is now (1,seq_len, d_model) we need this to math dimension

        self.register_buffer("pe", pe)



    def forward(self,x):
        """ 
        Args:
            x: (Batch, seq_len, d_model)
        """
        return x + self.pe

In [None]:
class Encoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.mha = MultiHeadAttention(num_heads=num_heads, d_model=d_model)
        self.norm_1 = nn.LayerNorm(normalized_shape=d_model)
        self.ffn = FeedForwardNetwork(d_model=d_model, d_ff=d_ff)
        self.norm_2 = nn.LayerNorm(normalized_shape=d_model)

    def forward(self,x):
        x_1 = self.mha(x)
        x_2 = self.norm_1(x + x_1)
        x_3 = self.ffn(x_2)
        x_out = self.norm_2(x_3 + x_2)
        return x_out


In [None]:
class BERT(nn.Module):
    def __init__(self):
        super().__init__()
        pass

    def forward(self,x):
        pass
