In [1]:
!pip install datasets



In [None]:
HF_TOKEN = "Put your token here"

In [3]:
from datasets import load_dataset
import itertools # We need this to limit our streaming dataset

# Load the dataset
wiki_dataset = load_dataset("wikimedia/wikipedia", "20231101.en", split='train', streaming=True)

# We'll create one large text file to train our tokenizer on.
# We'll start with 10,000 articles to keep the process fast for now.
num_articles_to_use = 10000

print(f"Creating training corpus file from {num_articles_to_use} Wikipedia articles...")

with open("wiki_corpus.txt", "w", encoding="utf-8") as f:
    # Loop through the first 10,000 articles in the dataset
    for article in itertools.islice(wiki_dataset, num_articles_to_use):
        # Get the text content of the article
        text_content = article['text']
        # Write it to the file, followed by two newlines to separate articles
        f.write(text_content + "\n\n")

print("Finished creating wiki_corpus.txt!")

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Creating training corpus file from 10000 Wikipedia articles...
Finished creating wiki_corpus.txt!


In [4]:
# import regex as re
# import requests

# class Tokenizer:
#     def __init__(self):
#         self.merges = {}
#         self.vocab = {}
#         self.pattern = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

#     def _get_pair_counts(self, tokens):
#         counts = {}
#         for i in range(len(tokens) - 1):
#             pair = (tokens[i], tokens[i+1])
#             counts[pair] = counts.get(pair, 0) + 1
#         return sorted(((v, k) for k, v in counts.items()), reverse=True)

#     def _merge(self, tokens, pair, new_idx):
#         new_tokens = []
#         i = 0
#         while i < len(tokens):
#             if i < len(tokens) - 1 and (tokens[i], tokens[i+1]) == pair:
#                 new_tokens.append(new_idx)
#                 i += 2
#             else:
#                 new_tokens.append(tokens[i])
#                 i += 1
#         return new_tokens

#     def train(self, text, num_merges, verbose=False):
#         print("Starting training...")
#         text_chunks = self.pattern.findall(text)
#         tokens = []
#         for chunk in text_chunks:
#             tokens.extend(list(chunk.encode("utf-8")))

#         merges = {}
#         for i in range(num_merges):
#             pair_counts = self._get_pair_counts(tokens)
#             if not pair_counts:
#                 break

#             top_pair = pair_counts[0][1]
#             new_idx = 256 + i
#             merges[top_pair] = new_idx
#             tokens = self._merge(tokens, top_pair, new_idx)
#             if verbose and (i + 1) % 50 == 0:
#                 print(f"  Merge {i+1}/{num_merges} completed...")

#         self.merges = merges
#         self._build_vocab()
#         print("Training finished!")

#     def _build_vocab(self):
#         self.vocab = {idx: bytes([idx]) for idx in range(256)}
#         for (p1, p2), idx in self.merges.items():
#             self.vocab[idx] = self.vocab[p1] + self.vocab[p2]

#     def save(self, filepath="merge_rules.bpe"):
#         print(f"Saving merge rules to {filepath}...")
#         with open(filepath, 'w', encoding="utf-8") as f:
#             for (p1, p2) in self.merges:
#                 f.write(f"{p1} {p2}\n")
#         print("Done.")

#     def load(self, filepath="merge_rules.bpe"):
#         print(f"Loading merge rules from {filepath}...")
#         merges = {}
#         with open(filepath, 'r', encoding="utf-8") as f:
#             for i, line in enumerate(f):
#                 p1, p2 = line.strip().split()
#                 merges[(int(p1), int(p2))] = 256 + i
#         self.merges = merges
#         self._build_vocab()
#         print("Tokenizer loaded.")

#     def encode(self, text):
#         text_chunks = self.pattern.findall(text)
#         tokens = []
#         for chunk in text_chunks:
#             tokens.extend(list(chunk.encode("utf-8")))

#         for pair, new_idx in self.merges.items():
#             tokens = self._merge(tokens, pair, new_idx)
#         return tokens

#     def decode(self, ids):
#         byte_chunk = b"".join(self.vocab[idx] for idx in ids)
#         return byte_chunk.decode("utf-8", errors="replace")



