In [1]:
!pip install -q sentence-transformers faiss-cpu openai

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.6/23.6 MB[0m [31m37.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import json
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from openai import OpenAI

In [3]:
import os
from getpass import getpass

os.environ["OPENAI_API_KEY"] = getpass("Enter your OpenAI API key: ")

from openai import OpenAI
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])

Enter your OpenAI API key: ··········


In [5]:
INDEX_PATH = "/content/pokemon_faiss.index"      # <- change if needed
META_PATH  = "/content/pokemon_metadata.json"    # <- change if needed

# Load FAISS index
index = faiss.read_index(INDEX_PATH)

# Load metadata list; must be in the same order as embeddings when you built the index
with open(META_PATH, "r", encoding="utf-8") as f:
    corpus_meta = json.load(f)

print(f"Loaded {len(corpus_meta)} chunks from metadata.")
print(f"FAISS index has {index.ntotal} vectors.")

# Load the same embedding model used for indexing
embed_model = SentenceTransformer("all-MiniLM-L6-v2")

# OpenAI client (expects OPENAI_API_KEY in environment)
client = OpenAI()

Loaded 944 chunks from metadata.
FAISS index has 944 vectors.


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [6]:
import re
from collections import defaultdict

# --- 1. Basic corpus stats ---

ALL_POKEMON = sorted({m["pokemon"] for m in corpus_meta})
ALL_SECTIONS = sorted({m["section"] for m in corpus_meta})

print("Example Pokémon:", ALL_POKEMON[:10])
print("Sections:", ALL_SECTIONS)

# For each section, collect all metadata keys that appear in that section
SECTION_TO_FIELDS = defaultdict(set)
for m in corpus_meta:
    section = m["section"]
    meta = m.get("metadata", {}) or {}
    for key in meta.keys():
        SECTION_TO_FIELDS[section].add(key)

print("\nMetadata keys per section:")
for section in ALL_SECTIONS:
    fields = sorted(SECTION_TO_FIELDS[section])
    print(f"  {section}: {fields}")


# --- 2. Normalization helpers ---

def norm(s: str) -> str:
    """Simple normalization for keys/sections/phrases."""
    return s.strip().lower().replace(" ", "").replace("-", "").replace("_", "")


# --- 3. Section synonyms grounded in actual section names ---

def build_section_synonyms(all_sections):
    syn = {}

    # Variants from actual section names
    for sec in all_sections:
        canon = sec
        variants = {
            sec,
            sec.lower(),
            sec.replace("-", " "),
            sec.replace("-", ""),
            sec.lower().replace("-", " "),
            sec.lower().replace("-", ""),
        }
        for v in variants:
            syn.setdefault(norm(v), canon)

    # Manual extras conditional on the section actually existing
    manual = {
        "core info": "core",
        "info": "core",
        "base stats": "base-stats",
        "stats": "base-stats",
        "evo": "evolutions",
        "evolution": "evolutions",
        "evolutions": "evolutions",
    }
    for raw, canon in manual.items():
        if canon in all_sections:
            syn[norm(raw)] = canon

    return syn

SECTION_SYNONYMS = build_section_synonyms(ALL_SECTIONS)

def resolve_section(phrase: str):
    return SECTION_SYNONYMS.get(norm(phrase))


# --- 4. Field synonyms grounded in actual metadata keys ---

ALL_META_FIELDS = sorted({
    key
    for m in corpus_meta
    for key in (m.get("metadata", {}) or {}).keys()
})

def build_field_synonyms(all_fields):
    syn = {}

    for field in all_fields:
        canon = field  # actual key
        variants = {
            field,
            field.lower(),
            field.replace(" ", ""),
            field.lower().replace(" ", ""),
            field.replace("-", ""),
            field.lower().replace("-", ""),
        }
        for v in variants:
            syn.setdefault(norm(v), canon)

    # Manual extras / typo fixes using real keys
    if "Generatin" in all_fields:   # typo preserved in data
        for v in ["generation", "gen", "gen."]:
            syn[norm(v)] = "Generatin"

    if "Name Etymology" in all_fields:
        for v in ["name etymology", "name-etymology", "etymology"]:
            syn[norm(v)] = "Name Etymology"

    if "Abilities" in all_fields:
        for v in ["ability", "abilities", "hidden ability"]:
            syn[norm(v)] = "Abilities"

    if "Height" in all_fields:
        for v in ["height", "tall", "tallness"]:
            syn[norm(v)] = "Height"

    if "Weight" in all_fields:
        for v in ["weight", "heavy", "heaviness"]:
            syn[norm(v)] = "Weight"

    return syn

