All of this is ran in a docker container using the following image:

nvcr.io/nvidia/tensorflow:23.12-tf2-py3

In [1]:
import os
import sys

# Add root directory (one level up from notebooks/)
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

download and install wikiextractor

In [2]:
if not os.path.isdir(r"../wikiextractor-master"):
    # Step 1: Download the ZIP file
    !curl -L -o ../wikiextractor.zip https://github.com/qfcy/wikiextractor/archive/refs/heads/master.zip

    # Step 2: Extract it
    import zipfile
    import os

    zip_path = r"../wikiextractor.zip"
    extract_to = r"../"

    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)

    # Step 3: Delete the ZIP file
    os.remove(zip_path)

    # Step 4: Install Wikiextractor
    !pip install -e ../wikiextractor-master
else:
    print("Wikiextractor already exists")

Wikiextractor already exists


Get wikipedia dump (takes like 2 hours to download)

In [4]:
os.makedirs("../data/raw", exist_ok=True)
if not os.path.isfile(r"../data/raw/enwiki-latest-pages-articles.xml.bz2"):
    !wget -P ../data/raw https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2
else:
    print("Wikipedia dump already downloaded")

Wikipedia dump already downloaded


Extract xml data from wikidump

In [5]:
if not os.path.isdir(r"../data/raw/extracted_wikidata"):
    !python -m wikiextractor.WikiExtractor \
        ../data/raw/enwiki-latest-pages-articles.xml.bz2 \
        -o ../data/raw/extracted_wikidata \
        --no-templates
else:
    print("Wikipedia XML extract already exists")

Wikipedia XML extract already exists


Create json data from wiki-dump

In [6]:
from utils.data_prep import traverse_directory

input_dir = r'../data/raw/extracted_wikidata'
output_dir = r'../data/processed/wikidata_json'

if not os.path.isdir(output_dir):
    traverse_directory(input_dir, output_dir)
else:
    print("wikidata_json already exists")

wikidata_json already exists


Initialize spark session for metadata and create random triplets

Create metadata from articles

Create Training data of random triplets using Pyspark from json wikidata

Had to install winutils for hadoop and pyspark to work on windows locally.

In [2]:
input_dir = r"../data/processed/wikidata_json"
output_dir = r"../data/processed/triplets/parts"
metadata_path = r"../data/custom_model/article_metadata.json"

In [None]:
from utils.spark_functions import *
from pyspark.sql import SparkSession

# Initialize Spark session
spark = SparkSession.builder \
        .appName("Capstone") \
        .master("local[*]") \
        .config("spark.driver.memory", "20g") \
        .config("spark.sql.shuffle.partitions", "100") \
        .config("spark.local.dir", "../spark-temp") \
        .config("spark.driver.maxResultSize", "2g") \
        .getOrCreate()

if not os.path.isdir(output_dir) or not os.path.isfile(metadata_path):
    print("Loading JSON input")
    json_df = spark.read.option("multiLine", True).json(f"{input_dir}/**/*.json")

    # Create metadata file from articles
    create_article_metadata(json_df, metadata_path)

    # Logic for creating training data triplets from json
    df = create_paragraphs_df(json_df)
    triplets = create_random_triplets(df)
    print(f"Writing triplets to Spark part files: {output_dir}")
    triplets.write.mode("overwrite").json(output_dir)

else:
    print("Triplets data already exists")

# Calculate total lines so that we can determine epoch size
total_lines = spark.read.json(f"{input_dir}/*.json", multiLine=False).count()
print("Total lines across training files:", total_lines)

# Stop spark
spark.stop()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/05/14 19:55:29 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/05/14 19:55:29 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


Triplets data already exists


                                                                                

Total lines across training files: 4091164


In [9]:
from utils.data_prep import delete_files_in_dir_based_on_ext

# Delete unneeded files produced by pyspark
delete_files_in_dir_based_on_ext(output_dir, ".json")

