In [1]:
# === Cell 1: Install dependencies ===

!pip install -q \
  transformers \
  accelerate \
  sentence-transformers \
  faiss-cpu \
  langchain \
  beautifulsoup4 \
  requests \
  readability-lxml


In [None]:
# === Cell 2: Imports, HF LLM, embeddings ===

import time
import textwrap
from typing import List, Literal, Optional

import numpy as np
import faiss
import torch
import requests
from bs4 import BeautifulSoup
from urllib.parse import urljoin, urlparse
from urllib import robotparser
from readability import Document as ReadabilityDocument  # NEW

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer

# Light LangChain usage (just Document type)
try:
    from langchain_core.documents import Document
except ImportError:
    from langchain.docstore.document import Document

# ------------------------------------------------
# HF LLM setup
# ------------------------------------------------

MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" #32K limit token size; ours +-2k
#268 chunks from .sql(15 to 800 char)
#outputs at 200 tokens

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

#both in gpu and cpu it can run; feature of huggingface
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)

#automated steps that transforms raw data into a trained model
llm = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
)

print(" Loaded HF model:", MODEL_NAME)


def call_llm(system_prompt: str, user_prompt: str,
             temperature: float = 0.1,
             max_new_tokens: int = 192) -> str:
    prompt = (
        f"System:\n{system_prompt.strip()}\n\n"
        f"User:\n{user_prompt.strip()}\n\n"
        f"Assistant:"
    )

    out = llm(
        prompt,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        temperature=temperature,
        pad_token_id=tokenizer.eos_token_id,
    )[0]["generated_text"]

    generated = out[len(prompt):].strip()
    return generated


# ------------------------------------------------
# Shared embedding model
# ------------------------------------------------
#k nearest neighbors at l2 DISTANCE
EMBEDDER_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
embedder = SentenceTransformer(EMBEDDER_MODEL_NAME)
print(" Loaded embedder:", EMBEDDER_MODEL_NAME)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

In [None]:
# === Cell 3: SQL knowledge base + FAISS index (+ .sql dump) ===

import os
from typing import List

sql_docs: List[str] = []
sql_index = None


def load_cs360_chunks_from_sql_dump(sql_path: str = "chatbot_db_export.sql") -> List[str]:
    """
    Parse the .sql dump and extract the TEXT column from
    COPY public.content_chunks (..., text) FROM stdin;
    """
    print(f"[load_cs360_chunks_from_sql_dump] Looking for: {sql_path}")
    print(f"[load_cs360_chunks_from_sql_dump] CWD: {os.getcwd()}")
    print(f"[load_cs360_chunks_from_sql_dump] Files here: {os.listdir()}")

    if not os.path.exists(sql_path):
        print(f"[load_cs360_chunks_from_sql_dump] File '{sql_path}' NOT found, skipping.")
        return []

    chunks: List[str] = []
    in_copy = False

    with open(sql_path, "r", encoding="utf-8") as f:
        for line in f:
            # start of the COPY block
            if line.startswith("COPY public.content_chunks"):
                in_copy = True
                print("[load_cs360_chunks_from_sql_dump] Found COPY public.content_chunks block.")
                continue

            if not in_copy:
                continue

            # end of the COPY block
            if line.strip() == r"\.":
                print("[load_cs360_chunks_from_sql_dump] Reached end of COPY block.")
                break

            # each line = id, document_id, topic_id, page_number, chunk_index, text
            parts = line.rstrip("\n").split("\t", 5)
            if len(parts) < 6:
                continue

            text = parts[5]
            text = " ".join(text.split())  # normalize whitespace
            if text:
                chunks.append(text)

    print(f"[load_cs360_chunks_from_sql_dump] Loaded {len(chunks)} chunks from {sql_path}")
    return chunks


def build_sql_knowledge_base(sql_dump_path: str = "chatbot_db_export.sql"):
    """
    Create a SQL basics KB and build a FAISS index.
    Uses:
      - built-in SQL snippets
      - extra chunks parsed from chatbot_db_export.sql (if present)
    """
    global sql_docs, sql_index

    print("üîß build_sql_knowledge_base: start")
