In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/climate-change-qa-enlarged/climate_change_qa_enlarged.json


In [5]:
# ============================================================
# ONE-CELL Kaggle: Local deps override + GPT-2 Fine-tune (Train/Val + Val Loss) + Save + Demo
# Fixes: huggingface-hub version mismatch with system transformers (>=0.34,<1.0)
# Also: auto-fallback to CPU if P100 kernel mismatch occurs ("no kernel image")
# ============================================================

import os, sys, subprocess, shutil

DEPS_DIR = "/kaggle/working/_deps"

def fresh_deps_dir():
    if os.path.isdir(DEPS_DIR):
        shutil.rmtree(DEPS_DIR, ignore_errors=True)
    os.makedirs(DEPS_DIR, exist_ok=True)
    if DEPS_DIR not in sys.path:
        sys.path.insert(0, DEPS_DIR)

def pip_install_local(pkgs):
    cmd = [sys.executable, "-m", "pip", "install", "-q", "--no-cache-dir", "-t", DEPS_DIR] + pkgs
    subprocess.check_call(cmd)

def force_local_imports():
    # Remove already-loaded modules so they re-import from DEPS_DIR
    for mod in list(sys.modules.keys()):
        if mod.startswith(("huggingface_hub", "fsspec")):
            del sys.modules[mod]

def setup_local_overrides():
    fresh_deps_dir()
    # Key fix: transformers 4.57.1 wants huggingface-hub >=0.34,<1.0
    pip_install_local([
        "huggingface-hub==0.35.1",
        "fsspec[http]==2025.10.0",
    ])
    force_local_imports()

    import huggingface_hub, fsspec
    print("‚úÖ huggingface_hub:", huggingface_hub.__version__, "from", huggingface_hub.__file__)
    print("‚úÖ fsspec:", fsspec.__version__, "from", fsspec.__file__)

setup_local_overrides()

# ----------------------------
# Imports (NO Trainer!)
# ----------------------------
import json
import math
import random
from typing import Dict, List, Tuple

import torch
from torch.utils.data import Dataset, DataLoader, Subset

from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers.optimization import get_linear_schedule_with_warmup

# ----------------------------
# Config
# ----------------------------
DATA_PATH = "/kaggle/input/climate-change-qa-enlarged/climate_change_qa_enlarged.json"
MODEL_NAME = "gpt2"
SAVE_DIR = "./fine_tuned_gpt2"

MAX_LENGTH = 256
EPOCHS = 3
BATCH_SIZE = 4
EVAL_BATCH_SIZE = 8
GRAD_ACCUM = 2
LR = 5e-5
WARMUP_RATIO = 0.06
WEIGHT_DECAY = 0.01
SEED = 42
LOG_EVERY = 50
VAL_RATIO = 0.10

# ----------------------------
# Utilities
# ----------------------------
def set_seed(seed: int) -> None:
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def load_json_qa(path: str) -> List[Dict[str, str]]:
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    if not isinstance(data, list) or not data:
        raise ValueError("Dataset JSON must be a non-empty list of {question, answer}.")
    if "question" not in data[0] or "answer" not in data[0]:
        raise ValueError("Each item must contain keys: 'question' and 'answer'.")
    return data

def train_val_split(n: int, val_ratio: float, seed: int) -> Tuple[List[int], List[int]]:
    idx = list(range(n))
    rng = random.Random(seed)
    rng.shuffle(idx)
    val_n = max(1, int(n * val_ratio))
    return idx[val_n:], idx[:val_n]

@torch.no_grad()
def eval_loss(model: GPT2LMHeadModel, dataloader: DataLoader, device: torch.device) -> float:
    model.eval()
    total = 0.0
    steps = 0
    for batch in dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        out = model(**batch)
        total += out.loss.item()
        steps += 1
    return total / max(1, steps)

@torch.no_grad()
def generate_answer(model: GPT2LMHeadModel, tokenizer: GPT2Tokenizer, question: str, max_new_tokens: int = 80) -> str:
    device = next(model.parameters()).device
    prompt = f"### Question:\n{question}\n\n### Answer:\n"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    out = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=0.9,
        temperature=0.8,
        repetition_penalty=1.15,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    text = tokenizer.decode(out[0], skip_special_tokens=True)
    return text.split("### Answer:", 1)[1].strip() if "### Answer:" in text else text.strip()