In [5]:
# print("Dowloading text for training...")
# with open("wiki_corpus.txt",'r') as file:
#   text = file.read()
# tokenizer = Tokenizer()
# print("Training tokenizer...")
# tokenizer.train(text, num_merges=500, verbose=True)
# tokenizer.save("my_tokenizer.bpe")

In [7]:
!pip install tokenizers

Collecting tokenizers
  Downloading tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m68.8 MB/s[0m  [33m0:00:00[0m
[?25hInstalling collected packages: tokenizers
Successfully installed tokenizers-0.21.4


In [34]:
# Due to comuting time I have switched to default library which made it faster to train the tokenizer
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import ByteLevel


tokenizer = Tokenizer(BPE(unk_token="[UNK]"))


tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=False)

print("Tokenizer initialized. Starting training...")

trainer = BpeTrainer(
    vocab_size=30000,
    special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
)

tokenizer.train(files=["wiki_corpus.txt"], trainer=trainer)

print("Training complete!")


tokenizer.save("wiki-bpe-tokenizer.json")

print("Tokenizer saved to wiki-bpe-tokenizer.json")

Tokenizer initialized. Starting training...



Training complete!
Tokenizer saved to wiki-bpe-tokenizer.json


In [35]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embedding size needs to be divisible by heads"

        # Linear layers to project inputs into V, K, Q spaces
        self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.query = nn.Linear(self.embed_size, self.embed_size, bias=False)

        # Final fully connected layer to produce the output
        self.fc_out = nn.Linear(self.embed_size, self.embed_size)

    # (This code goes inside the SelfAttention class)

    def forward(self, values, keys, query, mask):
        # Get the number of training examples in the batch
        N = query.shape[0]

        # Get the length of the input sequences
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Pass inputs through the linear layers
        values = self.values(values)
        keys = self.keys(keys)
        query = self.query(query)

        # Reshape the tensors to split them into multiple heads
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        query = query.reshape(N, query_len, self.heads, self.head_dim)

        # --- The Core Attention Calculation ---
        # Matrix multiply Queries and Keys to get attention scores
        # query shape: (N, query_len, heads, head_dim)
        # keys shape: (N, key_len, heads, head_dim)
        # energy shape: (N, heads, query_len, key_len)
        energy = torch.einsum("nqhd,nkhd->nhqk", [query, keys])

        # Apply the mask, if one is provided
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # Apply softmax to get attention weights (probabilities)
        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)

        # --- Apply Attention to Values ---
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, head_dim)
        # out shape: (N, query_len, heads, head_dim)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        # Pass the result through the final linear layer
        out = self.fc_out(out)
        return out

In [36]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()

        # 1. The attention layer
        self.attention = SelfAttention(embed_size, heads)

        # 2. Normalization layers
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        # 3. The feed-forward network
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )

        # 4. Dropout layer for regularization
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
      # 1. Pass inputs through the self-attention layer
      attention = self.attention(value, key, query, mask)

      # 2. Add the output of attention to the original query (the skip connection)
      #    and then pass it through the first normalization layer.
      x = self.dropout(self.norm1(attention + query))

      # 3. Pass the result through the feed-forward network
      forward = self.feed_forward(x)

      # 4. Add the output of the feed-forward network to its input (the second skip connection)
      #    and then pass it through the second normalization layer.
      out = self.dropout(self.norm2(forward + x))

      return out

In [37]:
class Encoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length,
    ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,
                )
                for _ in range(num_layers)
            ]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)

        out = self.dropout(
            (self.word_embedding(x) + self.position_embedding(positions))
        )

        # The attention mask from the tokenizer is of shape (N, seq_length).
        # We need to reshape it to (N, 1, 1, seq_length) to make it broadcastable
        # with the attention scores tensor of shape (N, heads, query_len, key_len).
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)

        # Pass the data through each Transformer Block in the stack
        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