#12 snippets
    sql_snippets = [
        """
        SQL (Structured Query Language) is used to interact with relational databases.
        It lets you create, read, update, and delete data (often abbreviated as CRUD).
        """,
        """
        Common SQL statements:
        - SELECT: read data from tables
        - INSERT: add new rows
        - UPDATE: modify existing rows
        - DELETE: remove rows
        """,
        """
        Basic SELECT:
        SELECT column1, column2
        FROM table_name
        WHERE condition;
        Example:
        SELECT name, age FROM employees WHERE age > 30;
        """,
        """
        WHERE filters rows:
        - Comparison operators: =, <>, <, >, <=, >=
        - Logical operators: AND, OR, NOT
        """,
        """
        ORDER BY sorts rows:
        SELECT * FROM employees ORDER BY salary DESC;
        """,
        """
        LIMIT restricts the number of rows:
        SELECT * FROM employees LIMIT 10;
        """,
        """
        Aggregation functions:
        - COUNT(*)
        - SUM(column)
        - AVG(column)
        - MIN(column)
        - MAX(column)
        Often used with GROUP BY.
        """,
        """
        GROUP BY groups rows by column(s):
        SELECT department, COUNT(*)
        FROM employees
        GROUP BY department;
        """,
        """
        Joins combine rows from multiple tables:
        - INNER JOIN: only matching rows
        - LEFT JOIN: all from left table + matches from right
        - RIGHT JOIN: all rows from both tables
        """,
        """
        INNER JOIN example:
        SELECT e.name, d.name
        FROM employees e
        INNER JOIN departments d
          ON e.department_id = d.id;
        """,
        """
        CREATE TABLE example:
        CREATE TABLE employees (
            id INT PRIMARY KEY,
            name VARCHAR(100),
            age INT,
            salary DECIMAL(10,2)
        );
        """,
        """
        INSERT example:
        INSERT INTO employees (id, name, age, salary)
        VALUES (1, 'Alice', 30, 5000.00);
        """,
    ]

    # base snippets (cleaned)
    docs: List[str] = [" ".join(sn.split()) for sn in sql_snippets]
    print(f"   Base SQL snippets: {len(docs)}")

    # extra chunks from .sql dump
    extra_chunks = load_cs360_chunks_from_sql_dump(sql_dump_path)
    if extra_chunks:
        docs.extend(extra_chunks)
        print(f"Added {len(extra_chunks)} chunks from {sql_dump_path} into SQL KB.")
    else:
        print("No extra chunks loaded from .sql dump.")

    if not docs:
        print(" No SQL docs found. SQL KB will be empty.")
        sql_docs = []
        sql_index = None
        return

    print(f" Total docs/chunks to index: {len(docs)}")

    embeddings = embedder.encode(docs, convert_to_numpy=True, show_progress_bar=False)
    dim = embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(embeddings)

    sql_docs = docs
    sql_index = index
    print("SQL KB built with", len(sql_docs), "documents.")


def retrieve_sql_context(query: str, top_k: int = 5) -> List[str]:
    """
    RAG retrieval for SQL mode.
    """
    if sql_index is None or not sql_docs:
        return []

    q_emb = embedder.encode([query], convert_to_numpy=True, show_progress_bar=False)
    k = min(top_k, len(sql_docs))
    D, I = sql_index.search(q_emb, k)
    return [sql_docs[i] for i in I[0]]


In [None]:
# === Cell 4: sqlbolt.com crawler + FAISS index (robust) ===

from typing import List, Optional
from urllib.parse import urljoin, urlparse
from urllib import robotparser
import time
import requests
from bs4 import BeautifulSoup
from readability import Document as ReadabilityDocument
import faiss

sqlbolt_docs: List[str] = []
sqlbolt_index = None

def build_robot_parser(base_url: str) -> Optional[robotparser.RobotFileParser]:
    robots_url = urljoin(base_url, "/robots.txt")
    rp = robotparser.RobotFileParser()
    try:
        rp.set_url(robots_url)
        rp.read()
        print(f"Loaded robots.txt from {robots_url}")
    except Exception as e:
        print("Could not read robots.txt, assuming allow-all. Error:", e)
    return rp

def is_allowed(url: str, rp: Optional[robotparser.RobotFileParser], user_agent: str = "*") -> bool:
    if rp is None:
        return True
    try:
        return rp.can_fetch(user_agent, url)
    except Exception:
        return True

