All of this is ran in a docker container using the following image:

nvcr.io/nvidia/tensorflow:23.12-tf2-py3

In [1]:
import os
import sys

# Add root directory (one level up from notebooks/)
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

download and install wikiextractor

In [2]:
if not os.path.isdir(r"../wikiextractor-master"):
    # Step 1: Download the ZIP file
    !curl -L -o ../wikiextractor.zip https://github.com/qfcy/wikiextractor/archive/refs/heads/master.zip

    # Step 2: Extract it
    import zipfile
    import os

    zip_path = r"../wikiextractor.zip"
    extract_to = r"../"

    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)

    # Step 3: Delete the ZIP file
    os.remove(zip_path)

    # Step 4: Install Wikiextractor
    !pip install -e ../wikiextractor-master
else:
    print("Wikiextractor already exists")

Wikiextractor already exists


Get wikipedia dump (takes like 2 hours to download)

In [4]:
os.makedirs("../data/raw", exist_ok=True)
if not os.path.isfile(r"../data/raw/enwiki-latest-pages-articles.xml.bz2"):
    !wget -P ../data/raw https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2
else:
    print("Wikipedia dump already downloaded")

Wikipedia dump already downloaded


Extract xml data from wikidump

In [5]:
if not os.path.isdir(r"../data/raw/extracted_wikidata"):
    !python -m wikiextractor.WikiExtractor \
        ../data/raw/enwiki-latest-pages-articles.xml.bz2 \
        -o ../data/raw/extracted_wikidata \
        --no-templates
else:
    print("Wikipedia XML extract already exists")

Wikipedia XML extract already exists


Create json data from wiki-dump

In [6]:
from utils.data_prep import traverse_directory

input_dir = r'../data/raw/extracted_wikidata'
output_dir = r'../data/processed/wikidata_json'

if not os.path.isdir(output_dir):
    traverse_directory(input_dir, output_dir)
else:
    print("wikidata_json already exists")

wikidata_json already exists


Initialize spark session for metadata and triplets creation

Create metadata from articles

Create Training data triplets using Pyspark from json wikidata

Had to install winutils for hadoop and pyspark to work on windows locally.

In [7]:
input_dir = r"../data/processed/wikidata_json"
output_dir = r"../data/processed/triplets/parts"
metadata_path = r"../data/custom_model/article_metadata.json"

In [8]:
from utils.spark_functions import *
from pyspark.sql import SparkSession

# Initialize Spark session
spark = SparkSession.builder \
        .appName("Capstone") \
        .master("local[*]") \
        .config("spark.driver.memory", "20g") \
        .config("spark.sql.shuffle.partitions", "100") \
        .config("spark.local.dir", "../spark-temp") \
        .config("spark.driver.maxResultSize", "2g") \
        .getOrCreate()

if not os.path.isdir(output_dir) or not os.path.isfile(metadata_path):
    print("Loading JSON input")
    json_df = spark.read.option("multiLine", True).json(f"{input_dir}/**/*.json")

    # Create metadata file from articles
    create_article_metadata(json_df, metadata_path)

    # Logic for creating training data triplets from json
    df = create_paragraphs_df(json_df)
    triplets = create_triplets(df)
    print(f"Writing triplets to Spark part files: {output_dir}")
    triplets.write.mode("overwrite").json(output_dir)

else:
    print("Triplets data already exists")

# Calculate total lines so that we can determine epoch size
total_lines = count_triplets_with_spark(spark, output_dir)
print("Total lines across training files:", total_lines)

# Stop spark
spark.stop()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/05/14 19:55:29 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/05/14 19:55:29 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


Triplets data already exists


                                                                                

Total lines across training files: 4091164


In [9]:
from utils.data_prep import delete_files_in_dir_based_on_ext

# Delete unneeded files produced by pyspark
delete_files_in_dir_based_on_ext(output_dir, ".json")

Training embedding model

Embed and Index Wikidata

In [None]:
import tensorflow as tf
from utils.custom_embedder import *

# Parameters
input_dir = "../data/processed/triplets/parts"
vectorizer_dir = "../data/custom_model/saved_vectorizer"
weights_dir = "../data/custom_model/encoder_weights"
vocab_size = 30000
max_len = 32
embed_dim = 128
num_heads = 4
ff_dim = 256
batch_size = 512
num_epochs = 10

# Load or create vectorizer
if os.path.exists(vectorizer_dir):
    print("Loading saved vectorizer")
    vectorizer = tf.keras.models.load_model(vectorizer_dir)
else:
    vectorizer = create_vectorizer(input_dir, vectorizer_dir, vocab_size, max_len)

# Load dataset
train_dataset = load_triplet_dataset_streamed(input_dir, vectorizer, batch_size)

# Model setup
encoder = CustomEncoder(vocab_size, max_len, embed_dim, num_heads, ff_dim, num_layers=5)
trainer = TripletTrainer(encoder)
trainer.compile(optimizer=tf.keras.optimizers.Adam(1e-3))

# Callbacks
callbacks = [
    EarlyStopping(monitor="loss", patience=2),
    ModelCheckpoint(
        filepath=f"{weights_dir}/best_encoder.weights.h5",
        monitor="loss",
        save_best_only=True,
        save_weights_only=True
    )
]

# Training
trainer.fit(
    train_dataset.repeat(),  # infinite generator
    steps_per_epoch=total_lines // batch_size,
    epochs=num_epochs,
    callbacks=callbacks
)

# Save final encoder weights
encoder.save_weights(weights_path)
print("Training complete and weights saved.")

2025-05-14 19:58:59.674243: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-14 19:58:59.707471: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9360] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-14 19:58:59.707533: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-14 19:58:59.707574: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1537] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-14 19:58:59.717666: I tensorflow/core/platform/cpu_feature_g