FIELD_SYNONYMS = build_field_synonyms(ALL_META_FIELDS)

def resolve_field(phrase: str):
    return FIELD_SYNONYMS.get(norm(phrase))


Example Pokémon: ['Abra', 'Aerodactyl', 'Aipom', 'Alakazam', 'Ampharos', 'Arbok', 'Arcanine', 'Ariados', 'Articuno', 'Azumarill']
Sections: ['breeding', 'core', 'statistics', 'training']

Metadata keys per section:
  breeding: ['eggCycles', 'eggGroups', 'genderFemale', 'genderMale']
  core: ['Abilities', 'Generatin', 'Height', 'Name Etymology', 'National Dex Number', 'Species', 'Types', 'Weight']
  statistics: ['baseStatTotal', 'baseStats', 'maxStatsLevel100', 'minStatsLevel100']
  training: ['Base Experience', 'Base Friendship', 'Catch Rate', 'EV Yield', 'Growth Rate']


In [7]:
def guess_pokemon_from_query(query: str):
    q = query.lower()
    candidates = [name for name in ALL_POKEMON if name.lower() in q]
    if not candidates:
        return None
    # Prefer longest match (handles things like "Nidoran" vs "Nidoran♀")
    return max(candidates, key=len)


def guess_sections_from_query(query: str):
    """
    Try to infer section(s) from explicit phrases in the query.
    """
    q = query.lower()
    tokens = re.findall(r"\w+", q)
    sections = set()

    # try longer n-grams first
    max_ngram = min(4, len(tokens))
    for size in range(max_ngram, 0, -1):
        for start in range(0, len(tokens) - size + 1):
            phrase = " ".join(tokens[start:start + size])
            sec = resolve_section(phrase)
            if sec:
                sections.add(sec)

    # also try whole query as a phrase
    whole = resolve_section(query)
    if whole:
        sections.add(whole)

    return sections


def guess_fields_from_query(query: str):
    """
    Infer metadata fields from the query using FIELD_SYNONYMS.
    Returns a set of canonical field names.
    """
    q = query.lower()
    tokens = re.findall(r"\w+", q)
    fields = set()

    max_ngram = min(4, len(tokens))
    for size in range(max_ngram, 0, -1):
        for start in range(0, len(tokens) - size + 1):
            phrase = " ".join(tokens[start:start + size])
            fld = resolve_field(phrase)
            if fld:
                fields.add(fld)

    return fields


def infer_sections_from_fields(field_hints):
    """
    Use SECTION_TO_FIELDS to find which sections are relevant
    for the given field hints (e.g., 'Height' -> 'core').
    """
    sections = set()
    for sec, fields in SECTION_TO_FIELDS.items():
        if any(f in fields for f in field_hints):
            sections.add(sec)
    return sections


def parse_query_metadata(query: str):
    """
    Return a dict of hints:
      {
        'pokemon': str | None,
        'sections': set[str],
        'fields': set[str],
      }
    """
    pokemon = guess_pokemon_from_query(query)
    sections_from_text = guess_sections_from_query(query)
    fields = guess_fields_from_query(query)
    sections_from_fields = infer_sections_from_fields(fields)

    sections = set()
    sections |= sections_from_text
    sections |= sections_from_fields

    return {
        "pokemon": pokemon,
        "sections": sections,
        "fields": fields,
    }