# ----------------------------
# Dataset
# ----------------------------
class QACausalLMDataset(Dataset):
    def __init__(self, items: List[Dict[str, str]], tokenizer: GPT2Tokenizer, max_length: int):
        self.items = items
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self) -> int:
        return len(self.items)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        q = str(self.items[idx]["question"]).strip()
        a = str(self.items[idx]["answer"]).strip()
        text = f"### Question:\n{q}\n\n### Answer:\n{a}"

        enc = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )
        input_ids = enc["input_ids"].squeeze(0)
        attention_mask = enc["attention_mask"].squeeze(0)
        labels = input_ids.clone()
        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

# ----------------------------
# Robust device selection
# ----------------------------
def choose_device() -> torch.device:
    if not torch.cuda.is_available():
        return torch.device("cpu")
    try:
        _ = (torch.tensor([1.0], device="cuda") * 2.0)
        torch.cuda.synchronize()
        return torch.device("cuda")
    except Exception as e:
        print("‚ö†Ô∏è CUDA present but not runnable:", repr(e))
        print("‚û°Ô∏è Falling back to CPU.")
        return torch.device("cpu")

# ----------------------------
# Train loop
# ----------------------------
def main() -> None:
    set_seed(SEED)
    device = choose_device()
    print("Device:", device)

    tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token

    model = GPT2LMHeadModel.from_pretrained(MODEL_NAME)
    model.resize_token_embeddings(len(tokenizer))
    model.to(device)

    items = load_json_qa(DATA_PATH)
    full_ds = QACausalLMDataset(items, tokenizer, MAX_LENGTH)

    train_idx, val_idx = train_val_split(len(full_ds), VAL_RATIO, SEED)
    train_ds = Subset(full_ds, train_idx)
    val_ds = Subset(full_ds, val_idx)

    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_dl = DataLoader(val_ds, batch_size=EVAL_BATCH_SIZE, shuffle=False, num_workers=2)

    total_steps = math.ceil(len(train_dl) / GRAD_ACCUM) * EPOCHS
    warmup_steps = int(total_steps * WARMUP_RATIO)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

    use_amp = (device.type == "cuda")
    scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

    print(f"Train size: {len(train_ds)} | Val size: {len(val_ds)}")
    print(f"Steps: {total_steps} | Warmup: {warmup_steps} | AMP: {use_amp}")

    global_step = 0
    running = 0.0

    for epoch in range(1, EPOCHS + 1):
        print(f"\n=== Epoch {epoch}/{EPOCHS} ===")
        model.train()
        optimizer.zero_grad(set_to_none=True)

        for step, batch in enumerate(train_dl, start=1):
            batch = {k: v.to(device) for k, v in batch.items()}

            try:
                with torch.amp.autocast("cuda", enabled=use_amp):
                    out = model(**batch)
                    loss = out.loss / GRAD_ACCUM

                scaler.scale(loss).backward()
                running += loss.item()

                if step % GRAD_ACCUM == 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad(set_to_none=True)
                    scheduler.step()

                    global_step += 1
                    if global_step % LOG_EVERY == 0:
                        print(f"step={global_step} train_loss={running/LOG_EVERY:.4f}")
                        running = 0.0

            except RuntimeError as e:
                if "no kernel image" in str(e).lower():
                    print("\n‚ö†Ô∏è GPU kernel mismatch detected. Switching to CPU and continuing...")
                    device = torch.device("cpu")
                    model.to(device)
                    use_amp = False
                    scaler = torch.amp.GradScaler("cuda", enabled=False)
                    batch = {k: v.to(device) for k, v in batch.items()}
                    out = model(**batch)
                    (out.loss / GRAD_ACCUM).backward()
                else:
                    raise

        vloss = eval_loss(model, val_dl, device)
        print(f"Epoch {epoch} validation_loss={vloss:.4f}")
        print("\n[Demo] Q: What is climate change?")
        print("[Demo] A:", generate_answer(model, tokenizer, "What is climate change?"))

    print("\nSaving to:", SAVE_DIR)
    model.save_pretrained(SAVE_DIR)
    tokenizer.save_pretrained(SAVE_DIR)
    print("‚úÖ Saved.")

    print("\n--- FINAL INFERENCE DEMO ---")
    model.eval()
    for q in [
        "What are greenhouse gases?",
        "How does deforestation contribute to climate change?",
        "How can individuals reduce their carbon footprint?",
    ]:
        print("\nQ:", q)
        print("A:", generate_answer(model, tokenizer, q))

