## **Amazon Sales Dataset - RAG System**

**Components**
- Dense Retrieval: FAISS
- Sparse Retrieval: TF-IDF
- Reranking: CrossEncoder
- Generation: Gemini API

**Model Storage**
Cached in `../models/rag/` for fast production loading.

## 1. Dependencies

In [50]:
# !pip install -q sentence-transformers faiss-cpu google-generativeai scikit-learn pandas numpy tqdm

import warnings
warnings.filterwarnings('ignore')

## 2. Imports

In [51]:
import os
import sys
import numpy as np
import pandas as pd
import json
import pickle
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from pathlib import Path
import time

from sentence_transformers import SentenceTransformer, CrossEncoder
import faiss
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import google.generativeai as genai
from tqdm.auto import tqdm

print("Libraries imported")

Libraries imported


## 3. Configuration

In [80]:
@dataclass
class RAGConfig:
    # Paths
    data_path: str = "../data/processed/amazon.csv"
    model_dir: str = "../models/rag"
    
    # Models
    embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
    reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-12-v2"
    gemini_model: str = "models/gemini-2.5-flash"
    
    # Parameters
    embedding_dim: int = 384
    top_k_dense: int = 20
    top_k_sparse: int = 20
    top_k_final: int = 10
    hybrid_alpha: float = 0.6
    
    # Reranking
    use_reranker: bool = True
    rerank_top_k: int = 5
    
    # LLM
    gemini_temperature: float = 0.7
    gemini_max_tokens: int = 12000
    
    # Cache
    cache_embeddings: bool = True
    cache_tfidf: bool = True
    cache_index: bool = True

config = RAGConfig()
Path(config.model_dir).mkdir(parents=True, exist_ok=True)
print(f"Config loaded. Model dir: {config.model_dir}")

Config loaded. Model dir: ../models/rag


## 4. Gemini API

In [66]:
from google import genai
from google.genai import types

def setup_gemini(api_key: Optional[str] = None):
    """Configure Gemini API with new client-based approach."""
    if api_key is None:
        api_key = os.getenv("GEMINI_API_KEY")
    
    if not api_key:
        print("WARNING: GEMINI_API_KEY not found")
        print("Set with: os.environ['GEMINI_API_KEY'] = 'your-key'")
        return None
    
    # New API: Create client instead of configure
    client = genai.Client(api_key=api_key)
    print(f"Gemini client configured")
    return client

os.environ['GEMINI_API_KEY'] = 'AIzaSyDcIoUJHsLeOpAeKYiYFTEFjtUGLxZdmEQ'
gemini_client = setup_gemini()

Gemini client configured


In [70]:
def create_chat_session():
    """Create chat session with new API."""
    if not gemini_client:
        return None
    
    # New chat API
    chat = gemini_client.chats.create(
        model="gemini-2.5-flash",
        config=types.GenerateContentConfig(
            temperature=0.7,
            max_output_tokens=32000
        )
    )
    return chat

## 5. Data Loading

In [54]:
def load_and_preprocess_data(csv_path: str) -> pd.DataFrame:
    """Load and preprocess product dataset."""
    print(f"Loading data from {csv_path}")
    df = pd.read_csv(csv_path)
    print(f"Loaded {len(df)} products")
    
    df['rating'] = df['rating'].fillna(0)
    df['rating_count'] = df['rating_count'].fillna(0)
    df['about_product'] = df['about_product'].fillna('')
    df['category'] = df['category'].fillna('Unknown')
    
    def create_context(row):
        parts = []
        if pd.notna(row.get('product_name')):
            parts.append(f"Product: {row['product_name']}")
        if pd.notna(row.get('category')):
            parts.append(f"Category: {row['category']}")
        if pd.notna(row.get('about_product')) and str(row['about_product']) != '':
            parts.append(f"Description: {str(row['about_product'])[:300]}")
        if pd.notna(row.get('discounted_price')):
            parts.append(f"Price: {row['discounted_price']}")
        if pd.notna(row.get('rating')):
            parts.append(f"Rating: {row['rating']}/5")
        if pd.notna(row.get('rating_count')):
            parts.append(f"Reviews: {int(row['rating_count'])}")
        return " | ".join(parts)
    
    df['product_context'] = df.apply(create_context, axis=1)
    print("Preprocessing complete")
    return df