def fetch_page(url: str, timeout: int = 10) -> Optional[str]:
    headers = {
        "User-Agent": "Mozilla/5.0 (compatible; SQLBoltRAGBot/1.0; +https://sqlbolt.com/)"
    }
    try:
        resp = requests.get(url, headers=headers, timeout=timeout)
        if resp.status_code == 200 and "text/html" in resp.headers.get("Content-Type", ""):
            return resp.text
        else:
            print(f"Skip {url} (status {resp.status_code}, type {resp.headers.get('Content-Type')})")
    except Exception as e:
        print("Error fetching", url, "->", e)
    return None

def extract_clean_text(html: str) -> str:
    """
    Try readability-lxml to get main article text, fall back to full page
    if needed. Strip typical boilerplate tags.
    """
    soup = None

    # 1) try readability main content
    try:
        doc = ReadabilityDocument(html)
        main_html = doc.summary()
        soup = BeautifulSoup(main_html, "html.parser")
    except Exception:
        pass

    # 2) fallback: whole page
    if soup is None:
        soup = BeautifulSoup(html, "html.parser")

    # remove junk
    for tag in soup(["script", "style", "noscript", "header", "footer", "nav", "form"]):
        tag.decompose()

    text = soup.get_text(" ", strip=True)
    text = " ".join(text.split())
    return text

def chunk_text(text: str, max_chars: int = 700) -> List[str]:
    """
    Simple fixed-size character chunking.
    """
    text = " ".join(text.split())
    if not text:
        return []
    chunks = []
    start = 0
    while start < len(text):
        end = start + max_chars
        chunks.append(text[start:end])
        start = end
    return chunks

def crawl_and_build_sqlbolt_index(
    start_url: str = "https://sqlbolt.com/",
    max_pages: int = 100,
    sleep_sec: float = 1.0
):
    """
    Crawl sqlbolt.com and build a FAISS index over collected text chunks.
    Logs how many pages and chunks we end up with.
    """
    global sqlbolt_docs, sqlbolt_index

    visited = set()
    queue = [start_url]
    domain = urlparse(start_url).netloc
    rp = build_robot_parser(start_url)

    collected_chunks: List[str] = []
    pages_with_text = 0

    while queue and len(visited) < max_pages:
        url = queue.pop(0)
        url = url.split("#")[0]
        if url in visited:
            continue
        visited.add(url)

        if not is_allowed(url, rp):
            print("Disallowed by robots.txt, skipping:", url)
            continue

        print(f"Crawling ({len(visited)}/{max_pages}): {url}")
        html = fetch_page(url)
        if not html:
            continue

        text = extract_clean_text(html)
        # VERY permissive now: keep even small pages
        if len(text) < 50:
            print("  -> very short page, skipping (len =", len(text), ")")
        else:
            chunks = chunk_text(text, max_chars=700)
            if chunks:
                pages_with_text += 1
                collected_chunks.extend(chunks)
                print(f"  -> kept {len(chunks)} chunks (total chunks: {len(collected_chunks)})")

        # discover more links
        soup = BeautifulSoup(html, "html.parser")
        for a in soup.find_all("a", href=True):
            href = a["href"].strip()
            if href.startswith("#"):
                continue
            full_url = urljoin(url, href)
            parsed = urlparse(full_url)
            if parsed.netloc == domain and full_url not in visited:
                queue.append(full_url)

        time.sleep(sleep_sec)

    print(f"\nVisited pages: {len(visited)}, pages with text: {pages_with_text}")
    print(f"Total chunks collected: {len(collected_chunks)}")

    if not collected_chunks:
        print(" No text collected from sqlbolt.com. Index will NOT be built.")
        sqlbolt_docs = []
        sqlbolt_index = None
        return

    sqlbolt_docs = collected_chunks
    embeddings = embedder.encode(sqlbolt_docs, convert_to_numpy=True, show_progress_bar=False)
    dim = embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(embeddings)
    sqlbolt_index = index

    print("sqlbolt.com index built with", len(sqlbolt_docs), "chunks.")


def retrieve_sqlbolt_context(query: str, top_k: int = 5) -> List[str]:
    if sqlbolt_index is None or not sqlbolt_docs:
        return []
    q_emb = embedder.encode([query], convert_to_numpy=True, show_progress_bar=False)
    k = min(top_k, len(sqlbolt_docs))
    D, I = sqlbolt_index.search(q_emb, k)
    return [sqlbolt_docs[i] for i in I[0]]