main()


     ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ 57.7/57.7 kB 3.9 MB/s eta 0:00:00
     ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ 75.1/75.1 kB 31.0 MB/s eta 0:00:00
   ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ 563.3/563.3 kB 22.3 MB/s eta 0:00:00
   ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ 201.0/201.0 kB 348.5 MB/s eta 0:00:00
   ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ 1.8/1.8 MB 98.7 MB/s eta 0:00:00
   ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ 3.3/3.3 MB 143.4 MB/s eta 0:00:00
   ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚î

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-adk 1.22.1 requires google-cloud-bigquery-storage>=2.0.0, which is not installed.
bigframes 2.26.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
preprocessing 0.1.13 requires nltk==3.2.4, but you have nltk 3.9.2 which is incompatible.
google-colab 1.0.0 requires google-auth==2.38.0, but you have google-auth 2.47.0 which is incompatible.
google-colab 1.0.0 requires jupyter-server==2.14.0, but you have jupyter-server 2.12.5 which is incompatible.
google-colab 1.0.0 requires requests==2.32.4, but you have requests 2.32.5 which is incompatible.
dopamine-rl 4.1.2 requires gymnasium>=1.0.0, but you have gymnasium 0.29.0 which is incompatible.
langchain-core 0.3.79 requires packaging<26.0.0,>=23.2.0, but you have packaging 26.0 which is incompatible.
cudf-cu12 25.6.0 requires pyarr

‚úÖ huggingface_hub: 0.35.1 from /kaggle/working/_deps/huggingface_hub/__init__.py
‚úÖ fsspec: 2025.10.0 from /kaggle/working/_deps/fsspec/__init__.py


2026-01-26 20:06:54.460773: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769458014.658810     124 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769458014.714019     124 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769458015.201064     124 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769458015.201095     124 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769458015.201098     124 computation_placer.cc:177] computation placer alr

Device: cuda


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Train size: 117 | Val size: 13
Steps: 45 | Warmup: 2 | AMP: True

=== Epoch 1/3 ===