In [38]:
class BERT(nn.Module):
    def __init__(self, encoder, vocab_size):
        super(BERT, self).__init__()
        self.encoder = encoder

        # This is the final layer that makes the word predictions
        self.fc_out = nn.Linear(encoder.embed_size, vocab_size)

    def forward(self, x, mask):
        # First, get the context-rich output from the encoder
        out = self.encoder(x, mask)

        # Then, pass it through the final linear layer to get predictions
        return self.fc_out(out)

In [39]:
import torch
from torch.utils.data import Dataset
import random

class MLMDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.lines = []

        # Get special token IDs once
        self.mask_token_id = tokenizer.token_to_id("[MASK]")
        self.pad_token_id = tokenizer.token_to_id("[PAD]")
        self.cls_token_id = tokenizer.token_to_id("[CLS]")
        self.sep_token_id = tokenizer.token_to_id("[SEP]")
        self.vocab_size = tokenizer.get_vocab_size()

        # Let the library handle padding and truncation
        self.tokenizer.enable_padding(length=self.max_length, pad_id=self.pad_token_id)
        self.tokenizer.enable_truncation(max_length=self.max_length)

        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    self.lines.append(line.strip())

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        line = self.lines[idx]
        # The encoding object now contains ids, attention_mask, etc.
        encoding = self.tokenizer.encode(line)

        input_ids = torch.tensor(encoding.ids, dtype=torch.long)
        attention_mask = torch.tensor(encoding.attention_mask, dtype=torch.long)
        labels = input_ids.clone()

        # Create a mask for tokens we can potentially mask.
        # We don't want to mask special tokens like [CLS], [SEP], [PAD].
        can_mask = (input_ids != self.cls_token_id) & \
                   (input_ids != self.sep_token_id) & \
                   (input_ids != self.pad_token_id)

        # Determine which tokens to mask (15% of the maskable tokens)
        mask_indices = torch.where(can_mask & (torch.rand(input_ids.shape) < 0.15))

        for i in mask_indices[0]:
            prob = random.random()
            if prob < 0.8:
                # 80% of the time: Replace with [MASK]
                input_ids[i] = self.mask_token_id
            elif prob < 0.9:
                # 10% of the time: Replace with a random token
                input_ids[i] = random.randint(0, self.vocab_size - 1)
            # else 10% of the time: Keep the original token (do nothing)

        # We only compute loss on the tokens we masked.
        # Set all other labels to -100 so they are ignored by the loss function.
        # Create a boolean mask for all positions that were NOT masked.
        non_masked_indices = torch.ones(labels.shape, dtype=torch.bool)
        non_masked_indices[mask_indices] = False

        # Use this boolean mask to set the labels at these positions to -100.
        labels[non_masked_indices] = -100

        return input_ids, attention_mask, labels

In [14]:
import torch
from torch.utils.data import DataLoader
from tokenizers import Tokenizer

# --- Assume previous code for model, dataset, etc. is loaded ---

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

tokenizer = Tokenizer.from_file("wiki-bpe-tokenizer.json")

# Initialize the Dataset and DataLoader
train_dataset = MLMDataset("wiki_corpus.txt", tokenizer)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Initialize the Model
VOCAB_SIZE = tokenizer.get_vocab_size()
EMBED_SIZE = 256
NUM_LAYERS = 6
HEADS = 8
FORWARD_EXPANSION = 4
DROPOUT = 0.1
MAX_LENGTH = 128

encoder = Encoder(VOCAB_SIZE, EMBED_SIZE, NUM_LAYERS, HEADS, device, FORWARD_EXPANSION, DROPOUT, MAX_LENGTH)
model = BERT(encoder, VOCAB_SIZE).to(device)

# Initialize Optimizer and Loss Function
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

# --- The Training Loop ---
epochs = 5
model.train()

for epoch in range(epochs):
    print(f"\n--- Epoch {epoch+1}/{epochs} ---")
    total_loss = 0

    for i, (input_ids, attention_mask, labels) in enumerate(train_loader):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        outputs = model(input_ids, mask=attention_mask)

        # Reshape for loss calculation
        # The outputs tensor is reshaped to (batch_size * seq_len, vocab_size)
        outputs = outputs.view(-1, VOCAB_SIZE)

        # FIX: The labels tensor should be flattened to (batch_size * seq_len)
        labels = labels.view(-1)

        # Calculate loss
        loss = criterion(outputs, labels)
        total_loss += loss.item()

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f"  Batch {i+1}/{len(train_loader)}, Loss: {loss.item():.4f}")

    print(f"Epoch {epoch+1} finished. Average Loss: {total_loss / len(train_loader):.4f}")