Loading saved vectorizer


2025-05-14 19:59:01.269132: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2025-05-14 19:59:01.289363: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2025-05-14 19:59:01.289389: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2025-05-14 19:59:01.291191: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:880] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2025-05-14 19:59:01.291225: I tensorflow/compile

Epoch 1/10


2025-05-14 19:59:16.179132: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f72bc28d890 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-05-14 19:59:16.179169: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce RTX 3070 Ti, Compute Capability 8.6
2025-05-14 19:59:16.182444: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-05-14 19:59:16.193269: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:442] Loaded cuDNN version 8907
2025-05-14 19:59:16.230883: I ./tensorflow/compiler/jit/device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10

In [None]:
from itertools import product

search_space = {
    "embed_dim": [128, 256],
    "num_heads": [4, 8],
    "ff_dim": [256, 512],
    "num_layers": [2, 3],
    "learning_rate": [1e-3, 5e-4],
}

keys, values = zip(*search_space.items())
configs = [dict(zip(keys, v)) for v in product(*values)]

results = []
for i, config in enumerate(configs):
    loss, cfg, ckpt = train_with_config(config, vectorizer, input_dir, batch_size, (total_lines // batch_size), run_id=i)
    results.append((loss, cfg, ckpt))

# Sort by loss
results.sort(key=lambda x: x[0])

# Print best
best_loss, best_config, best_checkpoint = results[0]
print("🏆 Best Config:", best_config)
print("📉 Best Loss:", best_loss)
print("💾 Best Checkpoint:", best_checkpoint)


🔧 Training config 0: {'embed_dim': 128, 'num_heads': 4, 'ff_dim': 256, 'num_layers': 2, 'learning_rate': 0.001}
Epoch 1/5


KeyboardInterrupt: 

Create faiss index if not already existing

In [None]:
from utils.faiss_index import create_faiss_index_from_dir

embedding_path = "../data/custom_model/embeddings_output1"
faiss_path = "../data/custom_model/faiss/faiss_index.index"

create_faiss_index_from_dir(embedding_path, faiss_path)

Found 100 embedding files. Building FAISS index...


Processing embedding files:   0%|          | 0/100 [00:00<?, ?it/s]

FAISS index saved.


In [None]:
# Load saved vectorizer
vectorizer = tf.keras.models.load_model(vectorizer_dir)
vectorizer = vectorizer.layers[0]

# Initialize the encoder
encoder = CustomEncoder(vocab_size=30000, max_len=32, embed_dim=128, num_heads=4, ff_dim=256)

# Build the model by calling it on dummy data
_ = encoder(tf.constant([[1] * 32]))  # shape: (1, max_len)

# Now load the weights
encoder.load_weights(f"{weights_dir}/best_encoder.weights.h5")

# Your query
query = "What is the function of DNA?"

query_seq = vectorizer(tf.constant([query]))
query_embedding = encoder(query_seq).numpy()



In [None]:
from utils.faiss_index import query_faiss

indices = query_faiss(faiss_path, query_embedding, 10)

# Load article metadata
metadata_path = "../data/custom_model/article_metadata.json"
with open(metadata_path, encoding='utf-8') as f:
    metadata = json.load(f)

# Retrieve top-k articles
results = [metadata[i] for i in indices[0]]

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
from glob import glob

# Gather all article embeddings from .npy files
all_embeddings = []
embedding_files = sorted(glob(os.path.join(embedding_path, '*.npy')))
for file_path in tqdm(embedding_files, desc="Loading embeddings"):
    all_embeddings.append(np.load(file_path))

# Stack into one large matrix (shape: [N, dim])
article_embeddings = np.vstack(all_embeddings)  # Now shape: (N_total, dim)

# Create combined metadata
top_articles = [metadata[i] | {"vec": article_embeddings[i]} for i in indices[0]]

# Rerank using cosine similarity
query_vec = query_embedding[0].reshape(1, -1)
top_articles.sort(
    key=lambda x: cosine_similarity(query_vec, x["vec"].reshape(1, -1))[0][0],
    reverse=True
)

Loading embeddings:   0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
import pprint

final_results = [
    {
        "title": article["title"],
        "url": article.get("url", "N/A")
    }
    # for article in top_articles[:5]
    for article in top_articles
]
pprint.pprint(final_results)

[{'title': 'Whitson, Texas',
  'url': 'https://en.wikipedia.org/wiki?curid=74340115'},
 {'title': 'My Kink Is Karma',
  'url': 'https://en.wikipedia.org/wiki?curid=74920372'},
 {'title': 'Marcial Moreno-Mañas',
  'url': 'https://en.wikipedia.org/wiki?curid=74635281'},
 {'title': 'Iran at the 2022 Asian Games',
  'url': 'https://en.wikipedia.org/wiki?curid=74864425'},
 {'title': 'Sir Charles Saxton, 2nd Baronet',
  'url': 'https://en.wikipedia.org/wiki?curid=74896144'},
 {'title': "List of Girls' Crystal comic stories",
  'url': 'https://en.wikipedia.org/wiki?curid=74899201'},
 {'title': 'Not My Neighbour',
  'url': 'https://en.wikipedia.org/wiki?curid=74886103'},
 {'title': 'Alaska Airlines Flight 2059',
  'url': 'https://en.wikipedia.org/wiki?curid=75127810'},
 {'title': 'Ting Ting Chaoro',
  'url': 'https://en.wikipedia.org/wiki?curid=74877608'},
 {'title': 'Detdet Pepito',
  'url': 'https://en.wikipedia.org/wiki?curid=74590708'}]
