# Week 3 — RAG Assistant (ShopLite)

Self-contained Colab: Qwen/Qwen2.5-7B-Instruct + FAISS + Flask + ngrok.

In [None]:
#i wanted to show a run of this but i have hit my collab limit 

In [None]:
# -*- coding: utf-8 -*-
# Colab one-cell: Qwen 2.5 7B Instruct (local), FAISS RAG, Flask, ngrok
# Behavior: refuses off-topic queries + terse two-line answers.

# 0) Installs
!pip -q install "transformers>=4.43" "accelerate>=0.33" "bitsandbytes>=0.43" \
                 sentencepiece "sentence-transformers>=2.7" faiss-cpu pyyaml flask pyngrok requests


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.1/60.1 MB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m38.7 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: Operation cancelled by user[0m[31m
[0m

In [None]:
# 1) Imports & GPU check
import os, json, time, threading, textwrap, yaml, numpy as np, torch, faiss, requests
from typing import List, Dict
from flask import Flask, request, jsonify
from pyngrok import ngrok
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer

print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

# --- Behavior knobs ---
TOP_K = 2                   
RELEVANCE_THRESHOLD = 0.32   
MAX_NEW_TOKENS = 110      
MAX_ANSWER_WORDS = 75        


ModuleNotFoundError: No module named 'faiss'

In [None]:


# 2) Knowledge base (REPLACE with your 15–20 docs in the assignment)
KB_DOCS: List[Dict] = [
    {
        "id": "doc_returns",
        "title": "Returns & Refunds",
        "content": (
            "ShopLite offers a 30-day return window from delivery. Items must be unused and in original packaging. "
            "Exclusions include perishables and final-sale items. Start from Your Orders to get an RMA and label. "
            "Return shipping is free for defective or mis-shipped items; otherwise a label cost may be deducted. "
            "Refunds post within 5–10 business days after inspection."
        ),
    },
    {
        "id": "doc_tracking",
        "title": "Orders & Tracking",
        "content": (
            "Statuses: Processing, Shipped, Out for delivery, Delivered. The tracking page shows carrier, latest scan, and ETA. "
            "If late by >3 business days, contact Support with your order ID. Address changes are possible only before Shipped."
        ),
    },
]

In [None]:
# 3) Chunking helper (≈150–250 words per chunk, 50 words overlap)
WORD_MAX, OVERLAP = 220, 50
def chunk_text(doc_id: str, title: str, text: str):
    words = text.split()
    chunks, start = [], 0
    while start < len(words):
        end = min(len(words), start + WORD_MAX)
        chunk = " ".join(words[start:end])
        chunks.append({"doc_id": doc_id, "title": title, "text": chunk})
        if end == len(words): break
        start = end - OVERLAP
    return chunks
CHUNKS: List[Dict] = []
for d in KB_DOCS: CHUNKS.extend(chunk_text(d["id"], d["title"], d["content"]))
print(f"Docs: {len(KB_DOCS)}, chunks: {len(CHUNKS)}")

In [None]:



# 4) Embeddings + FAISS (cosine via inner product on L2-normalized vectors)
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
vecs = embedder.encode([c["text"] for c in CHUNKS], convert_to_numpy=True, show_progress_bar=False)
vecs = vecs / (np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12)
index = faiss.IndexFlatIP(vecs.shape[1])
index.add(vecs.astype(np.float32))
print("FAISS index built; dim=", vecs.shape[1])

def retrieve(query: str, k: int = TOP_K):
    qv = embedder.encode([query], convert_to_numpy=True)
    qv = qv / (np.linalg.norm(qv, axis=1, keepdims=True) + 1e-12)
    D, I = index.search(qv.astype(np.float32), k)
    hits = []
    for score, idx in zip(D[0], I[0]):
        if idx == -1: continue
        hits.append({"score": float(score), "title": CHUNKS[idx]["title"], "text": CHUNKS[idx]["text"]})
    return hits

In [None]:

# 5) Load Qwen/Qwen2.5-7B-Instruct locally (4-bit so it fits a T4 16GB)
MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    load_in_4bit=True,
    torch_dtype=dtype,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=dtype,
)
model.eval()

In [None]:
# 6) Prompt policy (inline YAML) — terse + refuse when not grounded
PROMPTS_YAML = """
base_answer:
  role: system
  goal: >
    You are a ShopLite assistant. Answer ONLY if the provided context snippets
    contain the information. If not, refuse.
  rules:
    - Be terse and precise. If a single fact is asked, reply in ONE sentence.
    - Do not add definitions, disclaimers, or extra details unless asked.
    - Use only the snippets; do not invent facts.
    - Cite sources by document title(s).
    - Output exactly two lines:
      - "Answer: <your answer or refusal>"
      - "Sources: <Title A>; <Title B>"  # (leave empty if refusing)
  format: |
    Answer: <your answer>
    Sources: <Title A>; <Title B>

refusal_message: |
  Answer: Sorry—this isn’t in the ShopLite knowledge base. Try a question about orders, returns, shipping, payments, promotions, reviews, account, or support.
  Sources:
"""
PROMPTS = yaml.safe_load(PROMPTS_YAML)

# helper to render mixed-type YAML rules safely
def _render_rules(rules):
    lines = []
    for r in rules:
        if isinstance(r, str):
            lines.append(r)
        elif isinstance(r, dict):
            for k, v in r.items():
                lines.append(k)
                if isinstance(v, list):
                    for item in v:
                        lines.append(f"- {item}")
                else:
                    lines.append(str(v))
        else:
            lines.append(str(r))
    return "\n".join(lines)

def build_chat(query: str, hits: List[Dict]):
    rules_text = _render_rules(PROMPTS["base_answer"]["rules"])
    system = (
        PROMPTS["base_answer"]["goal"] + "\n" +
        rules_text + "\n" +
        "Format:\n" + PROMPTS["base_answer"]["format"]
    )
    if not hits:
        return [
            {"role": "system", "content": system},
            {"role": "user", "content": f"Question: {query}\nContext: (none)"}
        ]
    ctx = "\n\n".join([f"Title: {h['title']}\nSnippet: {h['text']}" for h in hits])
    user = f"Question: {query}\n\nUse these snippets only:\n---\n{ctx}"
    return [{"role": "system", "content": system}, {"role": "user", "content": user}]

@torch.inference_mode()
def generate_answer(messages: List[Dict], max_new_tokens=MAX_NEW_TOKENS, temperature=0.2):
    def _pick_line(lines, prefix):
        prefix = prefix.lower()
        cands = [ln for ln in lines if ln.lower().startswith(prefix)]
        # Prefer the last candidate that does NOT contain angle-bracket placeholders
        for ln in reversed(cands):
            body = ln.split(":", 1)[1] if ":" in ln else ""
            if "<" not in ln and ">" not in ln and body.strip():
                return ln
        # Fallback to the last candidate if none is clean
        return cands[-1] if cands else None

    # Generate
    prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tok(prompt, return_tensors="pt").to(model.device)
    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=0.9,
        do_sample=True,
        pad_token_id=tok.eos_token_id,
    )
    text = tok.decode(out[0], skip_special_tokens=True).strip()

    # Extract lines
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    answer_line = _pick_line(lines, "answer:") or f"Answer: {text}"
    sources_line = _pick_line(lines, "sources:") or "Sources: "

    # Trim the answer body to MAX_ANSWER_WORDS
    try:
        prefix, body = answer_line.split(":", 1)
        words = body.strip().split()
        if len(words) > MAX_ANSWER_WORDS:
            body = " ".join(words[:MAX_ANSWER_WORDS])
        answer_line = f"{prefix}: {body}"
    except Exception:
        pass

    return f"{answer_line}\n{sources_line}"

In [None]:
# 7) Flask API
app = Flask(__name__)

@app.get("/health")
def health():
    return jsonify(status="ok", model=MODEL_ID)

@app.post("/ping")
def ping():
    data = request.get_json(silent=True) or {}
    prompt = data.get("prompt", "Say hello in one sentence.")
    msgs = [{"role":"system","content":"You are helpful."},{"role":"user","content":prompt}]
    out = generate_answer(msgs, max_new_tokens=80)
    return jsonify(output=out)

@app.post("/chat")
def chat():
    data = request.get_json(silent=True) or {}
    query = (data.get("query") or "").strip()
    if not query:
        return jsonify(error="missing 'query'"), 400

    hits = retrieve(query, k=TOP_K)
    top_score = max([h["score"] for h in hits], default=0.0)

    # Refuse if off-topic / weak match
    if top_score < RELEVANCE_THRESHOLD:
        answer = PROMPTS["refusal_message"]
        return jsonify(answer=answer, sources=[], confidence=0.0)

    messages = build_chat(query, hits)
    answer = generate_answer(messages, max_new_tokens=MAX_NEW_TOKENS)

    titles = [h["title"] for h in hits]
    confidence = round(min(1.0, max(0.0, 0.5 + 0.5 * float(np.mean([h["score"] for h in hits])))), 2)

    return jsonify(answer=answer, sources=sorted(set(titles)), confidence=confidence)


In [None]:
# 8) ngrok (interactive token prompt)
NGROK_TOKEN = input("Paste your ngrok token (https://dashboard.ngrok.com): ").strip()
if NGROK_TOKEN: ngrok.set_auth_token(NGROK_TOKEN)
public_tunnel = ngrok.connect(5002, "http")
PUBLIC_URL = public_tunnel.public_url
print("Public URL:", PUBLIC_URL)

def run_server():
    app.run(host="0.0.0.0", port=5002, debug=False, use_reloader=False)

thread = threading.Thread(target=run_server, daemon=True)
thread.start()

time.sleep(2)
print("Server started → endpoints: /health  /ping  /chat")

In [None]:



# 9) Smoke test
print("Health:", requests.get(f"{PUBLIC_URL}/health", timeout=15).status_code)
resp = requests.post(f"{PUBLIC_URL}/chat", json={"query":"How long is the return window?"}, timeout=30)
print("Status:", resp.status_code)
print(textwrap.shorten(resp.text, width=400))