`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Epoch 1 validation_loss=0.4477

[Demo] Q: What is climate change?
[Demo] A: This issue arose in the late 1970s when a global warming alarmism and political pressure were building to control carbon dioxide emissions.

=== Epoch 2/3 ===
Epoch 2 validation_loss=0.2892

[Demo] Q: What is climate change?
[Demo] A: Climate change affects the weather, which causes extreme precipitation events that cause flooding and storms.

=== Epoch 3/3 ===
Epoch 3 validation_loss=0.2777

[Demo] Q: What is climate change?
[Demo] A: Climate Change affects ecosystems, affecting wildlife and the environment. It disrupts food supply chains that sustain human activities, such as logging and agriculture; destroys livelihood-building opportunities for people; or causes widespread famine across many countries.

Saving to: ./fine_tuned_gpt2
‚úÖ Saved.

--- FINAL INFERENCE DEMO ---

Q: What are greenhouse gases?
A: Greenhouse gas emissions from burning fossil fuels, such as coal, generate heat and produce carbon diox

In [6]:
# ============================================================
# Kaggle Notebook UI: Textbox -> Search question in dataset -> Show answer
# Works with your JSON format: [{"question": "...", "answer": "..."}, ...]
# ============================================================

import json
import re
from difflib import SequenceMatcher

import ipywidgets as widgets
from IPython.display import display, Markdown, clear_output

DATA_PATH = "/kaggle/input/climate-change-qa-enlarged/climate_change_qa_enlarged.json"

def load_qa(path: str):
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    if not isinstance(data, list):
        raise ValueError("Dataset JSON must be a list of objects.")
    for i, item in enumerate(data[:5]):
        if "question" not in item or "answer" not in item:
            raise ValueError(f"Item {i} missing 'question' or 'answer'.")
    return data

QA = load_qa(DATA_PATH)

def normalize(text: str) -> str:
    text = text.strip().lower()
    text = re.sub(r"\s+", " ", text)
    text = re.sub(r"[^\w\s]", "", text)
    return text

# Precompute normalized questions for faster search
NQ = [normalize(x["question"]) for x in QA]

def best_match(query: str, top_k: int = 5):
    qn = normalize(query)
    if not qn:
        return []

    # 1) Exact normalized match first
    exact = [i for i, nq in enumerate(NQ) if nq == qn]
    if exact:
        return [(1.0, QA[exact[0]])]

    # 2) Contains match (fast heuristic)
    contains = [(0.92, QA[i]) for i, nq in enumerate(NQ) if qn in nq or nq in qn]
    if contains:
        return contains[:top_k]

    # 3) Fuzzy match
    scored = []
    for i, nq in enumerate(NQ):
        s = SequenceMatcher(None, qn, nq).ratio()
        if s >= 0.55:  # threshold
            scored.append((s, QA[i]))
    scored.sort(key=lambda x: x[0], reverse=True)
    return scored[:top_k]

# ----------------------------
# UI widgets
# ----------------------------
title = widgets.HTML("<h3>Climate Change Q&A Search</h3>")

question_box = widgets.Text(
    value="",
    placeholder="Type your question here...",
    description="Question:",
    layout=widgets.Layout(width="90%"),
)

search_btn = widgets.Button(
    description="Find Answer",
    button_style="success",
    tooltip="Search the dataset for the best matching question",
)

topk_slider = widgets.IntSlider(
    value=3,
    min=1,
    max=10,
    step=1,
    description="Top-K:",
    continuous_update=False,
)

output = widgets.Output()

def on_search(_):
    with output:
        clear_output()
        q = question_box.value.strip()
        if not q:
            display(Markdown("‚ö†Ô∏è Please type a question."))
            return

        matches = best_match(q, top_k=topk_slider.value)
        if not matches:
            display(Markdown("‚ùå No good match found in the dataset. Try rephrasing."))
            return

        best_score, best_item = matches[0]
        display(Markdown(f"### ‚úÖ Best Match (score: `{best_score:.3f}`)"))
        display(Markdown(f"**Dataset Question:** {best_item['question']}"))
        display(Markdown(f"**Answer:** {best_item['answer']}"))

        if len(matches) > 1:
            display(Markdown("---\n### Other close matches"))
            for s, item in matches[1:]:
                display(Markdown(f"- score `{s:.3f}` ‚Äî **Q:** {item['question']}"))

search_btn.on_click(on_search)

ui = widgets.VBox([
    title,
    widgets.HBox([question_box, search_btn]),
    topk_slider,
    output
])

display(ui)


VBox(children=(HTML(value='<h3>Climate Change Q&A Search</h3>'), HBox(children=(Text(value='', description='Qu‚Ä¶

In [8]:
# ============================================================
# Kaggle UI: Textbox Q -> Dataset search (exact+fuzzy+semantic) -> Answer
# NEW:
#  - Confidence label (High/Medium/Low) based on score + method
#  - Toggle: show Dataset answer, GPT answer, or BOTH side-by-side
# Dataset: /kaggle/input/climate-change-qa-enlarged/climate_change_qa_enlarged.json
# Optional fine-tuned GPT-2: ./fine_tuned_gpt2
# ============================================================

import os
import json
import re
import time
import math
import random
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Optional
from difflib import SequenceMatcher

import torch
import torch.nn.functional as F

import ipywidgets as widgets
from IPython.display import display, Markdown, clear_output

from transformers import AutoTokenizer, AutoModel, GPT2LMHeadModel, GPT2Tokenizer

# ----------------------------
# Config
# ----------------------------
DATA_PATH = "/kaggle/input/climate-change-qa-enlarged/climate_change_qa_enlarged.json"
FINETUNED_DIR = "./fine_tuned_gpt2"

EMBED_MODEL_NAME = "distilbert-base-uncased"
EMBED_MAX_LEN = 192
EMBED_BATCH = 32

FUZZY_THRESHOLD = 0.62
SEMANTIC_THRESHOLD_DEFAULT = 0.45
TOPK_DEFAULT = 3

# ----------------------------
# Device helpers (robust)
# ----------------------------
def choose_device() -> torch.device:
    if not torch.cuda.is_available():
        return torch.device("cpu")
    try:
        _ = (torch.tensor([1.0], device="cuda") * 2.0)
        torch.cuda.synchronize()
        return torch.device("cuda")
    except Exception:
        return torch.device("cpu")

DEVICE = choose_device()

# ----------------------------
# Dataset
# ----------------------------
def load_qa(path: str) -> List[Dict[str, str]]:
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    if not isinstance(data, list) or not data:
        raise ValueError("Dataset JSON must be a non-empty list of objects with keys {question, answer}.")
    for i, item in enumerate(data[:5]):
        if "question" not in item or "answer" not in item:
            raise ValueError(f"Dataset item {i} missing 'question' or 'answer'.")
    return data

QA: List[Dict[str, str]] = load_qa(DATA_PATH)

def normalize(text: str) -> str:
    text = text.strip().lower()
    text = re.sub(r"\s+", " ", text)
    text = re.sub(r"[^\w\s]", "", text)
    return text

NQ = [normalize(x["question"]) for x in QA]

# ----------------------------
# Confidence labeling
# ----------------------------
def confidence_label(score: float, method: str) -> str:
    """
    Heuristic confidence:
    - exact: always High
    - contains: High if >=0.90 else Medium
    - semantic: High >=0.65, Medium >=0.52, else Low
    - fuzzy: High >=0.85, Medium >=0.72, else Low
    """
    m = method.lower()
    if m == "exact":
        return "High"
    if m == "contains":
        return "High" if score >= 0.90 else "Medium"
    if m == "semantic":
        if score >= 0.65:
            return "High"
        if score >= 0.52:
            return "Medium"
        return "Low"
    if m == "fuzzy":
        if score >= 0.85:
            return "High"
        if score >= 0.72:
            return "Medium"
        return "Low"
    return "Low"

# ----------------------------
# Lexical search
# ----------------------------
def search_lexical(query: str, top_k: int, allow_fuzzy: bool) -> List[Tuple[float, Dict[str, str], str]]:
    qn = normalize(query)
    if not qn:
        return []

    # exact normalized
    for i, nq in enumerate(NQ):
        if nq == qn:
            return [(1.0, QA[i], "exact")]

    # contains heuristic
    contains = []
    for i, nq in enumerate(NQ):
        if qn in nq or nq in qn:
            contains.append((0.92, QA[i], "contains"))
    if contains:
        return contains[:top_k]

    if not allow_fuzzy:
        return []

    # fuzzy
    scored = []
    for i, nq in enumerate(NQ):
        s = SequenceMatcher(None, qn, nq).ratio()
        if s >= FUZZY_THRESHOLD:
            scored.append((s, QA[i], "fuzzy"))
    scored.sort(key=lambda x: x[0], reverse=True)
    return scored[:top_k]

# ----------------------------
# Semantic search (encoder embeddings)
# ----------------------------
@dataclass
class SemanticIndex:
    tokenizer: Any
    model: Any
    device: torch.device
    embeddings: Optional[torch.Tensor] = None
    built: bool = False

    @staticmethod
    def mean_pool(last_hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        mask = attention_mask.unsqueeze(-1).type_as(last_hidden)
        summed = (last_hidden * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp(min=1e-9)
        return summed / denom

    @torch.no_grad()
    def encode_texts(self, texts: List[str], batch_size: int = 32, max_length: int = 192) -> torch.Tensor:
        all_vecs = []
        self.model.eval()

        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            enc = self.tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=max_length,
                return_tensors="pt",
            )
            enc = {k: v.to(self.device) for k, v in enc.items()}
            out = self.model(**enc)
            vec = self.mean_pool(out.last_hidden_state, enc["attention_mask"])
            vec = F.normalize(vec, p=2, dim=1)
            all_vecs.append(vec.cpu())

        return torch.cat(all_vecs, dim=0)

    def build(self, questions: List[str], batch_size: int = 32, max_length: int = 192) -> None:
        t0 = time.time()
        self.embeddings = self.encode_texts(questions, batch_size=batch_size, max_length=max_length)
        self.built = True
        print(f"‚úÖ Semantic index built: {self.embeddings.shape} in {time.time()-t0:.1f}s")

    @torch.no_grad()
    def query(self, text: str, top_k: int = 5, max_length: int = 192) -> List[Tuple[float, int]]:
        if not self.built or self.embeddings is None:
            return []
        q_vec = self.encode_texts([text], batch_size=1, max_length=max_length)[0]
        sims = (self.embeddings @ q_vec.unsqueeze(1)).squeeze(1)
        vals, idx = torch.topk(sims, k=min(top_k, sims.shape[0]))
        return [(float(v.item()), int(i.item())) for v, i in zip(vals, idx)]

SEM = SemanticIndex(
    tokenizer=AutoTokenizer.from_pretrained(EMBED_MODEL_NAME),
    model=AutoModel.from_pretrained(EMBED_MODEL_NAME).to(DEVICE),
    device=DEVICE,
)

# ----------------------------
# GPT fallback generator
# ----------------------------
@dataclass
class Generator:
    tokenizer: Any
    model: Any
    device: torch.device

    @staticmethod
    def load(device: torch.device) -> "Generator":
        model_dir = FINETUNED_DIR if os.path.isdir(FINETUNED_DIR) else "gpt2"
        tok = GPT2Tokenizer.from_pretrained(model_dir)
        tok.pad_token = tok.eos_token
        mdl = GPT2LMHeadModel.from_pretrained(model_dir).to(device)
        mdl.eval()
        return Generator(tok, mdl, device)

    @torch.no_grad()
    def answer(self, question: str, max_new_tokens: int = 140) -> str:
        prompt = f"### Question:\n{question}\n\n### Answer:\n"
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        out = self.model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_p=0.9,
            temperature=0.8,
            repetition_penalty=1.15,
            pad_token_id=self.tokenizer.eos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        )
        text = self.tokenizer.decode(out[0], skip_special_tokens=True)
        return text.split("### Answer:", 1)[1].strip() if "### Answer:" in text else text.strip()

GEN: Optional[Generator] = None

def safe_load_generator() -> Generator:
    global GEN
    if GEN is not None:
        return GEN
    try:
        GEN = Generator.load(DEVICE)
        _ = GEN.answer("Test?", max_new_tokens=5)
        print(f"‚úÖ Generator loaded on {DEVICE} (from {'./fine_tuned_gpt2' if os.path.isdir(FINETUNED_DIR) else 'gpt2'})")
        return GEN
    except Exception as e:
        print("‚ö†Ô∏è Generator failed on", DEVICE, "-> fallback CPU:", repr(e))
        GEN = Generator.load(torch.device("cpu"))
        print("‚úÖ Generator loaded on CPU")
        return GEN

# ----------------------------
# Combined search
# ----------------------------
def combined_search(query: str, top_k: int, use_semantic: bool, use_fuzzy: bool, semantic_threshold: float) -> List[Tuple[float, Dict[str, str], str]]:
    hits: List[Tuple[float, Dict[str, str], str]] = []
    hits.extend(search_lexical(query, top_k=top_k, allow_fuzzy=use_fuzzy))

    if use_semantic:
        if not SEM.built:
            SEM.build([x["question"] for x in QA], batch_size=EMBED_BATCH, max_length=EMBED_MAX_LEN)

        sem_hits = SEM.query(query, top_k=max(top_k, 8), max_length=EMBED_MAX_LEN)
        for score, idx in sem_hits:
            if score >= semantic_threshold:
                hits.append((score, QA[idx], "semantic"))

    # de-duplicate by question
    seen = set()
    deduped = []
    for s, item, m in sorted(hits, key=lambda x: x[0], reverse=True):
        key = normalize(item["question"])
        if key in seen:
            continue
        seen.add(key)
        deduped.append((s, item, m))
    return deduped[:top_k]

# ----------------------------
# UI widgets
# ----------------------------
title = widgets.HTML("<h3>Climate Change Q&A ‚Äî Dataset Search + Confidence + GPT Fallback</h3>")

question_box = widgets.Text(
    value="",
    placeholder="Type your question here...",
    description="Question:",
    layout=widgets.Layout(width="78%"),
)

search_btn = widgets.Button(
    description="Find Answer",
    button_style="success",
)

build_index_btn = widgets.Button(
    description="Build Semantic Index Now",
    button_style="info",
)

topk_slider = widgets.IntSlider(
    value=TOPK_DEFAULT,
    min=1, max=10, step=1,
    description="Top-K:",
    continuous_update=False,
)

min_score_slider = widgets.FloatSlider(
    value=SEMANTIC_THRESHOLD_DEFAULT,
    min=0.20, max=0.80, step=0.01,
    description="Min score:",
    continuous_update=False,
)

use_semantic_cb = widgets.Checkbox(value=True, description="Semantic search")
use_fuzzy_cb = widgets.Checkbox(value=True, description="Fuzzy search")
gen_fallback_cb = widgets.Checkbox(value=True, description="GPT fallback if no match")

# NEW: Output mode (Dataset / GPT / Both)
mode_dd = widgets.Dropdown(
    options=["Dataset only", "GPT only", "Both (compare)"],
    value="Both (compare)",
    description="Show:",
)

output = widgets.Output()

def render_dataset_result(score: float, item: Dict[str, str], method: str) -> None:
    conf = confidence_label(score, method)
    display(Markdown(f"### üìö Dataset Answer ({method}, score `{score:.3f}`, confidence **{conf}**)"))
    display(Markdown(f"**Matched Question:** {item['question']}"))
    display(Markdown(f"**Answer:** {item['answer']}"))

def render_gpt_result(question: str) -> None:
    gen = safe_load_generator()
    ans = gen.answer(question, max_new_tokens=160)
    display(Markdown("### ü§ñ GPT Answer (generated)"))
    display(Markdown(f"**Answer:** {ans}"))

def on_build_index(_):
    with output:
        clear_output()
        if SEM.built:
            display(Markdown("‚úÖ Semantic index already built."))
            return
        display(Markdown("‚è≥ Building semantic index..."))
        SEM.build([x["question"] for x in QA], batch_size=EMBED_BATCH, max_length=EMBED_MAX_LEN)
        display(Markdown("‚úÖ Done."))

def on_search(_):
    with output:
        clear_output()

        q = question_box.value.strip()
        if not q:
            display(Markdown("‚ö†Ô∏è Please type a question."))
            return

        semantic_threshold = float(min_score_slider.value)

        hits = combined_search(
            q,
            top_k=int(topk_slider.value),
            use_semantic=bool(use_semantic_cb.value),
            use_fuzzy=bool(use_fuzzy_cb.value),
            semantic_threshold=semantic_threshold,
        )

        show_mode = mode_dd.value

        # If BOTH: show dataset best (if exists) + GPT answer
        if show_mode == "Both (compare)":
            if hits:
                best_score, best_item, best_method = hits[0]
                render_dataset_result(best_score, best_item, best_method)
            else:
                display(Markdown("### üìö Dataset Answer"))
                display(Markdown("‚ùå No dataset match found under current thresholds."))

            display(Markdown("---"))
            render_gpt_result(q)

            if hits and len(hits) > 1:
                display(Markdown("---\n### Other dataset matches"))
                for s, item, m in hits[1:]:
                    conf = confidence_label(s, m)
                    display(Markdown(f"- `{m}` score `{s:.3f}` conf **{conf}** ‚Äî **Q:** {item['question']}"))
            return

        # Dataset only
        if show_mode == "Dataset only":
            if hits:
                best_score, best_item, best_method = hits[0]
                render_dataset_result(best_score, best_item, best_method)

                if len(hits) > 1:
                    display(Markdown("---\n### Other matches"))
                    for s, item, m in hits[1:]:
                        conf = confidence_label(s, m)
                        display(Markdown(f"- `{m}` score `{s:.3f}` conf **{conf}** ‚Äî **Q:** {item['question']}"))
            else:
                display(Markdown("‚ùå No dataset match found under current thresholds."))
                if gen_fallback_cb.value:
                    display(Markdown("---"))
                    display(Markdown("Fallback enabled, but you chose Dataset only display. Switch 'Show' to GPT/Both."))
            return

        # GPT only
        if show_mode == "GPT only":
            render_gpt_result(q)
            if hits:
                best_score, best_item, best_method = hits[0]
                conf = confidence_label(best_score, best_method)
                display(Markdown("---\n### (FYI) Best dataset match"))
                display(Markdown(f"`{best_method}` score `{best_score:.3f}` conf **{conf}** ‚Äî **Q:** {best_item['question']}"))
            return

search_btn.on_click(on_search)
build_index_btn.on_click(on_build_index)

ui = widgets.VBox([
    title,
    widgets.HBox([question_box, search_btn]),
    widgets.HBox([mode_dd, topk_slider, min_score_slider]),
    widgets.HBox([use_semantic_cb, use_fuzzy_cb, gen_fallback_cb]),
    build_index_btn,
    output,
])

display(ui)


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

VBox(children=(HTML(value='<h3>Climate Change Q&A ‚Äî Dataset Search + Confidence + GPT Fallback</h3>'), HBox(ch‚Ä¶