## Requirements

In [49]:
!pip install -q sentence-transformers faiss-cpu rank-bm25 openai

In [50]:
import json
import numpy as np
import faiss
import re
from collections import defaultdict
from sentence_transformers import SentenceTransformer
from openai import OpenAI
from rank_bm25 import BM25Okapi

In [51]:
import os
from getpass import getpass

os.environ["OPENAI_API_KEY"] = getpass("Enter your OpenAI API key: ")
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])

Enter your OpenAI API key: ··········


## Setup

In [52]:
INDEX_PATH = "/content/pokemon_faiss.index"      # <- change if needed
META_PATH  = "/content/pokemon_metadata.json"    # <- change if needed

# load FAISS
index = faiss.read_index(INDEX_PATH)
with open(META_PATH, "r", encoding="utf-8") as f:
    corpus_meta = json.load(f)

print(f"Loaded {len(corpus_meta)} chunks from metadata.")
print(f"FAISS index has {index.ntotal} vectors.")

# embeding model used for indexing
embed_model = SentenceTransformer("all-MiniLM-L6-v2")

# set up open AI client
client = OpenAI()

Loaded 944 chunks from metadata.
FAISS index has 944 vectors.


## Keyword Tokenizer Setup

In [53]:
def keyword_tokenize(text: str):
    return re.findall(r"\w+", str(text).lower())

def flatten_metadata(meta):
    tokens = []

    # dictionary, recurse to lower layers
    if isinstance(meta, dict):
        for k, v in meta.items():
            tokens.extend(keyword_tokenize(k))
            tokens.extend(flatten_metadata(v))

    # things in the value of the json
    if isinstance(meta, str):
        try:
            loaded = json.loads(meta)
            meta = loaded
        except Exception:
            return keyword_tokenize(meta)
    elif isinstance(meta, list):
        for item in meta:
            tokens.extend(flatten_metadata(item))
    else:
        tokens.extend(keyword_tokenize(meta))

    return tokens

In [54]:
bm25_corpus = []

for m in corpus_meta:
    tokens = []

    # add the search fields
    for key in ("id", "pokemon", "section"):
        if key in m and m[key] is not None:
            tokens.extend(keyword_tokenize(m[key]))

    # and the metadata to the keywords
    meta = m.get("metadata", {})
    tokens.extend(flatten_metadata(meta))

    bm25_corpus.append(tokens)

bm25 = BM25Okapi(bm25_corpus)

## Retrieval

In [55]:
# search using keyword matching
def keyword_search(query: str, top_k: int = 50):
    tokens = keyword_tokenize(query)
    scores = bm25.get_scores(tokens)
    top_idx = np.argsort(scores)[::-1][:top_k]
    return [(int(i), float(scores[i])) for i in top_idx if scores[i] > 0]

In [56]:
# search using typical rag similarities
def dense_search(query: str, top_k: int = 50):
    q_vec = embed_model.encode([query], convert_to_tensor=False).astype("float32")
    D, I = index.search(q_vec, top_k)
    D, I = D[0], I[0]

    results = [(int(idx), -float(dist)) for dist, idx in zip(D, I)]
    return results

