All of this is ran in a docker container using the following image:

nvcr.io/nvidia/tensorflow:23.12-tf2-py3

In [1]:
import os
import sys

# Add root directory (one level up from notebooks/)
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

download and install wikiextractor

In [None]:
if not os.path.isdir(r"../wikiextractor-master"):
    # Step 1: Download the ZIP file
    !curl -L -o ../wikiextractor.zip https://github.com/qfcy/wikiextractor/archive/refs/heads/master.zip

    # Step 2: Extract it
    import zipfile
    import os

    zip_path = r"../wikiextractor.zip"
    extract_to = r"../"

    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)

    # Step 3: Delete the ZIP file
    os.remove(zip_path)

    # Step 4: Install Wikiextractor
    !pip install -e ../wikiextractor-master
else:
    print("Wikiextractor already exists")

Wikiextractor already exists


Get wikipedia dump (takes like 2 hours to download)

In [None]:
os.makedirs(r"../data/raw", exist_ok=True)
if not os.path.isfile(r"../data/raw/enwiki-latest-pages-articles.xml.bz2"):
    !wget -P ../data/raw https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2
else:
    print("Wikipedia dump already downloaded")

Wikipedia dump already downloaded


Install wikiextractor if needed

In [13]:
import pkg_resources

try:
    pkg_resources.get_distribution("wikiextractor")
    print("wikiextractor is already installed.")
except pkg_resources.DistributionNotFound:
    print("Installing wikiextractor...")
    !pip install -e ../wikiextractor-master

wikiextractor is already installed.


Extract xml data from wikidump

In [None]:
if not os.path.isdir(r"../data/raw/extracted_wikidata"):
    !python -m wikiextractor.WikiExtractor \
        ../data/raw/enwiki-latest-pages-articles.xml.bz2 \
        -o ../data/raw/extracted_wikidata \
        --no-templates
else:
    print("Wikipedia XML extract already exists")

INFO: Starting page extraction from ../data/raw/enwiki-latest-pages-articles.xml.bz2.
INFO: Using 11 extract processes.


Create json data from wiki-dump

In [None]:
from utils.data_prep import traverse_directory

input_dir = r'../data/raw/extracted_wikidata'
output_dir = r'../data/processed/wikidata_json'

traverse_directory(input_dir, output_dir)

Processing XML files:   0%|          | 0/1 [00:00<?, ?file/s]

Initialize spark session for metadata and triplets creation

Create metadata from articles

Create Training data triplets using Pyspark from json wikidata

Had to install winutils for hadoop and pyspark to work on windows locally.

In [2]:
input_dir = r"../data/processed/wikidata_json"
output_dir = r"../data/processed/triplets/parts"
metadata_path = r"../data/custom_model/article_metadata.json"

In [4]:
from utils.spark_functions import create_article_metadata, create_paragraphs_df, create_triplets
from pyspark.sql import SparkSession

if not os.path.isdir(output_dir) or os.path.isfile(metadata_path):
    # Initialize Spark session
    spark = SparkSession.builder \
        .appName("Capstone") \
        .master("local[*]") \
        .config("spark.driver.memory", "20g") \
        .config("spark.sql.shuffle.partitions", "100") \
        .config("spark.local.dir", "../spark-temp") \
        .config("spark.driver.maxResultSize", "2g") \
        .getOrCreate()
    
    print("Loading JSON input")
    json_df = spark.read.option("multiLine", True).json(f"{input_dir}/**/*.json")

    # Create metadata file from articles
    create_article_metadata(json_df, metadata_path)

    # Logic for creating training data triplets from json
    df = create_paragraphs_df(json_df)
    triplets = create_triplets(df)
    print(f"Writing triplets to Spark part files: {output_dir}")
    triplets.write.mode("overwrite").json(output_dir)

    # Stop spark
    spark.stop()
else:
    print("Triplets data has already exists")

Loading JSON input


                                                                                

Metadata already exists.
Writing triplets to Spark part files: ../data/processed/triplets/parts


                                                                                

In [5]:
from utils.data_prep import delete_files_in_dir_based_on_ext

# Delete unneeded files produced by pyspark
delete_files_in_dir_based_on_ext(output_dir, ".json")

Training embedding model

Embed and Index Wikidata

In [6]:
import tensorflow as tf
from utils.custom_embedder import *


# Parameters
input_dir = "../data/processed/triplets/parts"
vectorizer_path = "../data/custom_model/saved_vectorizer"
vocab_size = 30000
max_len = 32
embed_dim = 128
num_heads = 4
ff_dim = 256
batch_size = 64
num_epochs = 10
weights_path = "../data/custom_model/encoder_weights.h5"

# Load vectorizer
vectorizer = tf.keras.models.load_model(vectorizer_path)

# Load dataset
train_dataset = load_triplet_dataset(input_dir, vectorizer, batch_size)