print("\nTraining complete!")

Using device: cuda

--- Epoch 1/5 ---
  Batch 100/17006, Loss: 8.6460
  Batch 200/17006, Loss: 8.1178
  Batch 300/17006, Loss: 7.7455
  Batch 400/17006, Loss: 7.8862
  Batch 500/17006, Loss: 8.2632
  Batch 600/17006, Loss: 8.1931
  Batch 700/17006, Loss: 7.9859
  Batch 800/17006, Loss: 8.2194
  Batch 900/17006, Loss: 7.9190
  Batch 1000/17006, Loss: 7.4765
  Batch 1100/17006, Loss: 7.5844
  Batch 1200/17006, Loss: 7.6251
  Batch 1300/17006, Loss: 7.5998
  Batch 1400/17006, Loss: 8.0714
  Batch 1500/17006, Loss: 7.6162
  Batch 1600/17006, Loss: 7.4739
  Batch 1700/17006, Loss: 7.9230
  Batch 1800/17006, Loss: 8.2018
  Batch 1900/17006, Loss: 8.2211
  Batch 2000/17006, Loss: 7.4853
  Batch 2100/17006, Loss: 7.7598
  Batch 2200/17006, Loss: 8.0013
  Batch 2300/17006, Loss: 7.9136
  Batch 2400/17006, Loss: 7.3744
  Batch 2500/17006, Loss: 7.7510
  Batch 2600/17006, Loss: 7.9412
  Batch 2700/17006, Loss: 7.7274
  Batch 2800/17006, Loss: 7.3608
  Batch 2900/17006, Loss: 7.4658
  Batch 3000/1

In [40]:
# Assuming 'model' and 'optimizer' are your trained model and optimizer objects
# and 'epoch' is the last completed epoch number.

# Define a path to save the checkpoint
SAVE_PATH = "bert_mlm_checkpoint.pth"

torch.save({
    'epoch': epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss, # You can save the last loss value too
    }, SAVE_PATH)

print(f"Model saved to {SAVE_PATH}")

Model saved to bert_mlm_checkpoint.pth


In [41]:
# --- In a new script, or later in your notebook ---

# 1. Define the model architecture again (or import it)
# Make sure all your classes like SelfAttention, TransformerBlock, Encoder, and BERT are defined.

# 2. Instantiate the model with the SAME hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters must match the saved model
VOCAB_SIZE = 30000 
EMBED_SIZE = 256
NUM_LAYERS = 6
HEADS = 8
FORWARD_EXPANSION = 4
DROPOUT = 0.1
MAX_LENGTH = 128

# Create an instance of the model
encoder = Encoder(VOCAB_SIZE, EMBED_SIZE, NUM_LAYERS, HEADS, device, FORWARD_EXPANSION, DROPOUT, MAX_LENGTH)
loaded_model = BERT(encoder, VOCAB_SIZE).to(device)

# 3. Load the checkpoint
checkpoint = torch.load(SAVE_PATH, map_location=device)
loaded_model.load_state_dict(checkpoint['model_state_dict'])

print("Model weights loaded successfully!")

