In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, d_model):
        super().__init__()
        self.W_Q = nn.Linear(in_features=d_model, out_features=d_model)
        self.W_K = nn.Linear(in_features=d_model, out_features=d_model)
        self.W_V = nn.Linear(in_features=d_model, out_features=d_model)
        self.W_O = nn.Linear(in_features=d_model, out_features=d_model)

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model//num_heads

    def scaled_dot_product_attention(self, Q, K, V):
        """
        Args:
            Q: (Batch, num_head, seq_len_q, d_k)
            K: (Batch, num_head, seq_len_k, d_k) 
            V: (Batch, num_head, seq_len_k, d_v)
        Returns:
            attention_output: (Batch, num_head, seq_len_q, d_k)
        """
        # scores = torch.einsum("bnqd,bnkd -> bnqk", Q, K) / torch.sqrt(self.d_k) # this causes a bug as torch operations only apply on tensors, not int. But creating another tensor is not efficient. so just use math
        scores = torch.einsum("bnqd,bnkd -> bnqk", Q, K) / np.sqrt(self.d_k)
        attention_weights = F.softmax(scores, dim=-1)  
        attention_output = torch.einsum("bnqk,bnkv -> bnqv", attention_weights, V)

        return attention_output

    def forward(self,x):
        """ 
        Args:
            x: (Batch, seq_len, d_model)
        Returns:
            multi_head_attention: (Batch, seq_len, d_model)
        """
        batch = x.shape[0]
        Q = self.W_Q(x).reshape(batch, -1, self.num_heads, self.d_k).transpose(1,2)
        K = self.W_K(x).reshape(batch, -1, self.num_heads, self.d_k).transpose(1,2)
        V = self.W_V(x).reshape(batch, -1, self.num_heads, self.d_k).transpose(1,2)
        multi_head_attention = self.scaled_dot_product_attention(Q,K,V)
        
        multi_head_attention = multi_head_attention.transpose(1,2).reshape(batch, -1, self.d_model)
        return self.W_O(multi_head_attention)

In [3]:
class FeedForwardNetwork(nn.Module):
    def __init__(self,d_model,d_ff, dropout=0.1):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff,d_model),
            nn.Dropout(dropout),
        )

    def forward(self,x):
        return self.ffn(x)

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, seq_len, d_model):
        super().__init__()
        self.seq_len = seq_len
        self.d_model = d_model #512 int

        pe = torch.ones([seq_len,d_model]) #2d array
        pos = torch.arange(0,seq_len).view(-1,1) # without view it is (seq_len,) 
        i = 2*torch.arange(0,d_model//2).view(1,-1) # without view it is (d_model//2,)

        pe[:, ::2] = torch.sin(pos/(10000**(i/d_model))) # to broadcast the shape should be (seq_len,1), (1,d_model//2)
        pe[:, 1::2] = torch.cos(pos/(10000**(i/d_model)))
        # self.pe is now (seq_len,d_model)
        pe.unsqueeze_(0) #need _ for in_place
        # self.pe is now (1,seq_len, d_model) we need this to math dimension

        self.register_buffer("pe", pe)



    def forward(self,x):
        """ 
        Args:
            x: (Batch, seq_len, d_model)
        """
        return x + self.pe

In [5]:
class Encoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.mha = MultiHeadAttention(num_heads=num_heads, d_model=d_model)
        self.norm_1 = nn.LayerNorm(normalized_shape=d_model)
        self.ffn = FeedForwardNetwork(d_model=d_model, d_ff=d_ff)
        self.norm_2 = nn.LayerNorm(normalized_shape=d_model)

    def forward(self,x):
        x_1 = self.mha(x)
        x_2 = self.norm_1(x + x_1)
        x_3 = self.ffn(x_2)
        x_out = self.norm_2(x_3 + x_2)
        return x_out


In [6]:
# BERT used learned positional encoding
class BERTEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_seq_len, dropout=0.1):
        super().__init__()
        self.token_embeddings = nn.Embedding(vocab_size,d_model)
        self.positional_embeddings = nn.Embedding(max_seq_len, d_model)
        self.segment_embeddings = nn.Embedding(2,d_model)
        self.drop_out_1 = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model)
        
    def forward(self, input_ids, segment_ids=None):
        token_embeddings = self.token_embeddings(input_ids)
        position_ids = torch.tensor([i for i in range(input_ids.shape[1])])
        position_ids.unsqueeze(0)
        positional_embeddings = self.positional_embeddings(position_ids)
        if segment_ids is not None:
            segment_embeddings = self.segment_embeddings(segment_ids)
        else:
            segment_embeddings = self.segment_embeddings(torch.zeros_like(input_ids))
        sum_embedding = token_embeddings + positional_embeddings + segment_embeddings
        sum_embedding = self.drop_out_1(sum_embedding)
        sum_embedding = self.layer_norm(sum_embedding)
        return sum_embedding

In [7]:
class BERTEncoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers):
        super().__init__()
        self.model = nn.ModuleList([Encoder(d_model=d_model, num_heads=num_heads, d_ff=d_ff) for i in range(num_layers)])
        
    def forward(self, x):
        # TODO: Pass x through all encoder layers sequentially
        for i, l in enumerate(self.model):
            x = l(x)
        return x

In [8]:
# TODO: Linear layer + activation for [CLS] token pooling
# Original BERT uses tanh activation
class BERTPooler(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.linear_layer = nn.Linear(d_model,d_model)
        self.activation = nn.Tanh()
        
    def forward(self, hidden_states):
        # TODO: Extract [CLS] token (first token) and pool it
        # Input: (batch_size, seq_len, d_model)  
        # Output: (batch_size, d_model)
        return self.activation(self.linear_layer(hidden_states[:, 0, :]))

In [9]:
class BERT(nn.Module):
    def __init__(self, vocab_size=30522, d_model=768, num_heads=12, d_ff=3072, num_layers=12, max_seq_len=512):
        super().__init__()
        # TODO: Initialize embedding, encoder, and pooler
        self.bert_embedding = BERTEmbedding(vocab_size=vocab_size, d_model=d_model, max_seq_len=max_seq_len)
        self.bert_encoder = BERTEncoder(d_model=d_model, num_heads=num_heads, d_ff=d_ff, num_layers=num_layers)
        self.bert_pooler = BERTPooler(d_model)
        
    def forward(self, input_ids, segment_ids=None):
        # TODO: 
        # - Get embeddings
        # - Pass through encoder  
        # - Return both sequence output and pooled output
        # Return format: (sequence_output, pooled_output)
        embeddings = self.bert_embedding(input_ids)
        sequence_output = self.bert_encoder(embeddings)
        pooled_output = self.bert_pooler(sequence_output)
        return (sequence_output, pooled_output)