df = load_and_preprocess_data(config.data_path)

Loading data from ../data/processed/amazon.csv
Loaded 1351 products
Preprocessing complete


## 6. Embedding Generation

In [55]:
def generate_or_load_embeddings(texts: List[str], model_name: str, cache_path: str = None):
    """Generate embeddings or load from cache."""
    if cache_path and os.path.exists(cache_path):
        print(f"Loading embeddings from {cache_path}")
        embeddings = np.load(cache_path)
        print(f"Loaded: {embeddings.shape}")
    else:
        print(f"Generating embeddings: {model_name}")
        model = SentenceTransformer(model_name)
        embeddings = model.encode(
            texts, batch_size=32, show_progress_bar=True,
            convert_to_numpy=True, normalize_embeddings=True
        )
        print(f"Generated: {embeddings.shape}")
        if cache_path:
            np.save(cache_path, embeddings)
            print(f"Cached to {cache_path}")
    
    model = SentenceTransformer(model_name)
    return model, embeddings

cache_path = os.path.join(config.model_dir, "product_embeddings.npy")
embedding_model, product_embeddings = generate_or_load_embeddings(
    df['product_context'].tolist(),
    config.embedding_model,
    cache_path=cache_path if config.cache_embeddings else None
)

Loading embeddings from ../models/rag\product_embeddings.npy
Loaded: (1351, 384)


## 7. FAISS Index

In [56]:
def build_or_load_faiss(embeddings: np.ndarray, cache_path: str = None):
    """Build FAISS index or load from cache."""
    if cache_path and os.path.exists(cache_path):
        print(f"Loading FAISS from {cache_path}")
        index = faiss.read_index(cache_path)
        print(f"Loaded: {index.ntotal} vectors")
    else:
        embeddings = embeddings.astype(np.float32)
        dim = embeddings.shape[1]
        print(f"Building FAISS (dim={dim})")
        index = faiss.IndexFlatIP(dim)
        index.add(embeddings)
        print(f"Built: {index.ntotal} vectors")
        if cache_path:
            faiss.write_index(index, cache_path)
            print(f"Cached to {cache_path}")
    return index

cache_path = os.path.join(config.model_dir, "faiss_index.bin")
faiss_index = build_or_load_faiss(
    product_embeddings,
    cache_path=cache_path if config.cache_index else None
)

Loading FAISS from ../models/rag\faiss_index.bin
Loaded: 1351 vectors


## 8. TF-IDF Sparse Retriever

In [57]:
def build_or_load_tfidf(documents: List[str], cache_dir: str = None):
    """Build TF-IDF or load from cache."""
    vec_path = os.path.join(cache_dir, "tfidf_vectorizer.pkl") if cache_dir else None
    mat_path = os.path.join(cache_dir, "tfidf_matrix.pkl") if cache_dir else None
    
    if vec_path and os.path.exists(vec_path) and os.path.exists(mat_path):
        print("Loading TF-IDF from cache")
        with open(vec_path, "rb") as f:
            vectorizer = pickle.load(f)
        with open(mat_path, "rb") as f:
            doc_vectors = pickle.load(f)
        print(f"Loaded: {doc_vectors.shape}")
    else:
        print("Building TF-IDF")
        vectorizer = TfidfVectorizer(
            max_features=5000, max_df=0.8, min_df=2,
            ngram_range=(1, 2), lowercase=True
        )
        doc_vectors = vectorizer.fit_transform(documents)
        print(f"Built: {doc_vectors.shape}")
        if vec_path:
            with open(vec_path, "wb") as f:
                pickle.dump(vectorizer, f)
            with open(mat_path, "wb") as f:
                pickle.dump(doc_vectors, f)
            print("Cached")
    return vectorizer, doc_vectors

tfidf_vectorizer, tfidf_vectors = build_or_load_tfidf(
    df['product_context'].tolist(),
    cache_dir=config.model_dir if config.cache_tfidf else None
)

Loading TF-IDF from cache
Loaded: (1351, 5000)


## 9. Reranker Model

In [58]:
def load_reranker(model_name: str):
    """Load CrossEncoder reranker."""
    print(f"Loading reranker: {model_name}")
    reranker = CrossEncoder(model_name)
    print("Loaded")
    return reranker

reranker_model = load_reranker(config.reranker_model)

