In [None]:
# Learing to write out the BERT architecture
# Embeddings
# Positional Embedding / Encoding
# Encoder
# Attention

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
class PostionalEmbeddings(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PostionalEmbeddings, self).__init__()

        pe = torch.zeros(max_seq_length, d_model)
        for pos in range(max_seq_length):
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** (2 * i / d_model)))
                pe[pos, i+1] = math.cos(pos / (10000 ** (2 * (i + 1) / d_model)))

        pe = pe.unsqueeze(0) # add batch size dimension
        self.register_buffer("pe", pe)
    
    def forward(self, x):
        return x + self.pe

In [None]:
class BERTEmbeddings(nn.Module):
    def __init__(self, vocab_size, embed_size, max_seq_len, dropout):
        super(BERTEmbeddings, self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        # Token embeddings, this is the normal text embedding
        self.token_embed = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        # Segment embeddings, adding sentence segment info
        self.segment_embed = nn.Embedding(3, embed_size, padding_idx=0)
        # Positional embeddings
        self.position = PostionalEmbeddings(embed_size, max_seq_len)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, input, segment_label):
        x = self.position(self.token_embed(input)) + self.segment_embed(segment_label)
        return self.dropout(x)

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init()
        
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_model = d_model
        self.d_k = d_model // n_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def split_heads(self, x):
        batch_size, seq_len, d_model = x.size()
        # new dim: [batch_size, self.n_heads, seq_length, self.d_k]
        return x.view(batch_size, seq_length, self.n_heads, self.d_k).transpose(1, 2)

    def attention(self, Q, K, V, mask):
        # [batch_size, self.n_heads, seq_length, self.d_k] * [batch_size, self.n_heads, self.d_k, seq_length]
        # [batch_size, self.n_heads, seq_length, seq_length]
        attn_score = torch.matmal(Q, K.transpose(-1, -2)) / math.sqrt(self.d_k)
        
        # Add low values for masked items
        attn_score = torch.masked_fill(mask == 0, -1e-9)

        # [batch_size, self.n_heads, seq_length, seq_length]
        attn_probs = torch.softmax(attn_score, dim=-1)

        # [batch_size, self.n_heads, seq_length, self.d_k]
        output = torch.matmul(attn_probs, V)
        return output

    def forward(self, Q, K, V, mask):
        # x dimensions is [batch_size, seq_length, d_model]
        # Initialize matrices and split heads
        Q = self.split_heads(self.W_q(Q)) # [batch_size, seq_length, d_model]
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        # Apply attention mechanism to head
        # [batch_size, self.n_heads, seq_length, self.d_k]
        attn_output = self.attention(Q, K, V, mask)

        # Combine heads
        # Reshape back to [batch_size, seq_length, d_model]
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        return self.W_o(attn_output)

In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        out = self.activation(self.fc1(x))
        return self.fc2(self.dropout(out))

In [None]:
class EncodeLayer(nn.Module):
    def __init__(self, d_model, d_ff, n_heads, dropout):
        super().__init()
        self.multihead = MultiHeadAttention(d_model, n_heads)
        self.feedforward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, mask):
        # Embeddings: [batch_size, seq_length, d_model]
        # Mask: [batch_size, 1, 1, seq_length]
        # Ouput: [batch_size, seq_length, d_model]
        attn_output = self.multihead(x, x, x, mask)
        out = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feedforward(out)
        out = self.norm2(x + self.dropout(ff_output))
        return out

In [None]:
class BERTPooler(nn.Module):
    def __init__(self, d_model):
        super().__init()
        self.dense = nn.Linear(d_model, d_model)
        self.tanh = nn.Tanh()
    
    def forward(self, encoder_out):
        cls_token = encoder_out[:, 0]
        out = self.tanh(self.dense(cls_token))
        return out

In [None]:
class BERT(nn.Module):
    def __init__(self, vocab_size, seq_length, d_model, n_heads, d_ff, n_layers, dropout):
        super().__init()
        self.n_layers = n_layers

        self.embeddings = BERTEmbeddings(vocab_size, d_model, seq_length, dropout)
        self.encoder = nn.ModuleList([EncodeLayer(d_model, d_ff, n_heads, dropout) for _ in range(n_layers)])
        
        # mlm
        self.mlm_head = nn.Linear(d_model, vocab_size)

        # next sentence prediction
        self.softmax = nn.LogSoftmax(dim=-1)
        self.nsp_head = nn.Linear(d_model, 2) 

    def forward(self, input, segment_label, mask):
        x = self.embeddings(input, segment_label)
        
        out = embeds
        # [batch_size, seq_length, d_model]
        for layer in self.encoder:
            out = layer(out, mask)

        mlm_output = self.mlm_head(out) # [batch_size, seq_length, vocab_size]
        
        cls_token = out[:, 0] # Shape: [batch_size, d_model]
        nsp_output = self.nsp_head(cls_token) # Shape: [batch_size, 2]

        return self.softmax(mlm_output), self.softmax(nsp_output)

In [None]:
class BERTClassification(nn.Module):
    def __init__(self, vocab_size, seq_length, d_model, n_heads, n_classes, d_ff, n_layers, dropout):
        self.n_layers = n_layers

        self.embeddings = BERTEmbeddings(vocab_size, d_model, seq_length, dropout)
        self.encoder = nn.ModuleList([EncodeLayer(d_model, d_ff, n_heads, dropout) for _ in range(n_layers)])
        self.pooler = BERTPooler(d_model)

        self.classifier = nn.Linear(d_model, n_classes)

    def forward(self, input, segment_label, mask):
        x = self.embeddings(input, segment_label)
        
        out = embeds
        # [batch_size, seq_length, d_model]
        for layer in self.encoder:
            out = layer(out, mask)
        
        pooled_output = self.pooler(out)
        logits = self.classifier(pooled_output)

        return logits


In [None]:
model = BERT(vocab_size=2000, seq_len = 512, d_model=768, n_heads=3, n_layers=12, d_ff=4*768)

dataloader

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
criterion = nn.NLLLoss(ignore_index=0)

epochs = 10
for epoch in range(epochs):
    for i, (input, is_next, bert_label, segment) in enumerate(data_loader):
        mlm_output, nsp_output = model(input, segment, mask)
        optimizer.zero_grad()
        
        # NSP Loss
        nsp_loss = criterion(nsp_output, is_next)

        # MLM loss
        mlm_loss = criterion(mlm_output.transpose(1, 2), bert_label)
        loss = nsp_loss + mlm_loss
        
        loss.backward()
        optimizer.step()

