# ZenML pipeline for Learning-to-Rank (LTR) with XGBoost. We:
  ### - We already performed chunking + embedding in a separate pipeline,
    storing question_embedding/context_embedding in MongoDB.
  ### - Each question may have multiple contexts: 1 positive, 2 negatives, etc.
  ### - Our final goal is to rank contexts so that relevant ones appear first.


### Steps:
  1) fetch_ltr_data: Reads from "ltr_dataset" collection in MongoDB.
  2) split_train_test: Splits data into train/test sets.
  3) build_features: Computes numeric features (e.g. cos_sim, L2 distance).
  4) train_ranker: Trains an XGBRanker in ranking mode (rank:ndcg).
  5) evaluate_ranker: Computes NDCG@k on test set.
  6) save_ranker: Saves the final model artifact.

In this pipeline, we demonstrate that:
  - A production-ready approach with real embeddings and negative sampling.
  - Grouping by question_id for ranking.
  - Weights & Biases integration for experiment tracking.

In [29]:
import os, subprocess,json
import wandb
import numpy as np
import pandas as pd
from typing import Tuple, List, Any

from zenml.pipelines import pipeline
from zenml.steps import step
from pymongo import MongoClient
from sklearn.model_selection import train_test_split
from xgboost import XGBRanker

In [30]:
def ensure_mongodb_running():
    """Checks if MongoDB is running, and starts it if not."""
    try:
        # Try connecting to MongoDB
        subprocess.run(["mongosh", "--eval", "db.runCommand({ ping: 1 })"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
        print("✅ MongoDB is already running.")
    except subprocess.CalledProcessError:
        print("⚠️ MongoDB is NOT running. Attempting to start it...")
        os.system("brew services start mongodb-community")
        print("✅ MongoDB is running now!!!")


def get_mongo_connection(mongo_uri:str = "mongodb://localhost:27017/", db_name:str = "medimaven_db"):
    client = MongoClient(mongo_uri)
    db = client[db_name]

    return db

In [31]:
ensure_mongodb_running()

✅ MongoDB is already running.


In [32]:
db = get_mongo_connection()

In [33]:
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
    dot_val = np.dot(a, b)
    norm_a = np.linalg.norm(a)
    norm_b = np.linalg.norm(b)
    return dot_val / (norm_a * norm_b + 1e-9)


def l2_distance(a: np.ndarray, b: np.ndarray) -> float:
    return np.linalg.norm(a - b)

In [34]:
def fetch_ltr_data(
    collection_name: str = "ltr_emb_dataset"
) -> pd.DataFrame:
    """
    Reads the LTR dataset from MongoDB, which already contains:
      question, context, label, question_id, context_id,
      question_embedding, context_embedding, etc.

    Expects each row to represent (question, context) pair with label=1 or 0.
    We have (1) positive for each question, (2) negatives, etc.

    Returns:
      DataFrame with the necessary columns.
    """
    ensure_mongodb_running()
    db = get_mongo_connection()
    collection = db[collection_name]

    data = list(collection.find({}, {'_id': 0}))
    df = pd.DataFrame(data)

    # Basic cleaning
    df.dropna(subset=["question", "context", "question_embeddings", "context_embeddings", "label"], inplace=True)
    df.drop_duplicates(subset=["question_id", "context_id"], inplace=True)

    print(f"Fetched {len(df)} rows from '{collection_name}'")
    return df

In [35]:
ltr_df = fetch_ltr_data()
ltr_df

✅ MongoDB is already running.
Fetched 48070 rows from 'ltr_emb_dataset'


Unnamed: 0,question,context,label,answer,Dataset,focus,synonyms,qtype,speciality,tags,created_at,updated_at,context_length,question_id,context_id,question_embeddings,context_embeddings
0,"1 year old fell, hurt head. Complications?","Frequent fall and hit on the head, if associat...",1,"Hello,\nWelcome to icliniq.com.\nFrequent fall...",iCliniQ,,,,Pediatrics,[],2025-03-04 23:55:05.651,2025-03-04 23:55:05.651,460,0,0,"[0.010470500215888023, -0.017503174021840096, ...","[0.07581809163093567, -0.025507749989628792, -..."
1,"1 year old fell, hurt head. Complications?",Difficulty swallowing and thick mucus can be d...,0,"Hello,\nWelcome to icliniq.com.\nFrequent fall...",iCliniQ,,,,Pediatrics,[],2025-03-04 23:55:05.651,2025-03-04 23:55:05.651,460,0,1,"[0.010470500215888023, -0.017503174021840096, ...","[0.058364640921354294, -0.0880638137459755, 0...."
2,"1 year old fell, hurt head. Complications?",Laryngitis can cause irritation in the lungs a...,0,"Hello,\nWelcome to icliniq.com.\nFrequent fall...",iCliniQ,,,,Pediatrics,[],2025-03-04 23:55:05.651,2025-03-04 23:55:05.651,460,0,2,"[0.010470500215888023, -0.017503174021840096, ...","[0.04832500219345093, -0.02734067291021347, 0...."
3,1.8 years old is not crawling yet. What can be...,Babies usually start crawling by 12 or 13 mont...,1,"Hello,\nWelcome to icliniq.com.\n1. Your daugh...",iCliniQ,,,,Pediatrics,[],2025-03-04 23:55:05.659,2025-03-04 23:55:05.659,330,1,3,"[-0.015964794903993607, -0.021806009113788605,...","[0.012998882681131363, -0.09156108647584915, 0..."
4,1.8 years old is not crawling yet. What can be...,H. Pylori is a type of bacteria responsible fo...,0,"Hello,\nWelcome to icliniq.com.\n1. Your daugh...",iCliniQ,,,,Pediatrics,[],2025-03-04 23:55:05.659,2025-03-04 23:55:05.659,330,1,4,"[-0.015964794903993607, -0.021806009113788605,...","[0.08642300218343735, -0.07272278517484665, -0..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
48065,Would tendinosis cause swelling along with dis...,Tendinosis causes pain and a loss of flexibili...,1,"Hello,\nWelcome to icliniq.com.\nI understand ...",iCliniQ,,,,Family Physician,[],2025-03-04 23:55:05.657,2025-03-04 23:55:05.657,647,28171,48065,"[-0.030853671953082085, -0.024006305262446404,...","[0.016239769756793976, -0.011710045859217644, ..."
48066,"Would using Azelaic acid cream treat acne, bla...",Acne is a skin condition that occurs when hair...,1,"Hello,\nWelcome to icliniq.com.\nI went throug...",iCliniQ,,,,Venereology,[],2025-03-04 23:55:05.657,2025-03-04 23:55:05.657,505,28172,48066,"[-0.03918802738189697, 0.046093229204416275, -...","[-0.011104153469204903, 0.02846224419772625, -..."
48067,Would you recommend photo test for solar urtic...,Hives or urticaria that develop when exposed t...,1,"Hi,\nI am glad you chose icliniq for your medi...",iCliniQ,,,,Dermatology,[],2025-03-04 23:55:05.647,2025-03-04 23:55:05.647,1154,28173,48067,"[-0.017249321565032005, 0.056969717144966125, ...","[0.035568609833717346, 0.04882590472698212, 0...."
48068,an allergic rhinitis cause sneezing fits early...,"Allergic rhinitis, also called hay fever, is a...",1,"Hello,\nWelcome to icliniq.com\nI understand y...",iCliniQ,,,,General Practitioner,[],2025-03-04 23:55:05.650,2025-03-04 23:55:05.650,375,28174,48068,"[0.03091682679951191, 0.018054476007819176, 0....","[0.06415484100580215, -0.022760339081287384, 0..."


Uses question_embedding/context_embedding to generate numeric features:
      - cos_sim
      - l2_dist
      - text length, domain signals, etc.

    Output columns:
      question_id, context_id, label, cos_sim, l2_dist, ...

In [36]:
mod_df = ltr_df.copy()

In [37]:

# Convert stored embeddings from list to np array (if they aren't already)
# We'll produce a final DataFrame with columns for numeric features + label
cos_sims = []
l2_dists = []
context_lengths = []

for idx, row in mod_df.iterrows():
    q_emb = np.array(row["question_embeddings"], dtype=np.float32)
    c_emb = np.array(row["context_embeddings"], dtype=np.float32)

    cos_sims.append(cosine_similarity(q_emb, c_emb))
    l2_dists.append(l2_distance(q_emb, c_emb))
    context_lengths.append(len(str(row["context"]).split()))

In [38]:
mod_df["cos_sim"] = cos_sims
mod_df["l2_dist"] = l2_dists
mod_df["context_length"] = context_lengths

In [39]:
mod_df

Unnamed: 0,question,context,label,answer,Dataset,focus,synonyms,qtype,speciality,tags,created_at,updated_at,context_length,question_id,context_id,question_embeddings,context_embeddings,cos_sim,l2_dist
0,"1 year old fell, hurt head. Complications?","Frequent fall and hit on the head, if associat...",1,"Hello,\nWelcome to icliniq.com.\nFrequent fall...",iCliniQ,,,,Pediatrics,[],2025-03-04 23:55:05.651,2025-03-04 23:55:05.651,89,0,0,"[0.010470500215888023, -0.017503174021840096, ...","[0.07581809163093567, -0.025507749989628792, -...",0.703953,0.769476
1,"1 year old fell, hurt head. Complications?",Difficulty swallowing and thick mucus can be d...,0,"Hello,\nWelcome to icliniq.com.\nFrequent fall...",iCliniQ,,,,Pediatrics,[],2025-03-04 23:55:05.651,2025-03-04 23:55:05.651,89,0,1,"[0.010470500215888023, -0.017503174021840096, ...","[0.058364640921354294, -0.0880638137459755, 0....",0.245083,1.228753
2,"1 year old fell, hurt head. Complications?",Laryngitis can cause irritation in the lungs a...,0,"Hello,\nWelcome to icliniq.com.\nFrequent fall...",iCliniQ,,,,Pediatrics,[],2025-03-04 23:55:05.651,2025-03-04 23:55:05.651,109,0,2,"[0.010470500215888023, -0.017503174021840096, ...","[0.04832500219345093, -0.02734067291021347, 0....",0.255682,1.220096
3,1.8 years old is not crawling yet. What can be...,Babies usually start crawling by 12 or 13 mont...,1,"Hello,\nWelcome to icliniq.com.\n1. Your daugh...",iCliniQ,,,,Pediatrics,[],2025-03-04 23:55:05.659,2025-03-04 23:55:05.659,62,1,3,"[-0.015964794903993607, -0.021806009113788605,...","[0.012998882681131363, -0.09156108647584915, 0...",0.737150,0.725052
4,1.8 years old is not crawling yet. What can be...,H. Pylori is a type of bacteria responsible fo...,0,"Hello,\nWelcome to icliniq.com.\n1. Your daugh...",iCliniQ,,,,Pediatrics,[],2025-03-04 23:55:05.659,2025-03-04 23:55:05.659,56,1,4,"[-0.015964794903993607, -0.021806009113788605,...","[0.08642300218343735, -0.07272278517484665, -0...",0.082553,1.354583
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
48065,Would tendinosis cause swelling along with dis...,Tendinosis causes pain and a loss of flexibili...,1,"Hello,\nWelcome to icliniq.com.\nI understand ...",iCliniQ,,,,Family Physician,[],2025-03-04 23:55:05.657,2025-03-04 23:55:05.657,109,28171,48065,"[-0.030853671953082085, -0.024006305262446404,...","[0.016239769756793976, -0.011710045859217644, ...",0.746299,0.712322
48066,"Would using Azelaic acid cream treat acne, bla...",Acne is a skin condition that occurs when hair...,1,"Hello,\nWelcome to icliniq.com.\nI went throug...",iCliniQ,,,,Venereology,[],2025-03-04 23:55:05.657,2025-03-04 23:55:05.657,88,28172,48066,"[-0.03918802738189697, 0.046093229204416275, -...","[-0.011104153469204903, 0.02846224419772625, -...",0.769600,0.678823
48067,Would you recommend photo test for solar urtic...,Hives or urticaria that develop when exposed t...,1,"Hi,\nI am glad you chose icliniq for your medi...",iCliniQ,,,,Dermatology,[],2025-03-04 23:55:05.647,2025-03-04 23:55:05.647,210,28173,48067,"[-0.017249321565032005, 0.056969717144966125, ...","[0.035568609833717346, 0.04882590472698212, 0....",0.808645,0.595370
48068,an allergic rhinitis cause sneezing fits early...,"Allergic rhinitis, also called hay fever, is a...",1,"Hello,\nWelcome to icliniq.com\nI understand y...",iCliniQ,,,,General Practitioner,[],2025-03-04 23:55:05.650,2025-03-04 23:55:05.650,62,28174,48068,"[0.03091682679951191, 0.018054476007819176, 0....","[0.06415484100580215, -0.022760339081287384, 0...",0.734878,0.728178


### - Splits the dataset 
Splits the dataset into train/test. In ranking scenarios,
group by question. Here we do a simple row-based split. For large
production systems, we'll consider grouping or stratified approaches.

Returns: (train_df, test_df)

In [42]:

test_size = 0.2
random_seed = 42

train_df, test_df = train_test_split(mod_df, test_size=test_size, random_state=random_seed)
print(f"Train set: {train_df.shape} rows, Test set: {test_df.shape} rows")

Train set: (38456, 19) rows, Test set: (9614, 19) rows


### - Train XGBoost
We train an XGBoost ranker on our numeric features, grouping by question_id.

Feature columns being [cos_sim, l2_dist, context_length, ...].
Label is the binary relevance (1 or 0).
Groups are derived from how many rows belong to each question_id.

Returns the fitted model.

In [48]:

learning_rate = 0.1
n_estimators = 100
ranking_objective = "rank:ndcg"


# Sort so that all rows for a question are contiguous
train_df = train_df.sort_values("question_id")
train_df.head(5)

Unnamed: 0,question,context,label,answer,Dataset,focus,synonyms,qtype,speciality,tags,created_at,updated_at,context_length,question_id,context_id,question_embeddings,context_embeddings,cos_sim,l2_dist
0,"1 year old fell, hurt head. Complications?","Frequent fall and hit on the head, if associat...",1,"Hello,\nWelcome to icliniq.com.\nFrequent fall...",iCliniQ,,,,Pediatrics,[],2025-03-04 23:55:05.651,2025-03-04 23:55:05.651,89,0,0,"[0.010470500215888023, -0.017503174021840096, ...","[0.07581809163093567, -0.025507749989628792, -...",0.703953,0.769476
2,"1 year old fell, hurt head. Complications?",Laryngitis can cause irritation in the lungs a...,0,"Hello,\nWelcome to icliniq.com.\nFrequent fall...",iCliniQ,,,,Pediatrics,[],2025-03-04 23:55:05.651,2025-03-04 23:55:05.651,109,0,2,"[0.010470500215888023, -0.017503174021840096, ...","[0.04832500219345093, -0.02734067291021347, 0....",0.255682,1.220096
3,1.8 years old is not crawling yet. What can be...,Babies usually start crawling by 12 or 13 mont...,1,"Hello,\nWelcome to icliniq.com.\n1. Your daugh...",iCliniQ,,,,Pediatrics,[],2025-03-04 23:55:05.659,2025-03-04 23:55:05.659,62,1,3,"[-0.015964794903993607, -0.021806009113788605,...","[0.012998882681131363, -0.09156108647584915, 0...",0.73715,0.725052
5,1.8 years old is not crawling yet. What can be...,You can keep in touch with your friend through...,0,"Hello,\nWelcome to icliniq.com.\n1. Your daugh...",iCliniQ,,,,Pediatrics,[],2025-03-04 23:55:05.659,2025-03-04 23:55:05.659,112,1,5,"[-0.015964794903993607, -0.021806009113788605,...","[0.01929549314081669, -0.003438724437728524, 0...",0.007168,1.409136
6,10 month old has VUR grade 2 in both kidneys. ...,"Hello doctor,\nI have a 10 month old daughter....",1,"Hello,\nWelcome to icliniq.com.\nSorry to hear...",iCliniQ,,,,Nephrology,[],2025-03-04 23:55:05.659,2025-03-04 23:55:05.659,84,2,6,"[-0.011081097647547722, 0.017025841400027275, ...","[-0.013550026342272758, -0.006064726505428553,...",0.754278,0.701031


In [49]:
# Build group array for XGBoost (size of each question's doc set).
group_series = train_df.groupby("question_id").size()
group_series

question_id
0        2
1        2
2        2
3        3
4        2
        ..
28171    1
28172    1
28173    1
28174    1
28175    1
Length: 25013, dtype: int64

In [50]:
group_sizes = group_series.values.tolist()
group_series

question_id
0        2
1        2
2        2
3        3
4        2
        ..
28171    1
28172    1
28173    1
28174    1
28175    1
Length: 25013, dtype: int64

In [51]:
# Extract features + labels
feature_cols = ["cos_sim", "l2_dist", "context_length"]
X_train = train_df[feature_cols].values
y_train = train_df["label"].values

In [52]:
# Initialize ranker
ranker = XGBRanker(
objective=ranking_objective,
learning_rate=learning_rate,
n_estimators=n_estimators,
eval_metric="ndcg",
tree_method="auto"   # or 'gpu_hist' if you have GPU
)

In [53]:
# Train
ranker.fit(
    X_train,
    y_train,
    group=group_sizes,
    verbose=True
)

In [None]:

# # Log hyperparams to W&B
# wandb.init(project="MediMaven-LTR", job_type="train_ranker", reinit=True)
# wandb.config.update({
# "learning_rate": learning_rate,
# "n_estimators": n_estimators,
# "objective": ranking_objective
# })

### - NDCG Calculation   
    Compute NDCG@k for a single query:
      1) Sort docs by predicted score descending
      2) Compute DCG of top-k
      3) Compute IDCG (ideal ranking)
      4) Return DCG/IDCG