Loading reranker: cross-encoder/ms-marco-MiniLM-L-12-v2
Loaded


## 10. Hybrid Retrieval

In [59]:
def hybrid_retrieve(query: str, top_k: int = 10, use_reranker: bool = True) -> List[Dict]:
    """Hybrid retrieval: dense + sparse + reranking."""
    # Dense
    q_emb = embedding_model.encode(
        query, convert_to_numpy=True, normalize_embeddings=True
    ).astype(np.float32).reshape(1, -1)
    d_scores, d_idx = faiss_index.search(q_emb, k=config.top_k_dense)
    d_scores, d_idx = d_scores[0], d_idx[0]
    
    # Sparse
    q_vec = tfidf_vectorizer.transform([query])
    s_scores = cosine_similarity(q_vec, tfidf_vectors)[0]
    s_idx = np.argsort(s_scores)[::-1][:config.top_k_sparse]
    s_scores = s_scores[s_idx]
    
    # Combine
    combined = {}
    for idx, score in zip(d_idx, d_scores):
        combined[int(idx)] = config.hybrid_alpha * float(score)
    for idx, score in zip(s_idx, s_scores):
        idx = int(idx)
        combined[idx] = combined.get(idx, 0.0) + (1 - config.hybrid_alpha) * float(score)
    
    sorted_items = sorted(combined.items(), key=lambda x: x[1], reverse=True)
    top_idx = [i for i, _ in sorted_items[:config.top_k_final]]
    top_scores = [s for _, s in sorted_items[:config.top_k_final]]
    
    # Rerank
    if use_reranker and reranker_model:
        contexts = [df.iloc[i]['product_context'] for i in top_idx]
        pairs = [[query, c] for c in contexts]
        r_scores = reranker_model.predict(pairs)
        r_idx = np.argsort(r_scores)[::-1][:config.rerank_top_k]
        final_idx = [top_idx[i] for i in r_idx]
        final_scores = [r_scores[i] for i in r_idx]
    else:
        final_idx = top_idx[:top_k]
        final_scores = top_scores[:top_k]
    
    # Format
    results = []
    for idx, score in zip(final_idx, final_scores):
        p = df.iloc[idx]
        results.append({
            'index': int(idx),
            'score': float(score),
            'product_name': str(p.get('product_name', 'N/A')),
            'category': str(p.get('category', 'N/A')),
            'price': str(p.get('discounted_price', 'N/A')),
            'rating': float(p.get('rating', 0)),
            'rating_count': int(p.get('rating_count', 0)),
            'description': str(p.get('about_product', ''))[:200],
            'product_link': str(p.get('product_link', 'N/A'))
        })
    return results

print("Hybrid retrieval ready")

Hybrid retrieval ready


## 11. Recommendation Generation

In [76]:
def generate_recommendation(query: str, retrieved: List[Dict]) -> str:
    """Generate AI recommendation with new Gemini API."""
    if not gemini_client:
        return "Gemini not configured"
    
    context = "\n\n".join([
        f"{i}. {p['product_name']}\n"
        f"   Category: {p['category']}\n"
        f"   Price: {p['price']}\n"
        f"   Rating: {p['rating']}/5 ({p['rating_count']} reviews)\n"
        f"   Description: {p['description']}"
        for i, p in enumerate(retrieved[:5], 1)
    ])
    
    prompt = f"""Provide product recommendation based on query.

Query: "{query}"

Products:
{context}

Please provide:
1. Brief summary
2. Top 2-3 recommendations with reasons
3. Key features

Keep response concise and helpful. Bên cạnh đó sử dụng Tiếng Việt trong phần trả lời."""  
    
    try:
        # NEW API: Use client.models.generate_content()
        response = gemini_client.models.generate_content(
            model="gemini-2.5-flash",
            contents=prompt,
            config=types.GenerateContentConfig(
                temperature=config.gemini_temperature,
                max_output_tokens=config.gemini_max_tokens,
            )
        )
        
        # Check if response has content
        if not response.text:
            return "Response blocked or empty. Try rephrasing query."
        
        return response.text
        
    except Exception as e:
        return f"Error: {str(e)}"

print("Recommendation generator ready")


Recommendation generator ready


## 12. Question Answering

