In [0]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
import mlflow
import mlflow.sklearn
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neighbors import NearestNeighbors
import pandas as pd
import numpy as np

In [0]:
spark = SparkSession.builder.getOrCreate()
spark.sql("USE CATALOG workspace")
spark.sql("USE SCHEMA med")

In [0]:
# load doc_chunks
chunks_df = spark.table("workspace.med.doc_chunks")
chunks_df = chunks_df.filter(F.col("chunk_text").isNotNull() & (F.length("chunk_text") > 0))

In [0]:
# convert to pandas
chunks_pdf = chunks_df.select(
    "chunk_id",
    "chunk_text",
    "source",
    "category",
    "title"
).toPandas()

chunks_pdf.head()

In [0]:
chunk_ids = chunks_pdf["chunk_id"].tolist()
texts = chunks_pdf["chunk_text"].tolist()

In [0]:
vectorizer = TfidfVectorizer(
    analyzer="char",
    ngram_range=(3, 5),
    min_df=2, # ignore very rare terms
    max_df=0.8, # ignore overly common terms
    max_features=50000,
    sublinear_tf=True, # reduce dominance of very frequent words
)

matrix = vectorizer.fit_transform(texts)
matrix.shape

In [0]:
# k nearest neighbour search index and test on the matrix 
knn = NearestNeighbors(
    n_neighbors=10, # max num of neighbours
    metric="cosine", # cosine distance
    algorithm="brute", # brute force (compare query vector to all vectors)
    n_jobs=1
).fit(matrix)

In [0]:
def search_chunks_local(question: str, top_k: int = 5):
    question_vector = vectorizer.transform([question]) # convert question to vector
    distances, indices = knn.kneighbors(question_vector, n_neighbors=top_k) # run kNN to get closest chunks

    # extract arrays
    indices = indices[0]
    distances = distances[0]

    results = []

    for rank, (idx, dist) in enumerate(zip(indices, distances), start=1):
        row = chunks_pdf.iloc[int(idx)] # get rows from the chunks pandas df
        results.append({
            "rank": rank, # ranked position based on cosine distance
            "chunk_id": row["chunk_id"], # the chunk id
            "title": row.get("title"), # doc title
            "source": row.get("source"), # medlineplus or openfda
            "category": row.get("category"),
            "distance": float(dist), # cosine distance score
            "text": row["chunk_text"],
            "chunk_text_preview": (row.get("chunk_text") or "")[:300], # short text preview
        })

    return results

In [0]:
for r in search_chunks_local("what is ibuprofen used for", top_k=5):
    print(r["rank"], r["source"], r["title"], r["distance"], r["chunk_id"])
    print(r["chunk_text_preview"])
    print("-" * 100)