In [54]:
def compute_ndcg_at_k(labels: np.ndarray, scores: np.ndarray, k: int = 10) -> float:
    """
    Compute NDCG@k for a single query:
      1) Sort docs by predicted score descending
      2) Compute DCG of top-k
      3) Compute IDCG (ideal ranking)
      4) Return DCG/IDCG
    """
    from math import log2

    # Sort by predicted score, descending
    idx_sorted = np.argsort(-scores)
    ideal_sorted = np.argsort(-labels)

    dcg = 0.0
    idcg = 0.0

    for i in range(k):
        if i < len(idx_sorted):
            rel = labels[idx_sorted[i]]
            dcg += (2**rel - 1) / log2(i+2)
        if i < len(ideal_sorted):
            ideal_rel = labels[ideal_sorted[i]]
            idcg += (2**ideal_rel - 1) / log2(i+2)

    return dcg / (idcg + 1e-9)

### - Evaluate ranker on test set
Computes mean NDCG@k across all queries in the test set.

test_df must have question_id, label, and the numeric feature columns used in training.
We predict a score for each row, group by question_id, then compute NDCG@k.

Returns:
    mean_ndcg (float)

In [55]:


k_eval = 10

test_df = test_df.sort_values("question_id")

feature_cols = ["cos_sim", "l2_dist", "context_length"]
X_test = test_df[feature_cols].values
y_true = test_df["label"].values