In [77]:
def answer_question(question: str) -> Dict:
    """Answer questions using RAG with new Gemini API."""
    if not gemini_client:
        return {"error": "Gemini not configured", "question": question}
    
    try:
        retrieved = hybrid_retrieve(question, top_k=5)
    except Exception as e:
        return {"error": f"Retrieval error: {str(e)}", "question": question}
    
    context = "\n".join([
        f"- {p['product_name']} ({p['category']}): "
        f"Price {p['price']}, Rating {p['rating']}/5. {p['description']}"
        for p in retrieved
    ])
    
    prompt = f"""Answer this question about products.

Question: "{question}"

Available Products:
{context}

Provide a clear and helpful answer.Bên cạnh đó sử dụng Tiếng Việt trong phần trả lời."""  
    
    try:
        # NEW API
        response = gemini_client.models.generate_content(
            model="gemini-2.5-flash",
            contents=prompt,
            config=types.GenerateContentConfig(
                temperature=config.gemini_temperature,
                max_output_tokens=config.gemini_max_tokens,
            )
        )
        
        if not response.text:
            return {
                "error": "Response blocked or empty",
                "question": question,
                "products": retrieved
            }
        
        return {
            "question": question,
            "answer": response.text,
            "num_retrieved": len(retrieved),
            "products": retrieved
        }
        
    except Exception as e:
        return {"error": str(e), "question": question}

print("Q&A ready")


Q&A ready


## 13. Chatbot

In [None]:
class RAGChatBot:
    """Interactive chatbot."""
    
    def __init__(self):
        self.history = []
    
    def chat(self, user_input: str, mode: str = "recommend") -> Dict:
        """Process input and generate response."""
        print(f"\nQuery: {user_input}")
        print("="*80)
        
        print("Retrieving...")
        start = time.time()
        retrieved = hybrid_retrieve(user_input, top_k=5)
        ret_time = time.time() - start
        # print(f"Retrieved {len(retrieved)} in {ret_time:.2f}s\n")
        
        # print("Products:")
        # print("-"*80)
        # for i, p in enumerate(retrieved, 1):
        #     print(f"{i}. {p['product_name']}")
        #     print(f"   {p['category']} | {p['price']} | {p['rating']}/5")
        #     print(f"   Score: {p['score']:.4f}\n")
        
        if mode == "recommend":
            print("Generating...\n")
            start = time.time()
            rec = generate_recommendation(user_input, retrieved)
            gen_time = time.time() - start
            
            # print("Recommendation:")
            # print("-"*80)
            # print(rec)
            # print("-"*80)
            # print(f"\nGeneration: {gen_time:.2f}s")
            
            result = {
                "mode": "recommend",
                "query": user_input,
                "products": retrieved,
                "recommendation": rec,
                "retrieval_time": ret_time,
                "generation_time": gen_time
            }
        else:
            print("Answering...\n")
            start = time.time()
            ans = answer_question(user_input)
            gen_time = time.time() - start
            
            if "error" in ans:
                # print(f"Error: {ans['error']}")
                result = ans
            else:
                # print("Answer:")
                # print("-"*80)
                # print(ans['answer'])
                # print("-"*80)
                # print(f"\nGeneration: {gen_time:.2f}s")
                
                result = {
                    "mode": "qa",
                    "query": user_input,
                    "answer": ans['answer'],
                    "products": retrieved,
                    "retrieval_time": ret_time,
                    "generation_time": gen_time
                }
        
        self.history.append(result)
        return result
    def get_stats(self):
        """Get statistics safely."""
        if not self.history:
            return {"queries": 0, "avg_retrieval": 0, "avg_generation": 0}
        
        retrieval_times = [h.get('retrieval_time', 0) for h in self.history]
        generation_times = [h.get('generation_time', 0) for h in self.history]
        
        return {
            "queries": len(self.history),
            "avg_retrieval": np.mean(retrieval_times),
            "avg_generation": np.mean(generation_times),
            "errors": sum(1 for h in self.history if 'error' in h)
        }

    def get_history(self):
        return self.history
    
    def generate_recommendation_stream(query: str, retrieved: List[Dict]) -> str:
        """Generate recommendation with streaming output."""
        if not gemini_client:
            return "Gemini not configured"
        
        # Prepare context
        context = "\n".join([
            f"{i}. {p['product_name']} - {p['price']} - {p['rating']}/5"
            for i, p in enumerate(retrieved[:3], 1)
        ])
        
        prompt = f"""Recommend product for: {query}

    Products:
    {context}

    Provide brief recommendation."""
        
        try:
            # Use generate_content_stream for streaming
            response_stream = gemini_client.models.generate_content_stream(
                model=config.gemini_model,
                contents=prompt,
                config=types.GenerateContentConfig(
                    temperature=0.7,
                    max_output_tokens=config.gemini_max_tokens
                )
            )
            
            # Stream and display chunks
            full_response = ""
            print("\nRecommendation:")
            print("-" * 80)
            
            for chunk in response_stream:
                if chunk.text:
                    print(chunk.text, end="", flush=True)
                    full_response += chunk.text
            
            print("\n" + "-" * 80)
            return full_response
            
        except Exception as e:
            return f"Error: {str(e)}"

    print("Streaming recommendation generator ready")