In [None]:
# === Cell 5: Answer verification (anti-hallucination, tuned) ===

from typing import List
import numpy as np

def clean_llm_output(raw: str) -> str:
    """
    Strip obvious prompt echoes and role labels.
    """
    text = raw.strip()
    for tag in ["Assistant:", "ASSISTANT:", "System:", "User:"]:
        if text.startswith(tag):
            text = text.split(tag, 1)[1].strip()

    for tag in ["System:", "User:", "<CONTEXT>", "</CONTEXT>"]:
        if tag in text:
            text = text.split(tag, 1)[0].strip()

    return text.strip()


def verify_answer(
    answer: str,
    context_texts: List[str],
    threshold: float = 0.45,   #0.65 was too strict
    min_overlap: int = 1       # require at least 1 shared meaningful word
) -> str:
    """
    1) Embed answer + context and check cosine similarity.
    2) Check that answer shares words with context.
    If not ‚Üí "Not enough information."
    """
    if not context_texts:
        return "Not enough information."

    answer_text = " ".join(answer.split())
    context_texts = [" ".join(c.split()) for c in context_texts]

    # Embedding similarity
    answer_emb = embedder.encode(
        [answer_text], convert_to_numpy=True, show_progress_bar=False
    )
    context_embs = embedder.encode(
        context_texts, convert_to_numpy=True, show_progress_bar=False
    )

    cos_sim = np.dot(answer_emb, context_embs.T) / (
        np.linalg.norm(answer_emb) * np.linalg.norm(context_embs, axis=1)
    )

    max_sim = float(np.max(cos_sim))
    # print("DEBUG max_sim:", max_sim)  # optional debug

    if max_sim < threshold:
        return "Not enough information."

    # Word-overlap sanity check
    flat_context = " ".join(context_texts).lower()
    answer_lower = answer_text.lower()
    words = [w for w in answer_lower.split() if len(w) > 3]

    relevant_words = [w for w in words if w in flat_context]

    if len(relevant_words) < min_overlap:
        return "Not enough information."

    return answer_text


In [None]:
# === Cell 6: RAG answer generation for SQL + sqlbolt.com QA (generative) ===

import re
import textwrap
from typing import Literal, List