RuntimeError: Error(s) in loading state_dict for BERT:
	Missing key(s) in state_dict: "encoder.word_embedding.weight", "encoder.position_embedding.weight", "encoder.layers.0.attention.values.weight", "encoder.layers.0.attention.keys.weight", "encoder.layers.0.attention.query.weight", "encoder.layers.0.attention.fc_out.weight", "encoder.layers.0.attention.fc_out.bias", "encoder.layers.0.norm1.weight", "encoder.layers.0.norm1.bias", "encoder.layers.0.norm2.weight", "encoder.layers.0.norm2.bias", "encoder.layers.0.feed_forward.0.weight", "encoder.layers.0.feed_forward.0.bias", "encoder.layers.0.feed_forward.2.weight", "encoder.layers.0.feed_forward.2.bias", "encoder.layers.1.attention.values.weight", "encoder.layers.1.attention.keys.weight", "encoder.layers.1.attention.query.weight", "encoder.layers.1.attention.fc_out.weight", "encoder.layers.1.attention.fc_out.bias", "encoder.layers.1.norm1.weight", "encoder.layers.1.norm1.bias", "encoder.layers.1.norm2.weight", "encoder.layers.1.norm2.bias", "encoder.layers.1.feed_forward.0.weight", "encoder.layers.1.feed_forward.0.bias", "encoder.layers.1.feed_forward.2.weight", "encoder.layers.1.feed_forward.2.bias", "encoder.layers.2.attention.values.weight", "encoder.layers.2.attention.keys.weight", "encoder.layers.2.attention.query.weight", "encoder.layers.2.attention.fc_out.weight", "encoder.layers.2.attention.fc_out.bias", "encoder.layers.2.norm1.weight", "encoder.layers.2.norm1.bias", "encoder.layers.2.norm2.weight", "encoder.layers.2.norm2.bias", "encoder.layers.2.feed_forward.0.weight", "encoder.layers.2.feed_forward.0.bias", "encoder.layers.2.feed_forward.2.weight", "encoder.layers.2.feed_forward.2.bias", "encoder.layers.3.attention.values.weight", "encoder.layers.3.attention.keys.weight", "encoder.layers.3.attention.query.weight", "encoder.layers.3.attention.fc_out.weight", "encoder.layers.3.attention.fc_out.bias", "encoder.layers.3.norm1.weight", "encoder.layers.3.norm1.bias", "encoder.layers.3.norm2.weight", "encoder.layers.3.norm2.bias", "encoder.layers.3.feed_forward.0.weight", "encoder.layers.3.feed_forward.0.bias", "encoder.layers.3.feed_forward.2.weight", "encoder.layers.3.feed_forward.2.bias", "encoder.layers.4.attention.values.weight", "encoder.layers.4.attention.keys.weight", "encoder.layers.4.attention.query.weight", "encoder.layers.4.attention.fc_out.weight", "encoder.layers.4.attention.fc_out.bias", "encoder.layers.4.norm1.weight", "encoder.layers.4.norm1.bias", "encoder.layers.4.norm2.weight", "encoder.layers.4.norm2.bias", "encoder.layers.4.feed_forward.0.weight", "encoder.layers.4.feed_forward.0.bias", "encoder.layers.4.feed_forward.2.weight", "encoder.layers.4.feed_forward.2.bias", "encoder.layers.5.attention.values.weight", "encoder.layers.5.attention.keys.weight", "encoder.layers.5.attention.query.weight", "encoder.layers.5.attention.fc_out.weight", "encoder.layers.5.attention.fc_out.bias", "encoder.layers.5.norm1.weight", "encoder.layers.5.norm1.bias", "encoder.layers.5.norm2.weight", "encoder.layers.5.norm2.bias", "encoder.layers.5.feed_forward.0.weight", "encoder.layers.5.feed_forward.0.bias", "encoder.layers.5.feed_forward.2.weight", "encoder.layers.5.feed_forward.2.bias", "fc_out.weight", "fc_out.bias". 
	Unexpected key(s) in state_dict: "encoder.embedding.weight". 

In [30]:
# --- Example Usage ---
# Make sure you have your tokenizer loaded
tokenizer = Tokenizer.from_file("wiki-bpe-tokenizer.json")

# Set the model to evaluation mode
loaded_model.eval()

# Let's get embeddings for a new sentence
test_sentence = "The capital of India is New Delhi."
embeddings = get_embeddings(test_sentence, loaded_model, tokenizer, device)

# The output shape will be (batch_size, sequence_length, embedding_size)
print(f"Sentence: '{test_sentence}'")
print(f"Shape of output embeddings: {embeddings.shape}")

# You can get the embedding for the entire sentence from the [CLS] token (the first token)
cls_embedding = embeddings[0, 0, :]
print(f"Shape of the [CLS] token embedding: {cls_embedding.shape}")