chatbot = RAGChatBot()
print("\nChatbot ready")


Chatbot ready


In [None]:
class ContinuousRAGChat:
    """Continuous chat with context and streaming."""
    
    def __init__(self):
        self.chat_session = None
        self.history = []
        self.initialize_chat()
    
    def initialize_chat(self):
        """Initialize Gemini chat session."""
        if not gemini_client:
            print("WARNING: Gemini not configured")
            return
        
        try:
            # Create chat session
            self.chat_session = gemini_client.chats.create(
                model=config.gemini_model,
                config=types.GenerateContentConfig(
                    temperature=0.7,
                    max_output_tokens=config.gemini_max_tokens,
                )
            )
            print("Chat session initialized")
        except Exception as e:
            print(f"Error initializing chat: {e}")
    
    def chat_stream(self, user_input: str, use_rag: bool = True):
        """Chat with streaming and optional RAG."""
        if not self.chat_session:
            print("Chat session not initialized")
            return
        
        print(f"\nYou: {user_input}")
        print("="*80)
        
        # Step 1: RAG Retrieval (if enabled)
        if use_rag:
            print("Searching products...")
            start = time.time()
            retrieved = hybrid_retrieve(user_input, top_k=5)
            print(f"Found {len(retrieved)} products in {time.time()-start:.2f}s\n")
            
            # Build context
            context = "Available products:\n"
            for i, p in enumerate(retrieved[:3], 1):
                context += f"{i}. {p['product_name']} - {p['price']} - {p['rating']}/5\n"
            
            # Augment prompt with RAG context
            augmented_prompt = f"{context}\n\nUser question: {user_input}"
        else:
            augmented_prompt = user_input
        
        # Step 2: Stream response
        print("Assistant: ", end="", flush=True)
        
        try:
            response_stream = self.chat_session.send_message_stream(augmented_prompt)
            
            full_response = ""
            for chunk in response_stream:
                if chunk.text:
                    print(chunk.text, end="", flush=True)
                    full_response += chunk.text
            
            print("\n" + "="*80)
            
            # Save to history
            self.history.append({
                "user": user_input,
                "assistant": full_response,
                "timestamp": time.time()
            })
            
            return full_response
            
        except Exception as e:
            print(f"\nError: {e}")
            return None
    
    def chat_no_stream(self, user_input: str, use_rag: bool = True):
        """Chat without streaming (instant response)."""
        if not self.chat_session:
            return "Chat session not initialized"
        
        print(f"\nYou: {user_input}")
        
        if use_rag:
            retrieved = hybrid_retrieve(user_input, top_k=3)
            context = "\n".join([
                f"{i}. {p['product_name']} - {p['price']}"
                for i, p in enumerate(retrieved[:3], 1)
            ])
            augmented = f"Products:\n{context}\n\nQuestion: {user_input}"
        else:
            augmented = user_input
        
        try:
            response = self.chat_session.send_message(augmented)
            print(f"Assistant: {response.text}")
            return response.text
        except Exception as e:
            print(f"Error: {e}")
            return None
    
    def get_history(self):
        """Get chat history."""
        return self.history
    
    def clear_history(self):
        """Clear chat history and restart session."""
        self.history = []
        self.initialize_chat()
        print("Chat history cleared, session restarted")

# Initialize continuous chat
continuous_chat = ContinuousRAGChat()


