## Install requirements

In [1]:
%cd ..

/mnt/d/workspace/mlops_final_project_01


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [6]:
from dotenv import load_dotenv

load_dotenv(override=True)

True

## Utils

In [3]:
from llm_inference.llm_inference_base import ai_completion, ai_completion_with_backoff, truncate_text, get_length
from llm_inference.llm_response_parser import parse_llm_response
from explore.rag.prompt_utils import (
    msg_to_conv, read_json_file, read_yaml_file, 
    load_jsonl, to_print, extract_placeholders
)
from IPython.display import display, Markdown
from pandarallel import pandarallel
import pandas as pd
import json_repair
import json
import random
import glob
import yaml
import time

In [4]:
import re
def clean_text(text):
    # remove special character
    return re.sub(r"[^a-zA-Z0-9 ]", "", text).strip()

## Approach 1

- Query rewriter
- Retrieved document clustering

In [5]:
config = {
    'embedding_model': 'models/gemini-embedding-exp',
    'llm_model': 'gemini-2.5-flash',
    'chroma_db_path': './chroma_db',
    'prompt_template_path': './prompts/rag_base.yaml',
    'temperature': 0.3,
    'top_p': 0.85,
    'search_k': 4
}

### Load vectorstore

In [7]:
import chromadb
from langchain_chroma import Chroma
from langchain_google_genai import GoogleGenerativeAIEmbeddings

In [8]:
def load_vectorstore(db_path, coll_name, embeddings):
    return Chroma(persist_directory=db_path, collection_name=coll_name, embedding_function=embeddings)

In [9]:
persistent_client = chromadb.PersistentClient(
    path="./chroma_db"
)

In [11]:
prompt_templates = yaml.safe_load(open("./prompts/rag_base.yaml"))
prompt_templates.keys()

dict_keys(['basic_rag_01', 'rewrite_01', 'synth_qa_01', 'synth_qa_json_01', 'synth_question_01', 'synth_answer_01', 'question_rewriter_01'])

In [10]:
# Retrieve and list all collection names
collections = persistent_client.list_collections()

# Print collection names
for collection in collections:
    print(collection.name)

In [12]:
coll_01 = persistent_client.get_or_create_collection("llm_rag_01")
embed_01 = GoogleGenerativeAIEmbeddings(model=config["embedding_model"])
vector_store_01 = load_vectorstore(
    db_path=config["chroma_db_path"],
    coll_name=coll_01.name,
    embeddings=embed_01
)

AttributeError: 'Chroma' object has no attribute 'list_documents'

### Define RAG chain

In [13]:
from langchain_google_genai import GoogleGenerativeAI, ChatGoogleGenerativeAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.runnables import RunnablePassthrough

In [14]:
def format_docs(docs):
    try:
        context = "\n\n".join(doc.page_content for doc in docs)
    except Exception as e:
        # print(f"Error formatting documents: {e}")
        context = "\n\n".join(doc for doc in docs)
    return context

def log_prompt(prompt):
    print("\n--- Prompt Fed to LLM ---")
    try:
        print(prompt.messages[0].content)
    except:
        print(prompt)
    print("\n--- End Prompt ---")
    # print(prompt)
    return prompt

In [15]:
llm = GoogleGenerativeAI(
    model=config["llm_model"],
    temperature=config["temperature"],
    top_p=config["top_p"]
)

chat_llm = ChatGoogleGenerativeAI(
    model=config["llm_model"],
    temperature=config["temperature"],
    top_p=config["top_p"]
)

#### Query rewriter

In [16]:
rewrite_prompt = ChatPromptTemplate.from_template(prompt_templates["rewrite_01"])
rewrite_prompt.pretty_print()