# --- ROBUST METHOD TO FIND A SPECIFIC WORD ---
tokens = tokenizer.encode(test_sentence).tokens
print("\nGenerated tokens:", tokens)

target_word = "India"
india_index = -1

# Find the index of the token corresponding to the target word
for i, token in enumerate(tokens):
    # The token could be the word itself, or the word with a space prefix 'Ġ'
    if target_word in token:
        india_index = i
        break

if india_index != -1:
    india_embedding = embeddings[0, india_index, :]
    print(f"Shape of the '{target_word}' token embedding: {india_embedding.shape}")
else:
    # This might happen if the word is split into multiple sub-tokens
    print(f"Could not find a single token for the word '{target_word}'. It might be split into subwords.")

Sentence: 'The capital of India is New Delhi.'
Shape of output embeddings: (256,)


IndexError: too many indices for array: array is 1-dimensional, but 3 were indexed

In [22]:
!pip install numpy

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [42]:
import numpy as np
import heapq
import random
from tokenizers import Tokenizer
import torch

# Assume the BERT model and Encoder classes from your notebook are defined here
# For demonstration, I'm including placeholder classes.
# Replace these with your actual model definitions.
class Encoder(torch.nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length):
        super(Encoder, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, embed_size)
    def forward(self, x, mask):
        return self.embedding(x)

class BERT(torch.nn.Module):
    def __init__(self, encoder, vocab_size):
        super(BERT, self).__init__()
        self.encoder = encoder
    def forward(self, x, mask):
        return self.encoder(x, mask)


def get_embeddings(sentence, model, tokenizer, device, max_length=128):
    """
    Generates embeddings for a given sentence using the provided BERT model.
    """
    model.eval()
    tokenizer.enable_padding(length=max_length, pad_id=tokenizer.token_to_id("[PAD]"))
    tokenizer.enable_truncation(max_length=max_length)
    encoding = tokenizer.encode(sentence)
    input_ids = torch.tensor(encoding.ids, dtype=torch.long).unsqueeze(0).to(device)
    attention_mask = torch.tensor(encoding.attention_mask, dtype=torch.long).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model.encoder(input_ids, attention_mask)
        # We'll use the embedding of the [CLS] token as the sentence representation
        cls_embedding = outputs[0, 0, :]
    return cls_embedding.cpu().numpy()