In [None]:
def interactive_chat_loop():
    """Interactive chat loop with streaming."""
    print("\n" + "="*80)
    print("RAG CHATBOT - INTERACTIVE MODE")
    print("="*80)
    print("\nCommands:")
    print("  - Type your question for product recommendations")
    print("  - 'history' - Show conversation history")
    print("  - 'clear' - Clear chat history")
    print("  - 'quit' - Exit")
    print("\n" + "="*80 + "\n")
    
    while True:
        try:
            user_input = input("\nYou: ").strip()
            
            if not user_input:
                continue
            
            # Handle commands
            if user_input.lower() in ['quit', 'exit', 'q']:
                print("\nGoodbye!")
                break
            
            elif user_input.lower() == 'history':
                print("\n--- Chat History ---")
                for i, item in enumerate(continuous_chat.get_history(), 1):
                    print(f"\n{i}. You: {item['user']}")
                    print(f"   Bot: {item['assistant'][:100]}...")
                continue
            
            elif user_input.lower() == 'clear':
                continuous_chat.clear_history()
                continue
            
            # Process query with streaming
            continuous_chat.chat_stream(user_input, use_rag=True)
            
        except KeyboardInterrupt:
            print("\n\nInterrupted. Type 'quit' to exit.")
        except Exception as e:
            print(f"\nError: {e}")

# Run interactive mode
# interactive_chat_loop()


## 14. Testing

In [81]:
print("\nRunning tests\n")
print("="*80)

print("\nTest 1: Recommendation")
test1 = chatbot.chat("wireless earbuds under 3000", mode='recommend')

print("\n\nTest 2: Q&A")
test2 = chatbot.chat("What are best Bluetooth speakers?", mode='qa')

print("\n\nSummary")
print("="*80)
print(f"Queries: {len(chatbot.history)}")
retrieval_times = [r.get('retrieval_time', 0) for r in chatbot.history if 'retrieval_time' in r]
avg_ret = np.mean(retrieval_times) if retrieval_times else 0
avg_gen = np.mean([r.get('generation_time', 0) for r in chatbot.history])
print(f"Avg retrieval: {avg_ret:.2f}s")
print(f"Avg generation: {avg_gen:.2f}s")


Running tests


Test 1: Recommendation

Query: wireless earbuds under 3000
Retrieving...
Retrieved 5 in 0.30s

Products:
--------------------------------------------------------------------------------
1. boAt Airdopes 141 Bluetooth Truly Wireless in Ear Earbuds with mic, 42H Playtime, Beast Mode(Low Latency Upto 80ms) for Gaming, ENx Tech, ASAP Charge, IWP, IPX4 Water Resistance (Bold Black)
   Electronics|Headphones,Earbuds&Accessories|Headphones|In-Ear | 1499.0 | 3.9/5
   Score: -0.2658

2. boAt Airdopes 171 in Ear Bluetooth True Wireless Earbuds with Upto 13 Hours Battery, IPX4, Bluetooth v5.0, Dual Tone Finish with Mic (Mysterious Blue)
   Electronics|Headphones,Earbuds&Accessories|Headphones|In-Ear | 1199.0 | 3.9/5
   Score: -0.8701

3. ZEBRONICS Zeb-Sound Bomb N1 True Wireless in Ear Earbuds with Mic ENC, Gaming Mode (up to 50ms), up to 18H Playback, BT V5.2, Fidget Case, Voice Assistant, Splash Proof, Type C (Midnight Black)
   Electronics|Headphones,Earbuds&Accessories|Headph

## 15. Model Persistence

In [83]:
def save_all_models():
    """Save all models and components."""
    print("Saving models...")
    
    # FAISS
    path = os.path.join(config.model_dir, "faiss_index.bin")
    faiss.write_index(faiss_index, path)
    print(f"Saved FAISS: {path}")
    
    # Embeddings
    path = os.path.join(config.model_dir, "product_embeddings.npy")
    np.save(path, product_embeddings)
    print(f"Saved embeddings: {path}")
    
    # TF-IDF vectorizer
    path = os.path.join(config.model_dir, "tfidf_vectorizer.pkl")
    with open(path, "wb") as f:
        pickle.dump(tfidf_vectorizer, f)
    print(f"Saved TF-IDF vec: {path}")
    
    # TF-IDF matrix
    path = os.path.join(config.model_dir, "tfidf_matrix.pkl")
    with open(path, "wb") as f:
        pickle.dump(tfidf_vectors, f)
    print(f"Saved TF-IDF mat: {path}")
    
    # Metadata
    metadata = {
        "total_products": len(df),
        "embedding_model": config.embedding_model,
        "reranker_model": config.reranker_model,
        "embedding_dim": config.embedding_dim,
        "hybrid_alpha": config.hybrid_alpha
    }
    path = os.path.join(config.model_dir, "metadata.json")
    with open(path, "w") as f:
        json.dump(metadata, f, indent=2)
    print(f"Saved metadata: {path}")
    
    print(f"\nAll saved to: {config.model_dir}")