In [57]:
def hybrid_search(
    query: str,
    k: int = 8,
    dense_pool: int = 50,
    keyword_pool: int = 50,
    rrf_k: int = 60,
    dense_weight: float = 2.0,
    keyword_weight: float = 1.0,
    debug: bool = False,
):
    dense_results = dense_search(query, top_k=dense_pool)
    keyword_results = keyword_search(query, top_k=keyword_pool)

    if debug:
        print("\n================ DENSE RESULTS ================\n")
        for rank, (idx, score) in enumerate(dense_results[:k], start=1):
            doc = corpus_meta[idx]
            text = doc.get("text", "")
            pokemon = doc.get("pokemon", "Unknown")
            section = doc.get("section", "unknown-section")
            print(f"DENSE #{rank} | idx={idx} | score={score:.4f} | [{pokemon} — {section}]")
            print(text[:300].replace("\n", " "))
            if len(text) > 300:
                print("... [truncated]")
            print()

        print("\n================ KEYWORD (BM25) RESULTS ================\n")
        for rank, (idx, score) in enumerate(keyword_results[:k], start=1):
            doc = corpus_meta[idx]
            text = doc.get("text", "")
            pokemon = doc.get("pokemon", "Unknown")
            section = doc.get("section", "unknown-section")
            print(f"BM25 #{rank} | idx={idx} | score={score:.4f} | [{pokemon} — {section}]")
            print(text[:300].replace("\n", " "))
            if len(text) > 300:
                print("... [truncated]")
            print()

    dense_rank = {idx: rank for rank, (idx, _) in enumerate(dense_results, start=1)}
    keyword_rank = {idx: rank for rank, (idx, _) in enumerate(keyword_results, start=1)}

    dense_score_map = {idx: score for idx, score in dense_results}
    keyword_score_map = {idx: score for idx, score in keyword_results}

    all_ids = set(dense_rank.keys()) | set(keyword_rank.keys())

    fused = []
    for idx in all_ids:
        score = 0.0

        dr = dense_rank.get(idx)
        if dr is not None:
            # weighted dense term
            score += dense_weight * (1.0 / (rrf_k + dr))

        kr = keyword_rank.get(idx)
        if kr is not None:
            # weaker keyword term
            score += keyword_weight * (1.0 / (rrf_k + kr))

        fused.append(
            {
                "idx": idx,
                "hybrid_score": float(score),
                "dense_rank": dr,
                "keyword_rank": kr,
                "dense_score": float(dense_score_map.get(idx, 0.0)),
                "keyword_score": float(keyword_score_map.get(idx, 0.0)),
                "doc": corpus_meta[idx],
            }
        )

    fused.sort(key=lambda r: r["hybrid_score"], reverse=True)
    return fused[:k]


In [58]:
def build_context(results, max_chars: int = 4000):
    parts = []
    total_len = 0

    for r in results:
        m = r["doc"]
        header = f"[{m.get('pokemon', 'Unknown')} — {m.get('section', 'unknown-section')}]"
        desc = m.get("description") or ""
        text = m.get("text") or ""

        chunk_str = header + "\n"
        if desc:
            chunk_str += desc + "\n"
        chunk_str += text + "\n"

        if total_len + len(chunk_str) > max_chars:
            break

        parts.append(chunk_str)
        total_len += len(chunk_str)

    return "\n".join(parts)


## Text Generation

In [61]:
def answer_with_rag(query: str, k: int = 8) -> str:
    results = hybrid_search(
        query=query,
        k=k,
        dense_pool=50,
        keyword_pool=50,
        rrf_k=60,
        dense_weight=2.0,
        keyword_weight=1.0,
        debug=False,
    )

    context = build_context(results)

    prompt = f"""You are a helpful Pokédex assistant for all generations of Pokémon.
Use ONLY the provided context when answering – if the context does not contain
the answer, say you don't know.

Context:
{context}

User question: {query}
Answer in a concise paragraph, including specific numbers, names, and conditions when relevant.
"""

    print(context)

    response = client.responses.create(
        model="gpt-5-mini",
        input=prompt,
    )

    return response.output_text


In [64]:
print(answer_with_rag("What is clefa height?"))

[Cleffa — core]
Cleffa is a Fairy type Pokémon introduced in Generation 2.
Cleffa is a Fairy-type Pokémon introduced in Generation 2. It is number 173 in the National Pokédex. Cleffa is classified as the Star Shape Pokémon. Cleffa is 30 tall and weighs 300. Cleffa can have the abilities Cute Charm, Magic Guard, and Friend Guard. Its name is derived from clef: musical symbol.

[Cleffa — statistics]
Cleffa has a base stat total of 218, with base stats of 50 HP, 25 Attack, 28 Defense, 45 Special Attack, 55 Special Defense, and 15 Speed. At level 100, Cleffa's HP can range from 210 to 304, Attack from 49 to 163, Defense from 54 to 170, Special Attack from 85 to 207, Special Defense from 103 to 229, and Speed from 31 to 141, depending on its nature, IVs, and EVs.

[Clefable — statistics]
Clefable has a base stat total of 483, with base stats of 95 HP, 70 Attack, 73 Defense, 95 Special Attack, 90 Special Defense, and 60 Speed. At level 100, Clefable's HP can range from 300 to 394, Attack fro