In [0]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
import mlflow
import mlflow.sklearn
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import FeatureUnion
from sklearn.neighbors import NearestNeighbors
import pandas as pd
import numpy as np
import re
import hashlib

In [0]:
spark = SparkSession.builder.getOrCreate()
spark.sql("USE CATALOG workspace")
spark.sql("USE SCHEMA med")

In [0]:
# load doc_chunks
chunks_df = spark.table("workspace.med.doc_chunks")
chunks_df = chunks_df.filter(F.col("chunk_text").isNotNull() & (F.length("chunk_text") > 0))

In [0]:
# convert to pandas
chunks_pdf = chunks_df.select(
    "chunk_id",
    "doc_id",
    "chunk_text",
    "source",
    "category",
    "title"
).toPandas()

chunks_pdf.head()

In [0]:
chunk_ids = chunks_pdf["chunk_id"].tolist()
texts = chunks_pdf["chunk_text"].tolist()

In [0]:
word_vec = TfidfVectorizer(
    analyzer="word",
    ngram_range=(1, 2),
    stop_words="english",
    min_df=2,
    max_df=0.8,
    sublinear_tf=True,
    max_features=50000,
)

char_vec = TfidfVectorizer(
    analyzer="char_wb",
    ngram_range=(3, 5),
    min_df=2,
    max_df=0.9,
    sublinear_tf=True,
    max_features=100000,
)

vectorizer = FeatureUnion([("word", word_vec), ("char", char_vec)])
matrix = vectorizer.fit_transform(texts)
matrix.shape

In [0]:
# k nearest neighbour search index and test on the matrix 
knn = NearestNeighbors(
    n_neighbors=10, # max num of neighbours
    metric="cosine", # cosine distance
    algorithm="brute", # brute force (compare query vector to all vectors)
    n_jobs=1
).fit(matrix)

In [0]:
def normalize_text(s: str) -> str:
    s = (s or "").lower() # ensure input is string
    s = re.sub(r"\s+", " ", s).strip() # remove extra spaces
    return s

# remove duplicated text across all documents
def text_hash(s: str) -> str:
    return hashlib.md5(normalize_text(s).encode("utf-8")).hexdigest()

In [0]:
def search_chunks_local(question: str, top_k: int = 5, pool_k: int = 50, max_dist: float = 0.85):
    question_vector = vectorizer.transform([question]) # convert question to vector
    pool_k = min(pool_k, len(chunks_pdf)) # limit pool size
    distances, indices = knn.kneighbors(question_vector, n_neighbors=pool_k) # find nearest neighbors
    seen_text = set() # track identical/near-identical chunk text
    seen_doc = set() # track document level duplicates
    results = []
    rank = 0 # ranking counter

    for idx, dist in zip(indices[0], distances[0]):
        # skip weak matches (higher distance is worse)
        if dist > max_dist:
            continue

        row = chunks_pdf.iloc[int(idx)] # get row from chunks df
        doc_id = row.get("doc_id", None) # get doc_id
        h = text_hash(row.get("chunk_text") or "") # dedupe identical/near-identical chunk text

        if h in seen_text:
            continue

        # dedupe multiple chunks from the same doc
        if doc_id is not None and doc_id in seen_doc:
            continue
        
        # mark chunk text as seen
        seen_text.add(h)
        if doc_id is not None:
            seen_doc.add(doc_id)

        # increment rank and append the result
        rank += 1
        results.append({
            "rank": rank, # ranking position based on cosine distance
            "chunk_id": row["chunk_id"], 
            "doc_id": doc_id,
            "title": row.get("title"),
            "source": row.get("source"),
            "category": row.get("category"),
            "cosine_distance": float(dist), # cosine distance score
            "cosine_similarity": float(1.0 - float(dist)), # cosine similarity score
            "chunk_text_preview": (row.get("chunk_text") or "")[:300], # short text preview
        })

        # stop when there is enough results
        if len(results) >= top_k:
            break

    return results

In [0]:
for r in search_chunks_local("what is ibuprofen used for", top_k=5):
    print(r["rank"], r["source"], r["title"], r["cosine_distance"], r["chunk_id"])
    print(r["chunk_text_preview"])
    print("-" * 100)