In [1]:
import sys
from pathlib import Path

def _find_project_root(start: Path = Path.cwd()):
	for p in [start] + list(start.parents):
		if (p / "pyproject.toml").exists():
			return p
	return start

project_root = _find_project_root()
sys.path.insert(0, str(project_root))

In [2]:
import os
import json
from typing import List, Tuple
from tqdm.notebook import tqdm
from dotenv import load_dotenv
from load_data import LoadData
from lora.core import LORA, FUSE
from datasets import load_dataset
from mlx_lm import generate, utils
from langchain_postgres import PGVector
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
from langchain_community.document_loaders import JSONLoader

In [3]:
load_dotenv(project_root / ".env")

True

In [4]:
# Configuration

root_folder = "../../.cache/ragVsFinetuning/model_1"

data_folder = f"./{root_folder}/data"
dataset_name = "Kaludi/Customer-Support-Responses"
n = None
test_split_ratio = 0.2
valid_split_ratio = 0.2

model_path = "mistralai/Mistral-7B-Instruct-v0.2"
adapter_file = f"./{root_folder}/adapters.npz"
save_model_path = f"./{root_folder}/model"

collection_name = "rag_finetuning_comparison"
rag_data_file="../../.cache/ragVsFinetuning/data/data.json"

## Prepare Data

In [None]:
# Prepare data for finetuning

system_message = """
You are a helpful ticket support agent for company XYZ. Provide clear and concise responses to customer queries.
"""

def create_conversation(input: dict) -> dict:
    if input['query'] is None or input["response"] is None:
        pass
    return {
        "messages": [
            {"role": "system", "content": system_message},
            {"role": "user", "content": input["query"]},
            {"role": "assistant", "content": input["response"]}
        ]
    }

data_loader = LoadData(folder=data_folder, dataset_name=dataset_name)
data_loader.save(function=create_conversation, n=n, test_split_ratio=test_split_ratio, valid_split_ratio=valid_split_ratio, write_files=True)

# Prepare data for RAG

def process_rag_data(dataset_name: str, output_file: str, n: int = None) -> List[dict]:
    dataset = load_dataset(dataset_name).select(range(n)).shuffle() if n is not None else load_dataset(dataset_name).shuffle()
    
    rag_data = []
    for i, item in enumerate(tqdm(dataset['train'])):
        if item['query'] is None or item['response'] is None:
            continue
        rag_data.append({"id": i, "query": item['query'], "response": item['response']})

    with open(output_file, 'w') as f:
        json.dump(rag_data, f, indent=4)

process_rag_data(dataset_name=dataset_name, output_file=rag_data_file, n=n)

## Prepare RAG

In [None]:
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
vector_store = PGVector(embeddings=embedding_model, collection_name=collection_name, connection=os.getenv("PG_CONN_URI"))

def metadata_func(sample: dict, metadata: dict) -> dict:
    metadata.update({
        "id": sample["id"],
        "source": dataset_name,
        "type": "support_ticket"
    })
    return metadata

loader = JSONLoader(file_path=rag_data_file, jq_schema=".[]", text_content=False, metadata_func=metadata_func)
docs = loader.load()

def batch_add_documents(vector_store: PGVector, documents: List[Document], batch_size: int = 20):
    for i in tqdm(range(0, len(documents), batch_size)):
        batch = documents[i:i + batch_size]
        vector_store.add_documents(batch)

batch_add_documents(vector_store, docs, batch_size=20)

## Finetune Model

In [None]:
lora = LORA(config={"train": True, "adapter_file": adapter_file, "batch_size": 1, "lora_layers": 4})
lora.invoke(model_path=model_path, data=data_folder)

In [None]:
fuse = FUSE(config={"adapter_file": adapter_file})
fuse.invoke(model_path=model_path, save_path=save_model_path)

## Comparision

In [5]:
def compare(query: str) -> Tuple[str, str]:

    vectordb = PGVector(embeddings=OpenAIEmbeddings(model="text-embedding-3-small"), collection_name=collection_name, connection=os.environ['PG_CONN_URI'], use_jsonb=True)
    rag_simi_result = vectordb.similarity_search(query=query, k=5)
    rag_model, rag_tokenizer = utils.load(model_path)
    rag_prompt = """System: You are a helpful ticket support agent for company XYZ. Provide clear and concise responses to customer queries. RAG Context: {context} User: {query} Answer:"""
    rag_result = generate(model=rag_model, tokenizer=rag_tokenizer, prompt=rag_prompt.format(context=" ".join([doc.page_content for doc in rag_simi_result]), query=query))
    del vectordb, rag_simi_result, rag_model, rag_tokenizer, rag_prompt

    finetuned_model, finetuned_tokenizer = utils.load(save_model_path)
    finetuned_prompt = """System: You are a helpful ticket support agent for company XYZ. Provide clear and concise responses to customer queries. User: {query} Answer:"""
    finetuned_result = generate(model=finetuned_model, tokenizer=finetuned_tokenizer, prompt=finetuned_prompt.format(query=query))
    del finetuned_model, finetuned_tokenizer

    return rag_result, finetuned_result

In [6]:
query = "I received a damaged product."
rag_answer, finetuned_answer = compare(query=query)

print("======================================")
print("================RAG==================")
print(rag_answer)
print("======================================")
print("==============Fine-tuned===============")
print(finetuned_answer)
print("======================================")

Fetching 11 files:   0%|          | 0/11 [00:00<?, ?it/s]

We apologize for the inconvenience. Could you please provide a clear photo of the damage so we can assess the situation and provide a solution?
We apologize for the inconvenience. Can you please provide your order number and a description of the damage so we can assist you in resolving the issue?


In [7]:
query = "I'd like to track my order."
rag_answer, finetuned_answer = compare(query=query)

print("======================================")
print("================RAG==================")
print(rag_answer)
print("======================================")
print("==============Fine-tuned===============")
print(finetuned_answer)
print("======================================")

Fetching 11 files:   0%|          | 0/11 [00:00<?, ?it/s]

Sure, please provide your order number for us to check the current status.
Certainly. Can you please provide your order number so we can check the current status for you?


In [8]:
query = "Can I place a bulk order?"
rag_answer, finetuned_answer = compare(query=query)

print("======================================")
print("================RAG==================")
print(rag_answer)
print("======================================")
print("==============Fine-tuned===============")
print(finetuned_answer)
print("======================================")

Fetching 11 files:   0%|          | 0/11 [00:00<?, ?it/s]

Yes, you can place a bulk order. Please provide the product name or SKU and the quantity you'd like to order for us to check availability and pricing.
Yes, we do accept bulk orders. Can you provide information on your bulk ordering process?