class HNSW:
    """
    Hierarchical Navigable Small World (HNSW) implementation for vector indexing.
    """
    def __init__(self, m=16, ef_construction=200, ef_search=50, ml=0.5):
        """
        Initializes the HNSW index.

        Args:
            m (int): The maximum number of connections for each node.
            ef_construction (int): The size of the dynamic list for insertion.
            ef_search (int): The size of the dynamic list for search.
            ml (float): The normalization factor for level generation.
        """
        self.m = m
        self.ef_construction = ef_construction
        self.ef_search = ef_search
        self.ml = ml
        self.graphs = []
        self.entry_point = None
        self.vectors = {}
        self.next_node_id = 0

    def _distance(self, v1, v2):
        """Calculates the Euclidean distance between two vectors."""
        # Ensure inputs are numpy arrays for safety
        v1_np = np.asarray(v1)
        v2_np = np.asarray(v2)
        return np.linalg.norm(v1_np - v2_np)

    def _get_random_level(self):
        """Generates a random level for a new node."""
        return int(-np.log(np.random.rand()) * self.ml)

    def _search_layer(self, query_vec, entry_point, num_neighbors, layer_num):
        """
        Searches for the nearest neighbors in a specific layer.
        """
        if not entry_point or entry_point['id'] not in self.graphs[layer_num]:
             return []
        
        visited = {entry_point['id']}
        
        dist_to_entry = self._distance(query_vec, self.vectors[entry_point['id']])
        
        candidates = [(-dist_to_entry, entry_point['id'])]
        nearest_neighbors = [(dist_to_entry, entry_point['id'])]

        while candidates:
            dist, node_id = heapq.heappop(candidates)
            dist = -dist

            if len(nearest_neighbors) >= num_neighbors and dist > nearest_neighbors[-1][0]:
                break

            node = self.graphs[layer_num][node_id]
            for neighbor_id in node['neighbors']:
                if neighbor_id not in visited:
                    visited.add(neighbor_id)
                    d = self._distance(query_vec, self.vectors[neighbor_id])
                    if len(nearest_neighbors) < num_neighbors or d < nearest_neighbors[-1][0]:
                        heapq.heappush(candidates, (-d, neighbor_id))
                        heapq.heappush(nearest_neighbors, (d, neighbor_id))
                        if len(nearest_neighbors) > num_neighbors:
                            heapq.heappop(nearest_neighbors)

        return sorted(nearest_neighbors)

    def _select_neighbors(self, query_vec, candidates, m):
        """
        Selects the best neighbors from a list of candidates.
        """
        if not candidates:
            return []
        # Sort by distance (first element of tuple) and take the top M
        return sorted(candidates, key=lambda c: c[0])[:m]

    def add(self, vector):
        """
        Adds a vector to the HNSW index.
        """
        node_id = self.next_node_id
        
        # FIX: If the input vector is a PyTorch tensor, move it to CPU and convert to NumPy.
        # This makes the HNSW class more robust to different input types.
        if torch.is_tensor(vector):
            vector = vector.cpu().numpy()

        # Ensure vector is a flat numpy array of a consistent type
        self.vectors[node_id] = np.asarray(vector, dtype=np.float32).flatten()
        self.next_node_id += 1

        level = self._get_random_level()

        # Ensure graph lists are long enough
        while len(self.graphs) <= level:
            self.graphs.append({})

        # Add the new node to all its levels in the graph
        for i in range(level + 1):
            self.graphs[i][node_id] = {'neighbors': []}

        # If this is the first node, set it as the entry point and return
        if self.entry_point is None:
            self.entry_point = {'id': node_id, 'level': level}
            return

        # Start search from the top-level entry point
        ep = self.entry_point

        # Phase 1: Find the entry point for each level down to the new node's level + 1
        for i in range(self.entry_point['level'], level, -1):
            nearest_neighbor_list = self._search_layer(vector, ep, 1, i)
            if not nearest_neighbor_list: break # Should not happen if graph is connected
            ep = {'id': nearest_neighbor_list[0][1], 'level': i - 1}

        # Phase 2: Insert the new node by connecting it to neighbors, from insertion level down to 0
        for i in range(min(level, self.entry_point['level']), -1, -1):
            # Find the ef_construction nearest neighbors at the current level
            nearest_neighbors = self._search_layer(vector, ep, self.ef_construction, i)
            
            # Select the M best neighbors to connect to the new node
            neighbors_to_connect = self._select_neighbors(vector, nearest_neighbors, self.m)
            self.graphs[i][node_id]['neighbors'] = [n[1] for n in neighbors_to_connect]

            # Connect the selected neighbors back to the new node
            for _, neighbor_id in neighbors_to_connect:
                neighbor_connections = self.graphs[i][neighbor_id]['neighbors']
                
                # Add the new node to the neighbor's connections
                neighbor_connections.append(node_id)
                
                # If the neighbor now has too many connections, prune the farthest one
                if len(neighbor_connections) > self.m:
                    distances = [(self._distance(self.vectors[neighbor_id], self.vectors[c_id]), c_id) for c_id in neighbor_connections]
                    self.graphs[i][neighbor_id]['neighbors'] = [c[1] for c in sorted(distances)[:self.m]]
            
            # Update the entry point for the next level down to be the closest node found at this level
            if nearest_neighbors:
                ep = {'id': nearest_neighbors[0][1], 'level': i - 1}

        # If the new node is at a higher level than the current entry point, update it
        if level > self.entry_point['level']:
            self.entry_point = {'id': node_id, 'level': level}

    def search(self, query_vec, k):
        """
        Searches for the k-nearest neighbors to a query vector.
        """
        # FIX: Ensure query vector is a NumPy array on the CPU.
        if torch.is_tensor(query_vec):
            query_vec = query_vec.cpu().numpy()

        if not self.entry_point:
            return []

        ep = self.entry_point
        # Phase 1: Greedily search from the top level down to level 1
        for i in range(self.entry_point['level'], 0, -1):
            nearest_neighbor_list = self._search_layer(query_vec, ep, 1, i)
            if not nearest_neighbor_list: break
            ep = {'id': nearest_neighbor_list[0][1], 'level': i - 1}

        # Phase 2: Perform an exhaustive search at the base layer (level 0)
        nearest_neighbors = self._search_layer(query_vec, ep, self.ef_search, 0)
        return sorted(nearest_neighbors)[:k]