# Build group array
group_series = test_df.groupby("question_id").size()
group_sizes = group_series.values.tolist()

# Predict
y_scores = ranker.predict(X_test)
y_scores

array([-2.2177625 , -4.6322207 , -4.2933974 , ...,  3.7861598 ,
        0.42110166,  3.8443065 ], shape=(9614,), dtype=float32)

In [56]:
len(y_scores)

9614

In [59]:
# Compute NDCG@k per query
ndcg_values = []
start_idx = 0
for size in group_sizes:
    end_idx = start_idx + size
    labels_group = y_true[start_idx:end_idx]
    scores_group = y_scores[start_idx:end_idx]

    ndcg_val = compute_ndcg_at_k(labels_group, scores_group, k=k_eval)
    ndcg_values.append(ndcg_val)
    start_idx = end_idx


In [60]:
ndcg_values

[np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.9999999989999999),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.9999999989999999),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.9999999989999999),
 np.float64(0.9999999989999999),
 np.float64(0.9999999989999999),
 np.float64(0.0),
 np.float64(0.9999999989999999),
 np.float64(0.0),
 np.float64(0.9999999989999999),
 np.float64(0.9999999989999999),
 np.float64(0.9999999989999999),
 np.float64(0.9999999989999999),
 np.float64(0.9999999989999999),
 np.float64(0.0),
 np.float64(0.9999999989999999),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.9999999989999999),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0.0),
 np.float64(0

In [61]:
mean_ndcg = float(np.mean(ndcg_values))
print(f"NDCG@{k_eval} on test set = {mean_ndcg:.4f}")

NDCG@10 on test set = 0.4813
