In [1]:
import os
import requests
import urllib.parse

# Used ChatGPT to generate this list of pages and download logic
wiki_pages = [
    "Chess_opening",
    "List_of_chess_openings",
    "King%27s_Pawn_Game",
    "Queen%27s_Pawn_Game",
    "Open_Game",
    "Semi-Open_Game",
    "Closed_Game",
    "Flank_opening",

    "Sicilian_Defence",
    "French_Defence",
    "Caro%E2%80%93Kann_Defence",
    "Pirc_Defence",
    "Modern_Defense",
    "Alekhine%27s_Defence",
    "Scandinavian_Defense",

    "Nimzo-Indian_Defense",
    "Queen%27s_Indian_Defence",
    "Bogo-Indian_Defense",
    "King%27s_Indian_Defence",
    "Gr%C3%BCnfeld_Defence",
    "Benoni_Defense",
    "Benko_Gambit",
    "Dutch_Defense",
    "Slav_Defense",
    "Semi-Slav_Defense",
    "Queen%27s_Gambit",
    "Catalan_Opening",

    "Ruy_Lopez",
    "Italian_Game",
    "Scotch_Game",
    "Four_Knights_Game",
    "Vienna_Game",
    "English_Opening",
    "R%C3%A9ti_Opening",
    "Bird%27s_Opening",
    "London_System"
]

kb_dir = "/Users/jpmalone/Documents/Applied_ML/applied_ml_hw4/kb"
os.makedirs(kb_dir, exist_ok=True)


HEADERS = {
    "User-Agent": "applied_ml_hw4/1.0 (contact: jpmalone@iu.edu)"
}

def download_wiki_page(title):
    # Decode %27, %C3%BC, %E2%80%93, etc.
    decoded_title = urllib.parse.unquote(title)

    params = {
        "action": "query",
        "prop": "extracts",
        "explaintext": True,
        "exsectionformat": "plain",
        "redirects": 1,
        "format": "json",
        "titles": decoded_title,
    }

    r = requests.get(
        "https://en.wikipedia.org/w/api.php",
        params=params,
        headers=HEADERS,
    )
    r.raise_for_status()

    pages = r.json()["query"]["pages"]
    page = next(iter(pages.values()))

    if "missing" in page:
        raise ValueError(f"Page not found: {decoded_title}")

    if "extract" not in page or not page["extract"].strip():
        raise ValueError(f"No extract available for page: {decoded_title}")

    return page["extract"]

for title in wiki_pages:
    print(f"Downloading: {title}")
    try:
        text = download_wiki_page(title)
    except Exception as e:
        print(f"Skipped: {e}")
        continue

    filename = (
        title.lower()
        .replace("%27", "")
        .replace("%e2%80%93", "-")
        .replace("%c3%bc", "u")
        + ".txt"
    )

    path = os.path.join(kb_dir, filename)
    with open(path, "w", encoding="utf-8") as f:
        f.write(text)

Downloading: Chess_opening
Downloading: List_of_chess_openings
Downloading: King%27s_Pawn_Game
Downloading: Queen%27s_Pawn_Game
Downloading: Open_Game
Downloading: Semi-Open_Game
Downloading: Closed_Game
Downloading: Flank_opening
Downloading: Sicilian_Defence
Downloading: French_Defence
Downloading: Caro%E2%80%93Kann_Defence
Downloading: Pirc_Defence
Downloading: Modern_Defense
Downloading: Alekhine%27s_Defence
Downloading: Scandinavian_Defense
Downloading: Nimzo-Indian_Defense
Downloading: Queen%27s_Indian_Defence
Downloading: Bogo-Indian_Defense
Downloading: King%27s_Indian_Defence
Downloading: Gr%C3%BCnfeld_Defence
Downloading: Benoni_Defense
Downloading: Benko_Gambit
Downloading: Dutch_Defense
Downloading: Slav_Defense
Downloading: Semi-Slav_Defense
Downloading: Queen%27s_Gambit
Downloading: Catalan_Opening
Downloading: Ruy_Lopez
Downloading: Italian_Game
Downloading: Scotch_Game
Downloading: Four_Knights_Game
Downloading: Vienna_Game
Downloading: English_Opening
Downloading: R%C3