Provide a better search query for web search engine to answer the given question, end the queries with `**`.
Question: [33;1m[1;3m{x}[0m 
Answer:



In [18]:
query_rewriter = rewrite_prompt | log_prompt | llm | StrOutputParser() | clean_text

In [19]:
query_rewriter.invoke("What is the core value of vietjet air?")


--- Prompt Fed to LLM ---
Provide a better search query for web search engine to answer the given question, end the queries with `**`.
Question: What is the core value of vietjet air? 
Answer:


--- End Prompt ---


'Vietjet Air core values'

#### Query rewrite RAG chain

In [22]:
prompt_resp = ChatPromptTemplate.from_template(prompt_templates["basic_rag_01"])
prompt_resp.pretty_print()


Answer the users question based only on the following context:

<context>
[33;1m[1;3m{context}[0m
</context>

Question: [33;1m[1;3m{question}[0m



In [23]:
retriever_01 = vector_store_01.as_retriever(
    search_type="mmr", search_kwargs={"k": 4}
)

In [None]:
relevant_docs = retriever_01.invoke("What is vietjet air?")

In [35]:
relevant_docs[0].metadata

{'chunk': 1,
 'doc_id': '6f1abd16-fb4d-4427-bf05-d44661811232',
 'title': 'VietJet Air (VJ) - Flights, Airline Tickets & Reviews',
 'url': 'https://www.kayak.com/VietJet-Air.VJ.airline.html'}

In [40]:
query_rewrite_chain = (
    {
        "context": {"x": RunnablePassthrough()} | query_rewriter | retriever_01 | format_docs,
        "question": RunnablePassthrough(),
    }
    | prompt_resp
    | log_prompt
    | llm
    | StrOutputParser()
)

In [41]:
query_rewrite_chain.invoke("What is vietjet air?")


--- Prompt Fed to LLM ---
Provide a better search query for web search engine to answer the given question, end the queries with `**`.
Question: What is vietjet air? 
Answer:


--- End Prompt ---

--- Prompt Fed to LLM ---
Answer the users question based only on the following context:

<context>
and cheerful low-cost airline based in Vietnam. It has a large flight network spanning central and Southeast Asia, China, Japan, India and Australia. It’s an excellent option for visiting secondary cities within Vietnam (and China, where it serves a great variety of destinations), as well as for longer connecting trips like Australia to South Korea or Japan to India. VietJet also operates a Thailand-based subsidiary, Thai VietJet, with more than a dozen destinations there. The VietJet Air model VietJet’s flights can be almost unbelievably cheap, whether you’re flying domestically, to neighboring countries or further afield to the likes of Australia, India and Japan. Don’t be put off by the som

"VietJet Air, or Vietjet, is a low-cost airline based in Hanoi, Vietnam.  It's Vietnam's largest private carrier, a full member of the International Air Transport Association (IATA), and holds the IOSA certificate.  It has a large flight network spanning central and Southeast Asia, China, Japan, India, and Australia, offering both domestic and international flights.  VietJet also operates a subsidiary, Thai VietJet, based in Thailand.  The airline uses Airbus planes (A320 and A321 models) and is known for its inexpensive flights.\n"

#### Cluster and sampling over retrieved docs

In [52]:
from sklearn.cluster import KMeans
import numpy as np

In [70]:
NUM_DOCS = 30
NUM_CLUSTERS = 3
NUM_SAMPLES_PER_CLUSTER = 2
RANDOM_STATE = 42

In [49]:
# define retriever with large num docs
retriever_02 = vector_store_01.as_retriever(
    search_type="similarity", search_kwargs={"k": NUM_DOCS}
)

In [50]:
relevant_docs = retriever_02.invoke("What is vietjet air?")
print(len(relevant_docs))

30


In [68]:
doc_ids = [doc.metadata["doc_id"] for doc in relevant_docs]
retrieved_embeddings = vector_store_01.get(ids=doc_ids, include=['embeddings'])['embeddings']

print("Perform K-Means clustering")
kmeans = KMeans(n_clusters=NUM_CLUSTERS, random_state=RANDOM_STATE)
kmeans.fit(retrieved_embeddings)

Perform K-Means clustering


In [69]:
labels = kmeans.labels_
labels

array([2, 2, 2, 2, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 2, 0, 2, 0, 1, 1, 0,
       2, 2, 2, 2, 1, 1, 1, 1], dtype=int32)

In [55]:
sampled_doc_ids = []
for cluster_id in range(NUM_CLUSTERS):
    cluster_indices = np.where(labels == cluster_id)[0]
    if len(cluster_indices) > 0:
        # Sample 2 documents from the current cluster
        sampled_indices = np.random.choice(cluster_indices, size=min(NUM_SAMPLES_PER_CLUSTER, len(cluster_indices)), replace=False)
        sampled_doc_ids.extend([doc_ids[i] for i in sampled_indices])

print("Sampled doc IDs:", sampled_doc_ids)

Sampled doc IDs: ['df569422-07ca-43ee-900a-284445227cdb', 'c64b867c-e3f6-43c0-82c8-12de9ccee979', 'e33e6a40-6eaf-48e0-9f1d-42d0b1bc6ffb', 'a1e0a9d6-4390-430f-846d-d13325cc7f5a', '4e89bc02-8a21-43d1-8f41-9aae8840af6f', '673fc849-17b7-4058-8922-b4c8529fd5d5']


In [56]:
len(sampled_doc_ids)

6

In [71]:
sampled_docs = vector_store_01.get(ids=sampled_doc_ids)

In [67]:
sampled_docs.keys()

dict_keys(['ids', 'embeddings', 'documents', 'uris', 'data', 'metadatas', 'included'])

In [72]:
len(sampled_docs['documents'])

6

In [73]:
# wrap up to function
def re_sampling_docs(relevant_docs):
    doc_ids = [doc.metadata["doc_id"] for doc in relevant_docs]
    doc_embeddings = vector_store_01.get(
        ids=doc_ids,
        include=['embeddings']
    )['embeddings']

    print("Perform K-Means clustering")
    kmeans = KMeans(n_clusters=NUM_CLUSTERS, random_state=RANDOM_STATE)
    kmeans.fit(doc_embeddings)

    sampled_doc_ids = []
    for cluster_id in range(NUM_CLUSTERS):
        cluster_indices = np.where(labels == cluster_id)[0]
        if len(cluster_indices) > 0:
            sampled_indices = np.random.choice(
                cluster_indices,
                size=min(NUM_SAMPLES_PER_CLUSTER, len(cluster_indices)),
                replace=False
            )
            sampled_doc_ids.extend([doc_ids[i] for i in sampled_indices])

    print("Sampled doc IDs:", sampled_doc_ids)

    return vector_store_01.get(ids=sampled_doc_ids)["documents"]

In [75]:
# Add to RAG chain
query_rewrite_resampling_chain = (
    {
        "context": {"x": RunnablePassthrough()} | query_rewriter | retriever_02 | re_sampling_docs | format_docs,
        "question": RunnablePassthrough(),
    }
    | prompt_resp
    | log_prompt
    | llm
    | StrOutputParser()
)

In [76]:
query_rewrite_resampling_chain.invoke("What is vietjet air")


--- Prompt Fed to LLM ---
Provide a better search query for web search engine to answer the given question, end the queries with `**`.
Question: What is vietjet air 
Answer:


--- End Prompt ---
Perform K-Means clustering
Sampled doc IDs: ['0dad213a-6d94-4042-90e9-c6228c567fb1', 'e97eb090-9548-4500-805e-8b018444ddcc', 'b14e02cd-3977-43bb-98af-292713b7ad3e', '254632af-6060-4a3b-9e5b-8a9913577989', '59eb094f-4850-49f3-9c5c-3932aed32189', 'ffe3dd6d-45ae-4b72-84d3-7c30aa63549f']

--- Prompt Fed to LLM ---
Answer the users question based only on the following context:

<context>
airport and onboard is very similar to other low-cost carriers in the region. If you want a basic seat with just a personal item to take with you, you’ll get it at an incredibly low cost. Anything more, like baggage, meals or seat reservations, comes with an extra fee, though I still find these pretty reasonable, usually. Classes of travel and fares On almost all its aircraft, VietJet is all-economy and, like many 

"Based on the provided text, VietJet Air is a low-cost carrier in the region, offering basic seats at low prices with extra fees for baggage, meals, and seat reservations.  It operates mostly all-economy class flights with various fare types offering added benefits for a higher price.  The airline's fleet consists entirely of Airbus A320 and A321 models, and it was the first in Southeast Asia to fly the A320neo model.  VietJet Air has a codeshare agreement with Japan Airlines and a subsidiary, Thai Vietjet Air, focusing on domestic flights within Thailand.  In 2015, it was named the Best Asian Low-Cost Carrier at the TTG Travel Awards.\n"