In [1]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


@dataclass
class BertConfig:
    block_size : int = 512
    vocab_size : int = 30_528 #a much better number than 30_522
    n_layer : int = 12 
    n_embd : int = 768
    n_head : int = 12
    intermediate_size : int = 3072
    type_vocab_size :int = 2    

In [3]:
class BertEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.n_embd)
        self.positional_embeddings = nn.Embedding(config.block_size, config.n_embd)
        self.layernorm = nn.LayerNorm(config.n_embd, eps = 1e-12)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, input_ids):
        position_ids = torch.arange(config.block_size, device = input_ids.device)
        input_embeds = self.word_embeddings(input_embeds) #(B, T, C)
        position_embeds = self.positional_embeddings(position_ids) #(T, C)
        embeddings = input_embeds + position_embeds #(B, T, C)
        embeddings = self.layernorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings #(B, T, C)

class BertAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.key = nn.Linear(config.n_embd, config.n_embd)
        self.query = nn.Linear(config.n_embd, config.n_embd)
        self.value = nn.Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head
        self.dropout = nn.Dropout(0.1)
        self.dense = nn.Linear(config.n_embd, config.n_embd)
        self.layernorm = nn.LayerNorm(config.n_embd, eps=1e-12)

    
    def forward(self, x):
        B, T, C = x.shape
        key = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) #(B, nh, T, hs)
        query = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) #(B, nh, T, hs)
        value = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) #(B, nh, T, hs)
        
        wei = key @ query.transpose(-1, -2) / math.sqrt(C // self.n_head) #(B, nh, T, T)
        attention_probs = F.softmax(wei, dim = -1)
        attention_probs = self.dropout(attention_probs)
        output = attention_probs @ value #(B, nh, T, hs)
        output = output.transpose(1, 2) #(B, T, nh, hs)
        output = output.view(B, T, C).contiguous() #(B, T, C)
        output = self.dense(output) #(B, T, C)
        output = self.layernorm(output + x) #(B, T, C)
        output = self.dropout(output)
        return output #(B, T, C)
    
class BertLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = BertAttention(config)
        self.intermediate = nn.Sequential(
            nn.Linear(config.n_embd, config.intermediate_size),
            nn.GELU(approximation = 'tanh'),
        )
        self.output = nn.Sequential(
            nn.Linear(config.intermediate_size, config.n_embd),
            nn.Dropout(0.1),
            #extracting the layernorm from here to add the residual 
        )
        self.output_layernorm = nn.LayerNorm(config.n_embd, eps = 1e-12)
        self.chunk_size_ffwd = 256
        self.seq_len_dim = 1 
    
    def _feed_forward_chunk(self, attention_output):
        intermediate_output = self.intermediate(attention_output) #(B, T, intermediate_size)
        layer_output = self.output(intermediate_output) #(B, T, C)
        layer_output = self.output_layernorm(layer_output + attention_output)
        return layer_output #(B, T, C)
    
    def apply_chunk_to_ffwd(self, forward_fn, chunk_size, chunk_dim, attn_out):
        tensor_shape = attn_out.shape[chunk_dim]
        assert tensor_shape % chunk_size == 0 
        n_chunks = tensor_shape // chunk_size
        attn_out_chunk = attn_out.chunk(n_chunks, dim = chunk_dim) #this is a tuple of the chunks
        output_chunk = tuple(forward_fn(chunk) for chunk in attn_out_chunk)
        return torch.cat(output_chunk, dim = chunk_dim)
        
    
    def forward(self, x):
        attention_output = self.attention(x)
        layer_output = self.apply_chunk_to_ffwd(self._feed_forward_chunk, self.chunk_size_ffwd, self.seq_len_dim, attention_output)
        return layer_output #(B, T, C)
    

class BertEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.n_layer)])
     
    def forward(self, x):
        for layer in self.layer:
            x = layer(x)
        
        return x #(B, T, C)
    
class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.n_embd, config.n_embd)
        self.activation = nn.Tanh()
    
    def forward(self, x):
        first_token_tensor = x[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output #(B, T, C)
    
class BertOnlyNSPHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.seq_relationship = nn.Linear(config.n_embd, 2)
    
    def forward(self, pooled_output):
        return self.seq_relationship(pooled_output) #(B, T, 2)
    
class BertModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embedding = BertEmbedding(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean = 0.0, std = 0.01)
            if module.bias is not None:
                module.bias.data.zero_()
        
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean = 0.0, std = 0.01)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        
    def forward(self, input_ids):
        B, T = input_ids.size()
        device = input_ids.device
        embedding_output = self.embedding(input_ids) #(B, T, C)
        encoder_output = self.encoder(embedding_output) #(B, T, C)
        pooled_output = self.pooler(encoder_output) #(B, T, C)
        return pooled_output #(B, T, C)
    
class BertForNextSentencePrediction(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.bert = BertModel(config)
        self.cls = BertOnlyNSPHead(config)
    
    def forward(self, input_ids, labels):
        pooled_output = self.bert(input_ids)
        seq_relationship_score = self.cls(pooled_output)
        next_sentence_loss = nn.CrossEntropyLoss(seq_relationship_score.view(-1, 2), labels.view(-1))
        return (pooled_output, next_sentence_loss)
        