In [2]:
import os
import re
import numpy as np
import torch
import faiss

from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM

device = "cuda" if torch.cuda.is_available() else "cpu"

LLM_NAME = "google/gemma-3-1b-it"
EMB_NAME = "sentence-transformers/all-MiniLM-L6-v2"

tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
llm = AutoModelForCausalLM.from_pretrained(
    LLM_NAME,
    dtype=torch.float16 if device == "cuda" else torch.float32,
    low_cpu_mem_usage=True,
).to(device)
llm.eval()

embedder = SentenceTransformer(EMB_NAME, device="cpu")

def load_texts_from_dir(dir_path):
    texts = []
    sources = []
    for fn in sorted(os.listdir(dir_path)):
        if fn.lower().endswith(".txt"):
            with open(os.path.join(dir_path, fn), "r", encoding="utf-8") as f:
                texts.append(f.read())
            sources.append(fn)
    return texts, sources

def chunk_text(text, chunk_size=800, overlap=150):
    text = re.sub(r"\s+", " ", text).strip()
    chunks = []
    start = 0
    while start < len(text):
        end = min(len(text), start + chunk_size)
        chunks.append(text[start:end])
        if end == len(text):
            break
        start = end - overlap
    return chunks

raw_texts, raw_sources = load_texts_from_dir(kb_dir)

chunks = []
chunk_meta = []
for src, doc in zip(raw_sources, raw_texts):
    for i, c in enumerate(chunk_text(doc)):
        chunks.append(c)
        chunk_meta.append((src, i))

chunk_emb = embedder.encode(
    chunks,
    batch_size=32,
    normalize_embeddings=True,
    convert_to_numpy=True,
    show_progress_bar=True,
)

dim = chunk_emb.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(chunk_emb)

def retrieve(query, k=5):
    q_emb = embedder.encode(
        [query],
        batch_size=1,
        normalize_embeddings=True,
        convert_to_numpy=True,
    )
    scores, idx = index.search(q_emb, k)
    return [
        {
            "score": float(s),
            "source": chunk_meta[j][0],
            "chunk_id": chunk_meta[j][1],
            "text": chunks[j],
        }
        for s, j in zip(scores[0], idx[0])
    ]

def generate_rag_answer(query, k=5, max_new_tokens=250):
    retrieved = retrieve(query, k)

    context = "\n\n".join(
        f"[Source: {r['source']} | chunk {r['chunk_id']}]\n{r['text']}"
        for r in retrieved
    )

    prompt = (
        "Answer using ONLY the CONTEXT.\n"
        "If the answer is not in the context, output exactly:\n"
        "I don't know based on the provided sources.\n\n"
        f"CONTEXT:\n{context}\n\n"
        f"QUESTION:\n{query}\n\n"
        "ANSWER:\n"
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        out = llm.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            eos_token_id=tokenizer.eos_token_id,
        )

    text = tokenizer.decode(out[0], skip_special_tokens=True)
    return text.split("ANSWER:\n", 1)[-1].strip(), retrieved

  from .autonotebook import tqdm as notebook_tqdm
Batches: 100%|██████████| 28/28 [00:18<00:00,  1.54it/s]


In [3]:
answer, sources = generate_rag_answer(
    "What are the main strategic ideas of the Sicilian Defense?",
    k=5
)

answer = answer.split("\n\n", 1)[0].strip()  

print("ANSWER:\n", answer)
print("\nSOURCES:")
for s in sources:
    print(f"- {s['source']} (chunk {s['chunk_id']}, score={s['score']:.3f})")

The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


ANSWER:
 The main strategic ideas of the Sicilian Defense are:
1.  White pushes a kingside pawn, asserting control over the d4 square.
2.  White develops the king's knight with 2.Nf3.
3.  Black replies with 2...Nc6, 2...d6, or 2...e6.
4.  White often holds the initiative on the side of the board.
5.  Black needs to avoid a quick attack.

SOURCES:
- sicilian_defence.txt (chunk 9, score=0.633)
- sicilian_defence.txt (chunk 0, score=0.626)
- sicilian_defence.txt (chunk 2, score=0.615)
- english_opening.txt (chunk 8, score=0.613)
- sicilian_defence.txt (chunk 15, score=0.602)