In [8]:
def hybrid_search(
    query: str,
    k: int = 8,
    candidate_pool_size: int = 50,
    w_pokemon: float = 2.0,
    w_section: float = 1.0,
    w_field: float = 1.5,
    alpha: float = 1.0,
):
    """
    Hybrid retrieval using embeddings + metadata-aware re-ranking.

    Score = -alpha * distance
            + w_pokemon * 1[pokemon match]
            + w_section * (# of matching hinted sections)
            + w_field   * (# of hinted fields present in metadata)
    """
    hints = parse_query_metadata(query)
    pokemon_hint = hints["pokemon"]
    section_hints = hints["sections"]
    field_hints = hints["fields"]

    # --- 1. global semantic search with FAISS ---
    q_vec = embed_model.encode([query], convert_to_tensor=False)
    D, I = index.search(np.array(q_vec, dtype="float32"),
                        min(candidate_pool_size, index.ntotal))

    results = []
    for dist, idx in zip(D[0], I[0]):
        meta = corpus_meta[int(idx)]
        mdata = meta.get("metadata", {}) or {}

        base_score = -alpha * float(dist)

        # pokemon bonus
        pokemon_bonus = w_pokemon if (pokemon_hint and meta["pokemon"] == pokemon_hint) else 0.0

        # section bonus (could be multiple hinted sections, though usually 0 or 1)
        section_bonus = 0.0
        if section_hints and meta["section"] in section_hints:
            section_bonus = w_section

        # field bonus: count how many hinted fields are present in this chunk's metadata
        field_bonus = 0.0
        if field_hints:
            matched_fields = [f for f in field_hints if f in mdata]
            if matched_fields:
                field_bonus = w_field * len(matched_fields)

        total_score = base_score + pokemon_bonus + section_bonus + field_bonus

        results.append({
            "idx": int(idx),
            "dist": float(dist),
            "score": float(total_score),
            "meta": meta,
            "matched_fields": list(field_hints & set(mdata.keys())),
        })

    # --- 2. sort by total score and return top-k ---
    results.sort(key=lambda r: r["score"], reverse=True)
    return results[:k]


In [20]:
def build_context(chosen_chunks, max_chars: int = 4000) -> str:
    """
    Build a readable context string from retrieved chunks.
    Truncate if needed to keep prompts reasonably sized.
    """
    parts = []
    total_len = 0

    for c in chosen_chunks:
        m = c["meta"]
        header = f"[{m['pokemon']} — {m['section']}]"
        # description may be None
        desc = m.get("description") or ""
        text = m["text"]

        chunk_str = header + "\n"
        if desc:
            chunk_str += desc + "\n"
        chunk_str += text + "\n"

        if total_len + len(chunk_str) > max_chars:
            break

        parts.append(chunk_str)
        total_len += len(chunk_str)

    return "\n".join(parts)


def answer_with_rag(query: str, k: int = 8, debug: bool = True) -> str:
    # 1) Retrieve
    retrieved = hybrid_search(query, k=k)
    context = build_context(retrieved)

    if debug:
        print("---- RETRIEVED CHUNKS (DEBUG) ----")
        for i, c in enumerate(retrieved, 1):
            m = c["meta"]
            print(f"{i}. idx={c['idx']}, dist={c['dist']:.4f}, "
                  f"pokemon={m['pokemon']}, section={m['section']}")
        print("---- END CHUNKS ----\n")

    # 2) Construct prompt
    prompt = f"""
You are a helpful Pokémon encyclopedia assistant.
Use ONLY the provided context to answer the user's question.
If the answer is not clearly in the context, say you don't know.

Context:
{context}

User question: {query}
Answer in a concise paragraph, but include specific numbers, names, and conditions when relevant.
"""

    # print(prompt)

    # 3) Call OpenAI (Responses API)
    response = client.responses.create(
        model="gpt-5-mini",   # or another model you prefer
        input=prompt,
    )

    # response.output_text gives the concatenated text from output tokens
    return response.output_text


In [22]:
print(answer_with_rag("What is bulbasour height?"))

---- RETRIEVED CHUNKS (DEBUG) ----
1. idx=0, dist=1.2441, pokemon=Bulbasaur, section=core
2. idx=420, dist=1.4012, pokemon=Hitmonlee, section=core
3. idx=448, dist=1.4046, pokemon=Chansey, section=core
4. idx=700, dist=1.4271, pokemon=Togetic, section=core
5. idx=932, dist=1.4344, pokemon=Stantler, section=core
6. idx=440, dist=1.4414, pokemon=Rhyhorn, section=core
7. idx=696, dist=1.4428, pokemon=Togepi, section=core
8. idx=720, dist=1.4505, pokemon=Ampharos, section=core
---- END CHUNKS ----

Bulbasaur, the Seed Pokémon, is 70 tall.