# save_all_models()
print("Save function ready. Call: save_all_models()")

Save function ready. Call: save_all_models()


## 16. Usage Summary



In [88]:
# Usage: chatbot.chat('query', mode='recommend') or mode='qa'"
# ## Recommendation
# result = chatbot.chat("wireless earbuds", mode='recommend')

# ## Q&A
# result = chatbot.chat("best keyboards?", mode='qa')

# ## Direct Retrieval
# products = hybrid_retrieve("gaming mouse", top_k=5)

result = chatbot.chat("best keyboards?", mode='qa')
print(result["answer"])


Query: best keyboards?
Retrieving...
Answering...

Chào bạn, để xác định "bàn phím tốt nhất" còn tùy thuộc vào nhu cầu sử dụng của bạn (chơi game, làm việc văn phòng, sử dụng hàng ngày) và ngân sách. Dựa trên các sản phẩm có sẵn, đây là những lựa chọn hàng đầu:

**1. Tốt nhất cho Gaming và Tính năng cao cấp:**

*   **HP K500F Backlit Membrane Wired Gaming Keyboard**
    *   **Giá:** 1149.0
    *   **Đánh giá:** 4.3/5 (Cao nhất)
    *   **Điểm nổi bật:** Đây là lựa chọn tốt nhất nếu bạn cần một bàn phím chơi game. Nó có đèn nền LED nhiều màu, tấm kim loại chắc chắn, 26 phím chống ghosting (anti-ghosting) và phím khóa Windows, rất quan trọng cho game thủ. Kèm theo bảo hành 3 năm.

**2. Tốt nhất cho Sử dụng hàng ngày/Văn phòng (Yên tĩnh và Tiện lợi):**

*   **Dell KB216 Wired Multimedia USB Keyboard**
    *   **Giá:** 549.0
    *   **Đánh giá:** 4.3/5 (Cao nhất)
    *   **Điểm nổi bật:** Nếu bạn cần một bàn phím đáng tin cậy cho công việc văn phòng hoặc sử dụng hàng ngày, Dell KB216 là m

## 17.UI chat

In [92]:
import gradio as gr

def chat_interface(message, history):
    """Gradio chat interface."""
    # Retrieve products
    retrieved = hybrid_retrieve(message, top_k=3)
    
    # Build context
    context = "\n".join([
        f"{i}. {p['product_name']} - {p['price']}"
        for i, p in enumerate(retrieved[:3], 1)
    ])
    
    # Generate response
    prompt = f"Products:\n{context}\n\nQuestion: {message}"
    
    try:
        response_stream = gemini_client.models.generate_content_stream(
            model=config.gemini_model,
            contents=prompt,
            config=types.GenerateContentConfig(temperature=0.7)
        )
        
        # Accumulate streaming response
        full_response = ""
        for chunk in response_stream:
            if chunk.text:
                full_response += chunk.text
                # Yield for real-time update in Gradio
                yield full_response
        
    except Exception as e:
        yield f"Error: {str(e)}"

# Create Gradio interface
demo = gr.ChatInterface(
    chat_interface,
    chatbot=gr.Chatbot(height=500),
    textbox=gr.Textbox(placeholder="Ask about products...", container=False, scale=7),
    title="RAG Product Recommendation Chatbot",
    description="Ask questions about products and get AI-powered recommendations",
    theme="soft",
    examples=[
        "I need wireless earbuds under 3000",
        "What are the best gaming keyboards?",
        "Show me budget Bluetooth speakers"
    ],
    cache_examples=False,
)

# Launch
demo.launch(share=True)


* Running on local URL:  http://127.0.0.1:7860

Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.


