In [1]:
import torch
from torch import nn

In [2]:
class BERTModel(nn.Module):
    def __init__(self, vocab_size, num_hiddens, num_heads, num_layers, dropout, max_len = 1000):
        token_vec_size = num_hiddens
        super(BERTModel, self).__init__()

        # 1. Embeddings
        self.token_embeddings = nn.Embedding(vocab_size, token_vec_size) 
        # "Kati ota vectors store garne" -> vocab_size (e.g. 30,000)
        # "each vector kati dim ko" -> token_vec_size (e.g. 768)
        # This creates a matrix of size 30,000 x 768. It initializes them randomly.

        self.segment_embedding = nn.Embedding(2, token_vec_size) # 2 segments ofc setence A ki sentence B
        # Embedding for sentence A and for sentence B each of 768 dim. i.e matrix of 2 x 768 vector
        # Input Batch Shape: (Batch_Size, Max_Len, Token_Vec_Size) thus the 1, needs this for broadcasting as we do batch processing
        
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, token_vec_size)) # This is a learnable parameter
        # It shape is (1,1000, 768), i,e 1000 vectors each of 768 size, each for 1 position, i.e. 1 vector for each position 0, 1, 2 ... 999

        self.dropout = nn.Dropout(dropout)

        # 2. Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=token_vec_size,
            nhead=num_heads,
            dim_feedforward=token_vec_size*4,
            dropout=dropout,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # 3. Task Heads
        # A. Next Sentence Prediction (NSP) Head
        self.hidden = nn.Sequential(nn.Linear(token_vec_size, token_vec_size),nn.Tanh())
        self.nsp = nn.Linear(token_vec_size, 2)

        # B. Masked Language Modeling (MLM) Head
        self.mlm=nn.Sequential(
            nn.Linear(token_vec_size, token_vec_size),
            nn.ReLU(),
            nn.LayerNorm(token_vec_size),
            # This is simply just normalization, sepcifically standarization lol
            nn.Linear(token_vec_size,vocab_size)
        )
    
    def forward(self, tokens, segments, pred_positions=None):
        # 1. Input Embedding
        X = self.token_embeddings(tokens)+ self.segment_embedding(segments)
        X = X + self.pos_embedding.data[:,:X.shape[1],:]
        X = self.dropout(X)

        # 2. Encoder Pass
        encoded_X = self.encoder(X)

        # 3. NSP Output
        # Token 0 ([CLS]): Is the designated "Summary Token"
        '''Because of the Self-Attention mechanism, every token looks at every other token. During pre-training, 
        the model is explicitly taught (via backpropagation) to put a global representation of the entire sentence pair into the [CLS] vector.'''
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))

        # 4. MLM Output
        if pred_positions is not None:
            # We only want to predict at the masked positions
            # pred_positions shape: (batch_size, num_preds)
            batch_size = X.shape[0]
            num_pred_positions = pred_positions.shape[1] # [[1,2],[2,3][4,5]] shape(3, 2)
            
            # Create batch indices to select specific vectors
            batch_idx = torch.arange(0, batch_size).unsqueeze(1).repeat(1, num_pred_positions)
            '''
                [B]  →  [0, 1, 2, ..., B-1]

                    =>

                [B] → [B, 1]

                [[0],
                [1],
                [2],
                ...
                [B-1]]


                    =>

                [B, 1] → [B, N]

                [[0, 0, 0, ..., 0],
                [1, 1, 1, ..., 1],
                [2, 2, 2, ..., 2],
                ...
                [B-1, B-1, B-1, ..., B-1]]


            '''
            
            # Gather the vectors at the masked positions
            masked_X = encoded_X[batch_idx, pred_positions]
            '''
                encoded_X.shape      # (2, 5, 3)
                pred_positions = [[1, 3],
                                [0, 4]]

                batch_idx = [[0, 0],
                            [1, 1]]

                            
                masked_X[0, 0] = encoded_X[0, 1]
                masked_X[0, 1] = encoded_X[0, 3]
                masked_X[1, 0] = encoded_X[1, 0]
                masked_X[1, 1] = encoded_X[1, 4]

            '''
            masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
            
            mlm_Y_hat = self.mlm(masked_X)
        else:
            mlm_Y_hat = None
        
        return encoded_X, mlm_Y_hat, nsp_Y_hat

In [3]:
import random

def generate_fake_data(vocab_size, batch_size, seq_len, num_mlm_preds):
    
    # Random tokens indices
    tokens = torch.randint(0, vocab_size, (batch_size, seq_len))
    
    # Random segments (0s and 1s)
    segments = torch.randint(0, 2, (batch_size, seq_len))
    
    # Random positions to mask (e.g., 15% of sequence)
    pred_positions = torch.randint(0, seq_len, (batch_size, num_mlm_preds))
    
    # Fake Ground Truth Labels
    # MLM Labels: The true words at the masked positions
    mlm_labels = torch.randint(0, vocab_size, (batch_size, num_mlm_preds))
    
    # MLM Weights: 1.0 for real masks, 0.0 for padding (all 1s here for simplicity)
    mlm_weights = torch.ones((batch_size, num_mlm_preds))
    
    # NSP Labels: 0 (True Next) or 1 (Random Next)
    nsp_labels = torch.randint(0, 2, (batch_size,))
    
    return tokens, segments, pred_positions, mlm_labels, mlm_weights, nsp_labels