def _extract_snippet(chunk: str, query: str, window: int = 260) -> str:
    """
    (Still here if you want snippet mode later, but NOT used in sqlbolt now)
    """
    text = " ".join(chunk.split())
    if not text:
        return ""

    words = [w for w in query.lower().split() if len(w) > 3]
    lower_text = text.lower()

    idx = -1
    for w in words:
        pos = lower_text.find(w)
        if pos != -1:
            idx = pos
            break

    if idx == -1:
        snippet = text[:window]
        return snippet + ("..." if len(text) > window else "")

    start = max(idx - window // 2, 0)
    end = min(len(text), start + window)
    snippet = text[start:end]
    if start > 0:
        snippet = "..." + snippet
    if end < len(text):
        snippet = snippet + "..."
    return snippet


def strip_inner_not_enough_information(text: str) -> str:
    """
    If the model wrote a real answer *and then* added some
    'not enough information' sentence, drop that sentence.
    If the entire answer is just that phrase, keep it.
    """
    t = text.strip()
    lower = t.lower()

    # if whole answer is basically "not enough information", keep it
    if "not enough information" in lower and len(t.split()) < 6:
        return t

    # otherwise, cut off anything from 'not enough information' onwards
    marker_idx = lower.find("not enough information")
    if marker_idx != -1:
        return t[:marker_idx].rstrip()

    return t


def _answer_from_sqlbolt(query: str) -> str:
    """
    Generative answer using context from sqlbolt.com:

    - Retrieve relevant chunks from the crawled pages
    - Feed them to the LLM with a strict 'use ONLY this context' prompt
    - Verify the answer against the context
    """
    ctx = retrieve_sqlbolt_context(query, top_k=8)
    if not ctx:
        return "Not enough information."

    system_prompt = textwrap.dedent("""
        You are an SQL teaching assistant.
        The context comes from sqlbolt.com.
        You MUST answer using ONLY the information in the provided context.
        If the context does not contain enough information to answer,
        reply exactly: Not enough information.
        Do NOT invent new examples, syntax, or explanations beyond the context.
        Keep your answer concise and clear.
    """)

    context_text = "\n\n---\n\n".join(ctx)

    user_prompt = textwrap.dedent(f"""
        <CONTEXT>
        {context_text}
        </CONTEXT>

        Question: {query}

        Answer the question using ONLY the information inside <CONTEXT>.
        If the answer is not fully supported by the context,
        reply exactly: Not enough information.
    """)

    raw = call_llm(system_prompt, user_prompt, temperature=0.1, max_new_tokens=192)
    cleaned = clean_llm_output(raw)
    final_answer = verify_answer(cleaned, ctx, threshold=0.45)
    final_answer = strip_inner_not_enough_information(final_answer)
    return final_answer


def generate_rag_answer(query: str, mode: Literal["sql", "sqlbolt"]) -> str:
    """
    Core RAG:

    - mode == "sql":
        * retrieve SQL context (local KB)
        * use LLM with strict instructions
        * verify answer against context

    - mode == "sqlbolt":
        * retrieve sqlbolt.com context
        * same pattern, but using sqlbolt chunks
    """
    if mode == "sql":
        context_chunks = retrieve_sql_context(query, top_k=8)
        if not context_chunks:
            return "Not enough information."

        system_prompt = textwrap.dedent("""
            You are an SQL teaching assistant.
            You MUST answer using ONLY the information in the provided context.
            If the context does not contain enough information to answer,
            reply exactly: Not enough information.
            Do NOT invent new examples, syntax, or explanations beyond the context.
            Keep your answer concise and clear.
        """)

        context_text = "\n\n---\n\n".join(context_chunks)

        user_prompt = textwrap.dedent(f"""
            <CONTEXT>
            {context_text}
            </CONTEXT>

            Question: {query}

            Answer the question using ONLY the information inside <CONTEXT>.
            If the answer is not fully supported by the context,
            reply exactly: Not enough information.
        """)

        raw = call_llm(system_prompt, user_prompt, temperature=0.1, max_new_tokens=192)
        cleaned = clean_llm_output(raw)
        final_answer = verify_answer(cleaned, context_chunks, threshold=0.45)
        final_answer = strip_inner_not_enough_information(final_answer)
        return final_answer

    elif mode == "sqlbolt":

        return _answer_from_sqlbolt(query)

    else:
        return "Unknown mode."


In [None]:
# === Cell 7: Build SQL KB + crawl sqlbolt.com & build index ===

print("Building SQL knowledge base...")
build_sql_knowledge_base()

print("\n Crawling sqlbolt.com and building index (this may take a bit)...")
crawl_and_build_sqlbolt_index(
    start_url="https://sqlbolt.com/",
    max_pages=100,   # adjustable
    sleep_sec=1.0
)

print("\n RAG indexes ready.")


In [None]:
# === Cell 8: Console chatbot (sql + sqlbolt) ===


print("Modes:")
print("  - 'sql'     ‚Üí SQL tutor (local KB)")
print("  - 'sqlbolt' ‚Üí sqlbolt.com QA")
print("Commands:")
print("  - type 'mode sql' or 'mode sqlbolt' to switch")
print("  - type 'exit' to quit\n")

current_mode: Literal["sql", "sqlbolt"] = "sql"
print(f"Current mode: {current_mode}")

while True:
    user_inp = input(f"[{current_mode}] You: ").strip()
    if not user_inp:
        continue

    low = user_inp.lower()

    if low in ["exit", "quit", "stop"]:
        print("Goodbye!")
        break

    # --- STRICT mode switching ---
    if low.startswith("mode"):
        parts = low.split()
        if len(parts) == 2:
            target = parts[1]
            if target in ("sqlbolt", "sqlbot"):   # accept both spellings
                current_mode = "sqlbolt"
                print("üîÅ Switched to sqlbolt.com QA mode.")
                continue
            elif target == "sql":
                current_mode = "sql"
                print("üîÅ Switched to SQL tutor mode.")
                continue
        print("Unknown mode. Use 'mode sql' or 'mode sqlbolt'.")
        continue

    # --- RAG answer ---
    answer = generate_rag_answer(user_inp, mode=current_mode)
    print(f"AI: {answer}\n")