Training embedding model using Random Triplets Data

grid search on smaller subset of data for speed

In [2]:
import logging
import tensorflow as tf
from itertools import product
from utils.custom_embedder import *

# Turn off warnings for Tensorflow 
tf_logger = logging.getLogger("tensorflow")
tf_logger.setLevel(logging.ERROR)

# So we don't have to rerun it every time
total_lines = 4091164 

2025-05-16 04:08:46.117175: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-16 04:08:46.431699: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9360] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-16 04:08:46.431760: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-16 04:08:46.431942: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1537] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-16 04:08:46.570366: I tensorflow/core/platform/cpu_feature_g

In [None]:
# Parameters
input_dir = "../data/processed/triplets/parts_test"
vectorizer_dir = "../data/custom_model/saved_vectorizer"
weights_dir = "../data/custom_model/encoder_weights"
vocab_size = 30000
max_len = 32
embed_dim = 128
num_heads = 4
ff_dim = 256
batch_size = 512
num_epochs = 30

# Load or create vectorizer
if os.path.exists(vectorizer_dir):
    print("Loading saved vectorizer")
    vectorizer = tf.keras.models.load_model(vectorizer_dir)
else:
    vectorizer = create_vectorizer(input_dir, vectorizer_dir, vocab_size, max_len)


search_space = {
    "embed_dim": [128, 256],
    "num_heads": [4, 8],
    "ff_dim": [256, 512],
    "num_layers": [2, 3],
    "learning_rate": [5e-4],
}

keys, values = zip(*search_space.items())
configs = [dict(zip(keys, v)) for v in product(*values)]