# Model setup
encoder = CustomEncoder(vocab_size, max_len, embed_dim, num_heads, ff_dim, num_layers=2)
trainer = TripletTrainer(encoder)
trainer.compile(optimizer=tf.keras.optimizers.Adam(1e-3))

# Callbacks
callbacks = [
    EarlyStopping(monitor="loss", patience=2),
    ModelCheckpoint("../data/custom_model/best_encoder.keras", monitor="loss", save_best_only=True)
]

# Training
trainer.fit(train_dataset, epochs=num_epochs, callbacks=callbacks)

# Save final encoder weights
encoder.save_weights(weights_path)
print("✅ Training complete and weights saved.")

2025-05-13 22:27:43.546743: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-13 22:27:43.974103: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9360] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-13 22:27:43.974596: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-13 22:27:43.976919: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1537] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-13 22:27:44.189202: I tensorflow/core/platform/cpu_feature_g

OSError: No file or directory found at ../data/custom_model/saved_vectorizer

Create faiss index if not already existing

In [None]:
from utils.faiss_index import create_faiss_index_from_dir

embedding_path = "../data/custom_model/embeddings_output1"
faiss_path = "../data/custom_model/faiss/faiss_index.index"

create_faiss_index_from_dir(embedding_path, faiss_path)

Found 100 embedding files. Building FAISS index...


Processing embedding files:   0%|          | 0/100 [00:00<?, ?it/s]

FAISS index saved.


In [None]:
# Load saved vectorizer
vectorizer = tf.keras.models.load_model("../data/custom_model/saved_vectorizer")
vectorizer = vectorizer.layers[0]

# Initialize the encoder
encoder = CustomEncoder(vocab_size=30000, max_len=32, embed_dim=128, num_heads=4, ff_dim=256)

# Build the model by calling it on dummy data
_ = encoder(tf.constant([[1] * 32]))  # shape: (1, max_len)

# Now load the weights
encoder.load_weights("../data/custom_model/encoder_weights.h5")

# Your query
query = "What is the function of DNA?"

query_seq = vectorizer(tf.constant([query]))
query_embedding = encoder(query_seq).numpy()



In [None]:
from utils.faiss_index import query_faiss

indices = query_faiss(faiss_path, query_embedding, 10)

# Load article metadata
metadata_path = "../data/custom_model/article_metadata.json"
with open(metadata_path, encoding='utf-8') as f:
    metadata = json.load(f)

# Retrieve top-k articles
results = [metadata[i] for i in indices[0]]

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
from glob import glob

# Gather all article embeddings from .npy files
all_embeddings = []
embedding_files = sorted(glob(os.path.join(embedding_path, '*.npy')))
for file_path in tqdm(embedding_files, desc="Loading embeddings"):
    all_embeddings.append(np.load(file_path))

# Stack into one large matrix (shape: [N, dim])
article_embeddings = np.vstack(all_embeddings)  # Now shape: (N_total, dim)

# Create combined metadata
top_articles = [metadata[i] | {"vec": article_embeddings[i]} for i in indices[0]]

# Rerank using cosine similarity
query_vec = query_embedding[0].reshape(1, -1)
top_articles.sort(
    key=lambda x: cosine_similarity(query_vec, x["vec"].reshape(1, -1))[0][0],
    reverse=True
)

Loading embeddings:   0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
import pprint

final_results = [
    {
        "title": article["title"],
        "url": article.get("url", "N/A")
    }
    # for article in top_articles[:5]
    for article in top_articles
]
pprint.pprint(final_results)

[{'title': 'Whitson, Texas',
  'url': 'https://en.wikipedia.org/wiki?curid=74340115'},
 {'title': 'My Kink Is Karma',
  'url': 'https://en.wikipedia.org/wiki?curid=74920372'},
 {'title': 'Marcial Moreno-Mañas',
  'url': 'https://en.wikipedia.org/wiki?curid=74635281'},
 {'title': 'Iran at the 2022 Asian Games',
  'url': 'https://en.wikipedia.org/wiki?curid=74864425'},
 {'title': 'Sir Charles Saxton, 2nd Baronet',
  'url': 'https://en.wikipedia.org/wiki?curid=74896144'},
 {'title': "List of Girls' Crystal comic stories",
  'url': 'https://en.wikipedia.org/wiki?curid=74899201'},
 {'title': 'Not My Neighbour',
  'url': 'https://en.wikipedia.org/wiki?curid=74886103'},
 {'title': 'Alaska Airlines Flight 2059',
  'url': 'https://en.wikipedia.org/wiki?curid=75127810'},
 {'title': 'Ting Ting Chaoro',
  'url': 'https://en.wikipedia.org/wiki?curid=74877608'},
 {'title': 'Detdet Pepito',
  'url': 'https://en.wikipedia.org/wiki?curid=74590708'}]