In [4]:
generate_fake_data(3,3,5,2)

(tensor([[2, 2, 1, 2, 2],
         [1, 2, 0, 0, 1],
         [0, 1, 0, 2, 0]]),
 tensor([[1, 0, 0, 1, 0],
         [1, 1, 0, 0, 0],
         [0, 1, 0, 0, 0]]),
 tensor([[3, 2],
         [0, 2],
         [4, 3]]),
 tensor([[1, 2],
         [0, 1],
         [2, 1]]),
 tensor([[1., 1.],
         [1., 1.],
         [1., 1.]]),
 tensor([0, 1, 0]))

In [5]:
# Configuration
VOCAB_SIZE = 10000
NUM_HIDDENS = 128
NUM_HEADS = 2
NUM_LAYERS = 2
DROPOUT = 0.1
BATCH_SIZE = 8
SEQ_LEN = 32
NUM_MLM_PREDS = 5 # Approx 15% of 32

In [6]:
def train_bert_custom():
    
    # 1. Initialize Model
    net = BERTModel(VOCAB_SIZE, NUM_HIDDENS, NUM_HEADS, NUM_LAYERS, DROPOUT)
    loss_fn = nn.CrossEntropyLoss(reduction='none')
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    
    print("Starting Training Simulation...")
    
    # 2. Training Loop (Simulating 50 steps)
    for step in range(50):
        # Generate a batch of fake data
        tokens, segments, pred_positions, mlm_labels, mlm_weights, nsp_labels = \
            generate_fake_data(VOCAB_SIZE, BATCH_SIZE, SEQ_LEN, NUM_MLM_PREDS)
        
        optimizer.zero_grad()
        
        # Forward Pass
        _, mlm_Y_hat, nsp_Y_hat = net.forward(tokens, segments, pred_positions)
        
        # --- Calculate Losses ---
        
        # A. MLM Loss
        # Flatten predictions to (batch * num_preds, vocab)
        mlm_l = loss_fn(mlm_Y_hat.reshape(-1, VOCAB_SIZE), mlm_labels.reshape(-1))
        # Multiply by weights (to ignore padding if we had it)
        mlm_l = mlm_l * mlm_weights.reshape(-1)
        mlm_l = mlm_l.sum() / (mlm_weights.sum() + 1e-8)
        
        # B. NSP Loss
        nsp_l = loss_fn(nsp_Y_hat, nsp_labels)
        nsp_l = nsp_l.mean()
        
        # Total Loss
        total_loss = mlm_l + nsp_l
        
        # Backward Pass
        total_loss.backward()
        optimizer.step()
        
        if (step+1) % 10 == 0:
            print(f"Step {step+1}: MLM Loss: {mlm_l.item():.4f}, NSP Loss: {nsp_l.item():.4f}")

# Run the training
train_bert_custom()

Starting Training Simulation...
Step 10: MLM Loss: 9.3282, NSP Loss: 0.8525
Step 20: MLM Loss: 9.3183, NSP Loss: 0.6907
Step 30: MLM Loss: 9.4742, NSP Loss: 0.6678
Step 40: MLM Loss: 9.2611, NSP Loss: 0.7680
Step 50: MLM Loss: 9.6456, NSP Loss: 0.6863


In [7]:
def check_context_sensitivity():
    net = BERTModel(VOCAB_SIZE, NUM_HIDDENS, NUM_HEADS, NUM_LAYERS, DROPOUT)
    net.eval() # Set to evaluation mode
    
    # Sentence 1: "A crane is flying"
    # Let's map these to fake IDs for demonstration
    # <cls>=0, a=5, crane=10, is=11, flying=12, <sep>=2
    tokens_1 = torch.tensor([[0, 5, 10, 11, 12, 2]])
    segments_1 = torch.tensor([[0, 0, 0, 0, 0, 0]])
    
    # Sentence 2: "A crane driver came"
    # <cls>=0, a=5, crane=10, driver=15, came=16, <sep>=2
    tokens_2 = torch.tensor([[0, 5, 10, 15, 16, 2]])
    segments_2 = torch.tensor([[0, 0, 0, 0, 0, 0]])
    
    with torch.no_grad():
        encoded_1, _, _ = net(tokens_1, segments_1)
        encoded_2, _, _ = net(tokens_2, segments_2)
        
    # Extract the vector for "crane" (Index 2 in both sentences)
    crane_vec_1 = encoded_1[0, 2, :]
    crane_vec_2 = encoded_2[0, 2, :]
    
    print("\nContext Sensitivity Check:")
    print(f"Crane Vector 1 (Flying): {crane_vec_1[:3]}") # Print first 3 numbers
    print(f"Crane Vector 2 (Driver): {crane_vec_2[:3]}")
    
    dist = torch.dist(crane_vec_1, crane_vec_2)
    print(f"Euclidean Distance between vectors: {dist.item():.4f}")
    if dist > 0.0:
        print("SUCCESS: The vectors are different! The model is context-sensitive.")

check_context_sensitivity()


Context Sensitivity Check:
Crane Vector 1 (Flying): tensor([0.2405, 1.3969, 0.4458])
Crane Vector 2 (Driver): tensor([0.2401, 1.5742, 0.3580])
Euclidean Distance between vectors: 1.7053
SUCCESS: The vectors are different! The model is context-sensitive.