results = []
for i, config in enumerate(configs):
    loss, cfg, ckpt = train_with_config(config, vectorizer, input_dir, batch_size, (total_lines // batch_size)//10, run_id=i)
    results.append((loss, cfg, ckpt))

# Sort by loss
results.sort(key=lambda x: x[0])

# Print best
best_loss, best_config, best_checkpoint = results[0]
print("🏆 Best Config:", best_config)
print("📉 Best Loss:", best_loss)
print("💾 Best Checkpoint:", best_checkpoint)


Loading saved vectorizer
🔧 Training config 0: {'embed_dim': 128, 'num_heads': 4, 'ff_dim': 256, 'num_layers': 2, 'learning_rate': 0.0005}
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
🔧 Training config 1: {'embed_dim': 128, 'num_heads': 4, 'ff_dim': 256, 'num_layers': 3, 'learning_rate': 0.0005}
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
🔧 Training config 2: {'embed_dim': 128, 'num_heads': 4, 'ff_dim': 512, 'num_layers': 2, 'learning_rate': 0.0005}
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
🔧 Training config 3: {'embed_dim': 128, 'num_heads': 4, 'ff_dim': 512, 'num_layers': 3, 'learning_rate': 0.0005}
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
🔧 Training config 4: {'embed_dim': 128, 'num_heads': 8, 'ff_dim': 256, 'num_layers': 2, 'learning_rate': 0.0005}
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
🔧 Training config 5: {'embed_dim': 128, 'num_heads': 8, 'ff_dim': 256, 'num_layers': 3, 'learning_rate': 0.0005}
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5

ResourceExhaustedError: Graph execution error:

Detected at node custom_encoder_18/transformer_block_61/layer_normalization_123/batchnorm_2/mul defined at (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main

  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code

  File "/usr/local/lib/python3.10/dist-packages/ipykernel_launcher.py", line 17, in <module>

  File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 1043, in launch_instance

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 739, in start

  File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 205, in start

  File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever

  File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once

  File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 529, in dispatch_queue

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 518, in process_one

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 424, in dispatch_shell

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 766, in execute_request

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 429, in do_execute

  File "/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py", line 549, in run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3048, in run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3103, in _run_cell

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes

  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3550, in run_code

  File "/tmp/ipykernel_88948/72152965.py", line 42, in <module>

  File "/opt/files/Capstone/WikipediaNLP/utils/custom_embedder.py", line 203, in train_with_config

  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 1783, in fit

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 1377, in train_function

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 1360, in step_function

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 1349, in run_step

  File "/opt/files/Capstone/WikipediaNLP/utils/custom_embedder.py", line 79, in train_step

  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py", line 589, in __call__

  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/opt/files/Capstone/WikipediaNLP/utils/custom_embedder.py", line 46, in call

  File "/opt/files/Capstone/WikipediaNLP/utils/custom_embedder.py", line 47, in call

  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/opt/files/Capstone/WikipediaNLP/utils/custom_embedder.py", line 32, in call

  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 65, in error_handler

  File "/usr/local/lib/python3.10/dist-packages/keras/src/engine/base_layer.py", line 1149, in __call__

  File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 96, in error_handler

  File "/usr/local/lib/python3.10/dist-packages/keras/src/layers/normalization/layer_normalization.py", line 297, in call

failed to allocate memory
	 [[{{node custom_encoder_18/transformer_block_61/layer_normalization_123/batchnorm_2/mul}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_train_function_398073]

Embed and Index Wikidata

In [None]:
# Parameters
input_dir = "../data/processed/triplets/parts"
vectorizer_dir = "../data/custom_model/saved_vectorizer"
weights_dir = "../data/custom_model/encoder_weights"
vocab_size = 30000
max_len = 32
embed_dim = 128
num_heads = 4
ff_dim = 256
batch_size = 512
num_epochs = 30
learning_rate = 5e-4

# Load or create vectorizer
if os.path.exists(vectorizer_dir):
    print("Loading saved vectorizer")
    vectorizer = tf.keras.models.load_model(vectorizer_dir)
else:
    vectorizer = create_vectorizer(input_dir, vectorizer_dir, vocab_size, max_len)

# Load dataset
train_dataset = load_triplet_dataset_streamed(input_dir, vectorizer, batch_size)

# Model setup
encoder = CustomEncoder(vocab_size, max_len, embed_dim, num_heads, ff_dim)
trainer = TripletTrainer(encoder)
trainer.compile(optimizer=tf.keras.optimizers.Adam(learning_rate))

# Callbacks
callbacks = [
    EarlyStopping(monitor="loss", patience=2),
    ModelCheckpoint(
        filepath=f"{weights_dir}/best_encoder.weights.h5",
        monitor="loss",
        save_best_only=True,
        save_weights_only=True
    )
]

# Training
trainer.fit(
    train_dataset.repeat(),  # infinite generator
    steps_per_epoch=total_lines // batch_size,
    epochs=num_epochs,
    callbacks=callbacks
)

print("Training complete and weights saved.")

Loading saved vectorizer
Epoch 1/30
1500/7990 [====>.........................] - ETA: 29:31 - loss: 0.2886

KeyboardInterrupt: 

In [None]:
vectorizer = tf.keras.models.load_model(vectorizer_dir)

triplet_dataset = load_triplet_dataset_streamed(input_dir, vectorizer, batch_size)

anchor_embeddings = []

for anchor_batch, _, _ in triplet_dataset:
    emb = encoder(anchor_batch, training=False)  # shape (batch_size, embed_dim)
    anchor_embeddings.append(emb.numpy())

all_anchor_embeddings = np.concatenate(anchor_embeddings, axis=0)
np.save("../data/custom_model/embeddings_output/wiki_anchor_embeddings.npy", all_anchor_embeddings)
print("✅ Saved embeddings to 'wiki_anchor_embeddings.npy'")

✅ Saved embeddings to 'wiki_anchor_embeddings.npy'


Create faiss index if not already existing

In [None]:
from utils.faiss_index import create_faiss_index_from_dir

embedding_path = "../data/custom_model/embeddings_output"
faiss_path = "../data/custom_model/faiss/faiss_index.index"

# Creates embeddings and faiss.index
create_faiss_index_from_dir(embedding_path, faiss_path)

Found 1 embedding files. Building FAISS index...


Processing embedding files:   0%|          | 0/1 [00:00<?, ?it/s]

FAISS index saved.


In [5]:
input_dir = "../data/processed/triplets/parts"
vectorizer_dir = "../data/custom_model/saved_vectorizer"
weights_dir = "../data/custom_model/encoder_weights"
vocab_size = 30000
max_len = 32
embed_dim = 128
num_heads = 4
ff_dim = 256
batch_size = 512
num_epochs = 10
learning_rate = 5e-4

# Load saved vectorizer
vectorizer = tf.keras.models.load_model(vectorizer_dir)
vectorizer = vectorizer.layers[0]

# Initialize the encoder
encoder = CustomEncoder(vocab_size, max_len, embed_dim, num_heads, ff_dim)

# Build the model by calling it on dummy data
# _ = encoder(tf.constant([[1] * 32]))  # shape: (1, max_len)

# Now load the weights
encoder.load_weights(f"{weights_dir}/best_encoder.weights.h5")

# Your query
query = "What is the function of DNA?"

query_seq = vectorizer(tf.constant([query]))
query_embedding = encoder(query_seq).numpy()

In [7]:
from utils.faiss_index import query_faiss
faiss_path = "../data/custom_model/faiss/faiss_index.index"
indices = query_faiss(faiss_path, query_embedding, 10)

# Load article metadata
metadata_path = "../data/custom_model/article_metadata.json"
with open(metadata_path, encoding='utf-8') as f:
    metadata = json.load(f)

# Retrieve top-k articles
results = [metadata[i] for i in indices[0]]

In [12]:
from sklearn.metrics.pairwise import cosine_similarity

embedding_path = "../data/custom_model/embeddings_output/wiki_anchor_embeddings.npy"

# Normalize for cosine similarity
query_vec = query_embedding[0].reshape(1, -1)
article_embeddings = np.load(embedding_path)
top_articles = [metadata[i] | {"vec": article_embeddings[i]} for i in indices[0]]

# Rerank
top_articles.sort(
    key=lambda x: cosine_similarity(query_vec, x["vec"].reshape(1, -1))[0][0],
    reverse=True
)

final_results = [
    {
        "title": article["title"],
        "url": article.get("url", "N/A")
    }
    for article in top_articles[:5]
]

# Print nicely
import pprint
pprint.pprint(final_results)

[{'title': "Yangon Children's Hospital",
  'url': 'https://en.wikipedia.org/wiki?curid=22418344'},
 {'title': 'Twitterature',
  'url': 'https://en.wikipedia.org/wiki?curid=43700363'},
 {'title': 'Peddapuram Assembly constituency',
  'url': 'https://en.wikipedia.org/wiki?curid=37904163'},
 {'title': 'Mount Buller (Victoria)',
  'url': 'https://en.wikipedia.org/wiki?curid=6211759'},
 {'title': 'Villanova Wildcats football',
  'url': 'https://en.wikipedia.org/wiki?curid=15373239'}]


In [None]:
from utils.top_k_testing import evaluate_all_metrics, retrieval_function, load_test_set

test_set = load_test_set("../data/test_data/test_queries.json")
results = evaluate_all_metrics(test_set, retrieval_function)

print("Evaluation Metrics:")
for metric, value in results.items():
    print(f"{metric}: {value}")


Evaluation Metrics:
Top-1 Accuracy: 0.0
Top-3 Accuracy: 0.0
Top-5 Accuracy: 0.0
Top-10 Accuracy: 0.0
Precision@1: 0.0
Precision@3: 0.0
Precision@5: 0.0
Precision@10: 0.0
Recall@1: 0.0
Recall@3: 0.0
Recall@5: 0.0
Recall@10: 0.0
MRR: 0.0