if __name__ == '__main__':
    # --- Setup ---
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load your trained tokenizer
    try:
        tokenizer = Tokenizer.from_file("wiki-bpe-tokenizer.json")
    except FileNotFoundError:
        print("Tokenizer file not found. Please ensure 'wiki-bpe-tokenizer.json' is in the same directory.")
        exit()


    # --- Load your trained model ---
    # Make sure the hyperparameters match your saved model
    VOCAB_SIZE = tokenizer.get_vocab_size()
    EMBED_SIZE = 256
    NUM_LAYERS = 6
    HEADS = 8
    FORWARD_EXPANSION = 4
    DROPOUT = 0.1
    MAX_LENGTH = 128

    encoder = Encoder(VOCAB_SIZE, EMBED_SIZE, NUM_LAYERS, HEADS, device, FORWARD_EXPANSION, DROPOUT, MAX_LENGTH)
    model = BERT(encoder, VOCAB_SIZE).to(device)

    try:
        checkpoint = torch.load("bert_mlm_checkpoint.pth", map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print("Model loaded successfully.")
    except FileNotFoundError:
        print("Model checkpoint not found. Using a randomly initialized model for demonstration.")


    # --- Create and populate the HNSW index ---
    hnsw_index = HNSW()

    # Sample documents to be indexed
    documents = [
        "The capital of France is Paris.",
        "The Eiffel Tower is a famous landmark in Paris.",
        "The currency of Japan is the Yen.",
        "Tokyo is the largest city in Japan.",
        "The Great Wall of China is an ancient wonder.",
        "Beijing is the capital of China."
    ]

    print("\nIndexing documents...")
    for i, doc in enumerate(documents):
        embedding = get_embeddings(doc, model, tokenizer, device)
        hnsw_index.add(embedding)
        print(f"  Indexed document {i+1}/{len(documents)}")

    # --- Perform a search ---
    query_sentence = "What is the capital of Japan?"
    query_embedding = get_embeddings(query_sentence, model, tokenizer, device)

    print(f"\nSearching for: '{query_sentence}'")
    results = hnsw_index.search(query_embedding, k=3)

    print("\nSearch Results:")
    if results:
        for dist, node_id in results:
            # Check if node_id is valid
            if node_id < len(documents):
                print(f"  - Document: '{documents[node_id]}', Distance: {dist:.4f}")
            else:
                print(f"  - Invalid node_id found: {node_id}")
    else:
        print("  No results found.")


Error loading state_dict: Error(s) in loading state_dict for BERT:
	Missing key(s) in state_dict: "encoder.word_embedding.weight", "encoder.position_embedding.weight", "encoder.layers.0.attention.values.weight", "encoder.layers.0.attention.keys.weight", "encoder.layers.0.attention.query.weight", "encoder.layers.0.attention.fc_out.weight", "encoder.layers.0.attention.fc_out.bias", "encoder.layers.0.norm1.weight", "encoder.layers.0.norm1.bias", "encoder.layers.0.norm2.weight", "encoder.layers.0.norm2.bias", "encoder.layers.0.feed_forward.0.weight", "encoder.layers.0.feed_forward.0.bias", "encoder.layers.0.feed_forward.2.weight", "encoder.layers.0.feed_forward.2.bias", "encoder.layers.1.attention.values.weight", "encoder.layers.1.attention.keys.weight", "encoder.layers.1.attention.query.weight", "encoder.layers.1.attention.fc_out.weight", "encoder.layers.1.attention.fc_out.bias", "encoder.layers.1.norm1.weight", "encoder.layers.1.norm1.bias", "encoder.layers.1.norm2.weight", "encoder.laye