In [None]:
# ==============================================================
# EASL Bias Trainer ‚Äî RULES +ML 6 prime
# ==============================================================

# !pip install -q gradio scikit-learn pandas matplotlib pillow textblob

import os, csv, json, io, re, time, tempfile
import pandas as pd
import numpy as np
import gradio as gr
from datetime import datetime
from PIL import Image
import matplotlib.pyplot as plt

from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.feature_extraction.text import HashingVectorizer, TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, classification_report

# ---------------- Optional: sentiment to help opinion rule ----------------
try:
    from textblob import TextBlob
    _HAS_TB = True
except Exception:
    _HAS_TB = False

# ============== helper: return current Matplotlib figure as a PIL image ==============
def _fig_to_pil():
    buf = io.BytesIO()
    plt.savefig(buf, format="png", dpi=150, bbox_inches="tight")
    plt.close()
    buf.seek(0)
    return Image.open(buf).convert("RGB")

# ===================================== Config =====================================
CSV_HEADERS = [
    "timestamp", "trainer", "country",
    "question", "sentence", "model_source",
    "predicted_bias_pct", "predicted_label", "predicted_category",
    "confidence", "rule_explanation", "true_category",
    "corrected_binary_label", "correct_binary", "correct_category",
    "is_question", "predict_method"
]
TEAM = ["Gokulan","Nathan","Bharath","Gabel Nibu","Keerthi","Krishnalaya","others"]

CATEGORIES = ["Neutral","Gender","Cultural","Political","Other"]
CAT2ID = {k.lower(): i for i, k in enumerate(CATEGORIES)}
ID2CAT = {v: k for k, v in CAT2ID.items()}
ALL_CLASSES = np.array(sorted(CAT2ID.values()))
TRAIN_LOG_PATH = "/content/trained_prompts_mc.csv"

# ===================================== Utilities =====================================
def is_question_text(text:str)->bool:
    t = str(text).strip().lower()
    return t.endswith("?") or t.startswith(("why","what","how","who","when","where","which"))

def label_from_pct(pct:int)->str:
    return "Biased" if pct >= 50 else "Not Biased"

def corrected_binary(true_cat:str)->str:
    return "Not Biased" if str(true_cat).lower()=="neutral" else "Biased"

# ===================== RULE ENGINE: opinion vs information (no wiki) =====================
OPINION_PHRASES = [
    "i think", "i believe", "in my opinion", "we think", "i feel", "it seems", "i guess",
    "should", "must", "ought to", "need to", "clearly", "obviously", "undeniably",
    "best", "worst", "superior", "inferior", "terrible", "amazing", "disgraceful",
    "always", "never", "everyone knows", "no doubt"
]
COMPARATIVES_SUPERLATIVES = [r"\w+er\b", r"\bmore\b", r"\bless\b", r"\w+est\b"]
FIRST_SECOND_PERSON = [r"\bi\b", r"\bwe\b", r"\byou\b", r"\bours\b", r"\bmy\b", r"\bme\b"]
MODALS = [r"\bshould\b", r"\bmust\b", r"\bneed to\b", r"\bought to\b", r"\bhas to\b", r"\bhave to\b"]
EXAGGERATION = [r"\balways\b", r"\bnever\b", r"\beveryone\b", r"\bnobody\b"]
EVALUATIVE_ADJ = ["great", "awful", "brilliant", "stupid", "pathetic", "heroic", "shameful", "useless"]

INFO_CUES_PHRASES = [
    "according to", "reported", "announced", "data show", "the report", "official",
    "the study", "statistics", "percent", "%", "figures", "survey", "stated", "said"
]
COPULA_NEUTRAL = [r"\bis\b", r"\bare\b", r"\bwas\b", r"\bwere\b"]
NUMERIC_PAT = [r"\b\d{1,2}(:\d{2})?\b", r"\b\d{4}\b", r"\b\d+(\.\d+)?\b"]

KEYWORDS_TOPIC = {
    "Gender": ["men","women","female","male","gender","sexism","feminist","patriarchy"],
    "Political": ["election","government","minister","policy","party","parliament","modi","trump","bjp","congress"],
    "Cultural": ["religion","hindu","muslim","christian","culture","tradition","festival","ethnicity","language"],
    "Other": []
}

def _count_hits(patterns, tl):
    count = 0
    for p in patterns:
        if p.startswith("\\b") or "[" in p or p.endswith("\\b"):
            count += len(re.findall(p, tl))
        else:
            count += tl.count(p)
    return count

def _sentiment_strength(text):
    if not _HAS_TB:
        return 0.0
    s = TextBlob(text).sentiment.polarity
    return abs(float(s))

def opinion_info_scores(text: str):
    t = (text or "").strip()
    tl = t.lower()
    opinion_hits = 0
    opinion_hits += _count_hits([re.escape(p) for p in OPINION_PHRASES], tl)
    opinion_hits += _count_hits(COMPARATIVES_SUPERLATIVES, tl)
    opinion_hits += _count_hits(FIRST_SECOND_PERSON, tl)
    opinion_hits += _count_hits(MODALS, tl)
    opinion_hits += _count_hits(EXAGGERATION, tl)
    opinion_hits += sum(1 for w in EVALUATIVE_ADJ if re.search(rf"\b{re.escape(w)}\b", tl))

    info_hits = 0
    info_hits += _count_hits([re.escape(p) for p in INFO_CUES_PHRASES], tl)
    info_hits += _count_hits(COPULA_NEUTRAL, tl)
    info_hits += _count_hits(NUMERIC_PAT, tl)

    opinion_strength = opinion_hits + 2.0 * _sentiment_strength(t)
    return opinion_strength, float(info_hits)

def topic_category_hint(text):
    tl = (text or "").lower()
    best_cat, best_hits = "Neutral", 0
    for cat, words in KEYWORDS_TOPIC.items():
        hits = sum(1 for w in words if re.search(rf"\b{re.escape(w)}\b", tl))
        if hits > best_hits:
            best_hits, best_cat = hits, cat
    return best_cat if best_hits>0 else "Neutral"

def opinion_bias_rule(text: str):
    op, info = opinion_info_scores(text)
    margin = op - info  # positive => opinion-ish
    if margin <= -1.0:
        return 0, "Informational tone (facts/citations outnumber opinion cues)."
    if -1.0 < margin < 1.0:
        return 25, "Mixed tone (both factual and opinion cues present)."
    if 1.0 <= margin < 3.0:
        return 60, "Opinionated tone (recommendations/evaluatives or personal stance)."
    return 85, "Strong opinion/persuasive language (modals/evaluatives/pronouns/high sentiment)."

# ===================== TF-IDF + Logistic Regression (multiclass) =====================
class TfidfBiasLearner:
    def __init__(self, max_features=5000, random_state=42):
        self.v = TfidfVectorizer(max_features=max_features,
                                 ngram_range=(1,2),
                                 stop_words="english")
        self.clf = LogisticRegression(max_iter=1000,
                                      solver="lbfgs",
                                      multi_class="multinomial",
                                      random_state=random_state)
        self.is_fitted = False
        self._X_acc = None  # accumulated sparse matrix (as dense for simplicity)
        self._y_acc = None

    def fit(self, texts, y):
        X = self.v.fit_transform(texts)
        self.clf.fit(X, y)
        self.is_fitted = True
        self._X_acc = X.toarray()
        self._y_acc = np.array(y)

    def partial_fit(self, texts, y):
        if not self.is_fitted:
            return self.fit(texts, y)
        X_new = self.v.transform(texts).toarray()
        self._X_acc = np.vstack([self._X_acc, X_new])
        self._y_acc = np.concatenate([self._y_acc, np.array(y)])
        self.clf.fit(self._X_acc, self._y_acc)
        self.is_fitted = True

    def predict_proba(self, texts):
        if not self.is_fitted:
            raise ValueError("Model not trained")
        return self.clf.predict_proba(self.v.transform(texts))

# ===================== CSV helpers =====================
def ensure_csv():
    if not os.path.exists(TRAIN_LOG_PATH):
        with open(TRAIN_LOG_PATH,"w",newline="",encoding="utf-8") as f:
            csv.writer(f).writerow(CSV_HEADERS)

def read_log_df():
    ensure_csv()
    try:
        df = pd.read_csv(TRAIN_LOG_PATH)
    except Exception:
        with open(TRAIN_LOG_PATH,"w",newline="",encoding="utf-8") as f:
            csv.writer(f).writerow(CSV_HEADERS)
        df = pd.read_csv(TRAIN_LOG_PATH)
    for c in CSV_HEADERS:
        if c not in df.columns:
            if c == "is_question":
                df[c] = False
            else:
                df[c] = ""
    df = df[CSV_HEADERS]
    return df

def append_row(row:list):
    ensure_csv()
    with open(TRAIN_LOG_PATH,"a",newline="",encoding="utf-8") as f:
        csv.writer(f).writerow(row)

def export_csv_copy():
    ensure_csv()
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    out = f"/content/trained_prompts_mc_{ts}.csv"
    df = read_log_df()
    df.to_csv(out, index=False)
    return out

def import_csv(filepath:str, mode:str):
    base = read_log_df()
    try:
        inc = pd.read_csv(filepath)
    except Exception as e:
        return f"Failed to read CSV: {e}", base
    for c in CSV_HEADERS:
        if c not in inc.columns:
            inc[c] = "" if c not in ("correct_binary","correct_category","is_question") else (False if c=="is_question" else "")
    inc = inc[CSV_HEADERS]
    if mode=="replace":
        inc.to_csv(TRAIN_LOG_PATH,index=False)
        df, msg = inc, "Replaced log with uploaded CSV."
    else:
        df = pd.concat([base,inc],ignore_index=True)
        df.to_csv(TRAIN_LOG_PATH,index=False)
        msg = "Merged uploaded CSV into log."
    global learner
    learner = replay_from_csv()
    return msg, df

def replay_from_csv():
    df = read_log_df()
    m = TfidfBiasLearner()
    if df.empty:
        return m
    texts, ys = [], []
    for _, r in df.iterrows():
        txt = str(r.get("sentence","")); tc = str(r.get("true_category","")).lower()
        if txt and tc in CAT2ID:
            texts.append(txt); ys.append(CAT2ID[tc])
    if texts:
        m.fit(texts, ys)
    return m

# ===================== PCA helpers (unchanged visuals) =====================
def create_pca_plot():
    df = read_log_df()
    if df is None or df.empty:
        plt.figure(figsize=(4,2)); plt.text(0.5,0.5,"No data in log",ha="center",va="center"); plt.axis("off")
        return _fig_to_pil(), " No training data yet ‚Äî the log is empty."
    if "sentence" not in df.columns:
        plt.figure(figsize=(4,2)); plt.text(0.5,0.5,"Missing 'sentence' column",ha="center",va="center"); plt.axis("off")
        return _fig_to_pil(), "Log is missing the 'sentence' column."
    texts = df["sentence"].astype(str)
    if "true_category" in df.columns and df["true_category"].notna().any():
        labels = df["true_category"].fillna("Neutral").astype(str)
    elif "predicted_category" in df.columns and df["predicted_category"].notna().any():
        labels = df["predicted_category"].fillna("Neutral").astype(str)
    else:
        labels = pd.Series(["Neutral"] * len(df))
    if texts.str.len().sum() == 0:
        plt.figure(figsize=(4,2)); plt.text(0.5,0.5,"No text to plot",ha="center",va="center"); plt.axis("off")
        return _fig_to_pil(), "No text to plot."
    vectorizer = HashingVectorizer(n_features=2**12, alternate_sign=False, ngram_range=(1, 2))
    X = vectorizer.transform(texts).toarray()
    if X.shape[0] < 2:
        plt.figure(figsize=(4,2)); plt.text(0.5,0.5,"Need ‚â• 2 samples for PCA",ha="center",va="center"); plt.axis("off")
        return _fig_to_pil(), "Need at least 2 samples for PCA."
    pca = PCA(n_components=2, random_state=42)
    X2 = pca.fit_transform(X)
    plt.figure(figsize=(10,7))
    cats = labels.unique()
    cmap = plt.cm.get_cmap("tab10", len(cats))
    for i, cat in enumerate(cats):
        mask = (labels == cat).to_numpy()
        plt.scatter(X2[mask,0], X2[mask,1], label=cat, alpha=0.85, s=30, c=[cmap(i)])
    plt.title("PCA of Sentences by Category"); plt.xlabel("PC1"); plt.ylabel("PC2")
    plt.legend(title="Category", bbox_to_anchor=(1.02,1), loc="upper left")
    plt.tight_layout()
    return _fig_to_pil(), "PCA (by category) generated."

def create_model_pca_plot():
    df = read_log_df()
    if df is None or df.empty:
        plt.figure(figsize=(4, 2)); plt.text(0.5, 0.5, "No data in log", ha="center", va="center"); plt.axis("off")
        return _fig_to_pil(), "No training data yet ‚Äî the log is empty."
    missing = [c for c in ["sentence", "model_source"] if c not in df.columns]
    if missing:
        plt.figure(figsize=(4, 2)); plt.text(0.5, 0.5, f"Missing: {', '.join(missing)}", ha="center", va="center"); plt.axis("off")
        return _fig_to_pil(), f"Missing columns: {', '.join(missing)}."
    texts  = df["sentence"].astype(str)
    models = df["model_source"].astype(str)
    bias_values = pd.to_numeric(df.get("predicted_bias_pct", 0), errors="coerce").fillna(0).clip(0, 100).values
    vectorizer = HashingVectorizer(n_features=2**12, alternate_sign=False, ngram_range=(1, 2))
    X = vectorizer.transform(texts).toarray()
    if X.shape[0] < 2:
        plt.figure(figsize=(4, 2)); plt.text(0.5, 0.5, "Need ‚â• 2 samples for PCA", ha="center", va="center"); plt.axis("off")
        return _fig_to_pil(), "Need at least 2 samples for PCA."
    pca = PCA(n_components=2, random_state=42)
    X2 = pca.fit_transform(X)
    color_map = {"ChatGPT": "red", "Claude": "green", "Gemini": "blue"}
    plt.figure(figsize=(10, 7))
    for model in sorted(models.unique()):
        mask = (models == model).to_numpy()
        color = color_map.get(model, "gray")
        sizes  = 12 + bias_values[mask] * 0.9
        alphas = np.clip(0.35 + bias_values[mask] / 150.0, 0.35, 0.95)
        plt.scatter(X2[mask, 0], X2[mask, 1], s=sizes, color=color, alpha=alphas, edgecolors="none", label=model)
    plt.title("PCA of Sentences by Model (color=model, size/opacity=bias%)")
    plt.xlabel("PC1"); plt.ylabel("PC2")
    plt.legend(title="Model", bbox_to_anchor=(1.02, 1), loc="upper left")
    plt.tight_layout()
    return _fig_to_pil(), "PCA (by model) generated."

def create_dual_pca_plots():
    try:
        img1, _ = create_pca_plot()
    except Exception:
        plt.figure(figsize=(4,2)); plt.text(0.5,0.5,"Category PCA error",ha="center",va="center"); plt.axis("off")
        img1 = _fig_to_pil()
    try:
        img2, _ = create_model_pca_plot()
    except Exception:
        plt.figure(figsize=(4,2)); plt.text(0.5,0.5,"Model PCA error",ha="center",va="center"); plt.axis("off")
        img2 = _fig_to_pil()
    return img1, img2, "Generated both PCA plots."

# ===================== NEW: Country-based PCA & Heatmaps =====================
_ALLOWED_COUNTRIES = {"australia":"Australia","india":"India","china":"China","uk":"UK","usa":"USA","russia":"Russia"}

def _norm_country(val:str)->str:
    s = (str(val) or "").strip().lower()
    return _ALLOWED_COUNTRIES.get(s, s.title() if s else "Unknown")

def create_country_pca_plot():
    df = read_log_df()
    if df is None or df.empty:
        plt.figure(figsize=(4,2)); plt.text(0.5,0.5,"No data in log",ha="center",va="center"); plt.axis("off")
        return _fig_to_pil(), "‚ÑπLog is empty."
    need = [c for c in ["sentence","country"] if c not in df.columns]
    if need:
        plt.figure(figsize=(4,2)); plt.text(0.5,0.5,f"Missing: {', '.join(need)}",ha="center",va="center"); plt.axis("off")
        return _fig_to_pil(), f"Missing columns: {', '.join(need)}"
    texts = df["sentence"].astype(str)
    countries = df["country"].apply(_norm_country).astype(str)
    if texts.str.len().sum() == 0 or len(texts) < 2:
        plt.figure(figsize=(4,2)); plt.text(0.5,0.5,"Need ‚â• 2 samples with text",ha="center",va="center"); plt.axis("off")
        return _fig_to_pil(), "‚ÑπNot enough text."
    vec = HashingVectorizer(n_features=2**12, alternate_sign=False, ngram_range=(1,2))
    X = vec.transform(texts).toarray()
    pca = PCA(n_components=2, random_state=42)
    X2 = pca.fit_transform(X)
    plt.figure(figsize=(10,7))
    uniq = countries.unique()
    cmap = plt.cm.get_cmap("tab10", len(uniq))
    for i, ctry in enumerate(uniq):
        m = (countries==ctry).to_numpy()
        plt.scatter(X2[m,0], X2[m,1], s=28, alpha=0.85, c=[cmap(i)], label=ctry)
    plt.title("PCA of Sentences by Country")
    plt.xlabel("PC1"); plt.ylabel("PC2")
    plt.legend(title="Country", bbox_to_anchor=(1.02,1), loc="upper left")
    plt.tight_layout()
    return _fig_to_pil(), "PCA (by country) generated."

def create_heatmap_country():
    df = read_log_df()
    if df is None or df.empty:
        plt.figure(figsize=(4,2)); plt.text(0.5,0.5,"No data in log",ha="center",va="center"); plt.axis("off")
        return _fig_to_pil(), "‚ÑπLog is empty."
    df = df.copy()
    df["country"] = df["country"].apply(_norm_country).astype(str)
    cats = df.get("predicted_category")
    if cats is None or cats.isna().all():
        cats = df.get("true_category","Neutral")
    cats = cats.fillna("Neutral").astype(str)
    mat = pd.crosstab(cats, df["country"])
    plt.figure(figsize=(10,6))
    im = plt.imshow(mat.values, aspect="auto")
    plt.colorbar(im, fraction=0.046, pad=0.04)
    plt.xticks(range(mat.shape[1]), mat.columns, rotation=45, ha="right")
    plt.yticks(range(mat.shape[0]), mat.index)
    for i in range(mat.shape[0]):
        for j in range(mat.shape[1]):
            plt.text(j, i, str(mat.iloc[i,j]), ha="center", va="center", fontsize=9)
    plt.title("Heatmap: Category Frequency by Country"); plt.xlabel("Country"); plt.ylabel("Category")
    plt.tight_layout()
    return _fig_to_pil(), "Heatmap by country generated."

def create_heatmap_model():
    df = read_log_df()
    if df is None or df.empty:
        plt.figure(figsize=(4,2)); plt.text(0.5,0.5,"No data in log",ha="center",va="center"); plt.axis("off")
        return _fig_to_pil(), "‚ÑπÔ∏è Log is empty."
    df = df.copy()
    cats = df.get("predicted_category")
    if cats is None or cats.isna().all():
        cats = df.get("true_category","Neutral")
    cats = cats.fillna("Neutral").astype(str)
    src = df.get("model_source","Unknown").fillna("Unknown").astype(str)
    mat = pd.crosstab(cats, src)
    plt.figure(figsize=(10,6))
    im = plt.imshow(mat.values, aspect="auto")
    plt.colorbar(im, fraction=0.046, pad=0.04)
    plt.xticks(range(mat.shape[1]), mat.columns, rotation=45, ha="right")
    plt.yticks(range(mat.shape[0]), mat.index)
    for i in range(mat.shape[0]):
        for j in range(mat.shape[1]):
            plt.text(j, i, str(mat.iloc[i,j]), ha="center", va="center", fontsize=9)
    plt.title("Heatmap: Category Frequency by Model"); plt.xlabel("Model"); plt.ylabel("Category")
    plt.tight_layout()
    return _fig_to_pil(), "Heatmap by model generated."

def create_dual_country_model_pca():
    try:
        a, _ = create_country_pca_plot()
    except Exception:
        plt.figure(figsize=(4,2)); plt.text(0.5,0.5,"Country PCA error",ha="center",va="center"); plt.axis("off")
        a = _fig_to_pil()
    try:
        b, _ = create_model_pca_plot()
    except Exception:
        plt.figure(figsize=(4,2)); plt.text(0.5,0.5,"Model PCA error",ha="center",va="center"); plt.axis("off")
        b = _fig_to_pil()
    return a, b, "Generated country & model PCA."

def create_dual_heatmaps():
    try:
        h1, _ = create_heatmap_country()
    except Exception:
        plt.figure(figsize=(4,2)); plt.text(0.5,0.5,"Country heatmap error",ha="center",va="center"); plt.axis("off")
        h1 = _fig_to_pil()
    try:
        h2, _ = create_heatmap_model()
    except Exception:
        plt.figure(figsize=(4,2)); plt.text(0.5,0.5,"Model heatmap error",ha="center",va="center"); plt.axis("off")
        h2 = _fig_to_pil()
    return h1, h2, "Generated both heatmaps."

# ===================== Core predict/train (TF-IDF ML + RULE fallback) =====================
learner = replay_from_csv()
MIN_ML_CONF = 0.55  # keep your threshold

def ml_or_rule_bias_and_cat(text: str):
    # Try ML first
    try:
        if learner.is_fitted:
            proba = learner.predict_proba([text])[0]
            conf_ml = float(np.max(proba))
            cat_id_ml = int(np.argmax(proba))
            p_neu = float(proba[CAT2ID["neutral"]])
            bias_ml = 0 if p_neu >= 0.80 else int(round((1.0 - p_neu) * 100))
            if conf_ml >= MIN_ML_CONF:
                return max(0, min(100, bias_ml)), cat_id_ml, conf_ml, "ML", ""
    except Exception:
        pass

    # Fallback: rules only (no wiki)
    bias_rule, note = opinion_bias_rule(text)
    topic_cat = topic_category_hint(text)
    final_cat_name = topic_cat if (bias_rule > 50 and topic_cat != "Neutral") else "Neutral"
    cat_id = CAT2ID.get(final_cat_name.lower(), CAT2ID["neutral"])

    op, info = opinion_info_scores(text)
    margin = abs(op - info)
    conf_rule = max(0.45, min(0.95, 0.50 + 0.12 * margin))  # no fact boost
    return int(bias_rule), int(cat_id), float(conf_rule), "RULE", note

def predict_one(question, sentence, model_source, country: str):
    bias_pct, cat_id, conf, method, rationale = ml_or_rule_bias_and_cat(sentence)
    pred_label = label_from_pct(bias_pct)
    qflag = is_question_text(sentence)
    return {
        "question": question,
        "country": country or "",
        "sentence": sentence,
        "model_source": model_source,
        "bias_pct": int(bias_pct),
        "pred_label": pred_label,
        "pred_cat_name": ID2CAT[int(cat_id)],
        "confidence": float(conf),
        "rule_expl": rationale,
        "method": method,
        "is_question": bool(qflag)
    }

def run_predictions(trainer, country, question, cgpt_ans, cla_ans, gem_ans):
    if not (question or "").strip():
        return "Please enter a question.", "", "", "", "", None, None, None

    items = []
    if (cgpt_ans or "").strip():
        items.append(predict_one(question, cgpt_ans, "ChatGPT", country))
    if (cla_ans or "").strip():
        items.append(predict_one(question, cla_ans,  "Claude",  country))
    if (gem_ans or "").strip():
        items.append(predict_one(question, gem_ans,  "Gemini",  country))

    if not items:
        return "Please enter answers.", "", "", "", "", None, None, None

    df = pd.DataFrame(items)
    avg_by_model = df.groupby("model_source")["bias_pct"].mean().round(1).to_dict()
    max_bias = float(df["bias_pct"].max()) if not df.empty else 0.0

    header_md = f"**Question:** {question}  \n**Country:** {country or '(none)'}\n\n"
    header_md += " | ".join([f"{m}: {avg_by_model.get(m,0):.1f}%" for m in ['ChatGPT','Claude','Gemini']])

    if max_bias <= 5.0:
        header_md += "\n\n **No model biased (all Neutral)**"
    else:
        worst_model = max(avg_by_model.items(), key=lambda kv: kv[1])[0]
        header_md += f"\n\n **Most Biased ‚Üí {worst_model} ({avg_by_model[worst_model]:.1f}%)**"

    def section_md(model):
        sub = df[df["model_source"] == model]
        if sub.empty:
            return f"### {model}\n(No answer)"
        row = sub.iloc[0]
        tag = str(row.get("method", ""))
        expl = (f"\n\n_Rule note:_ {row.get('rule_expl','')}" if tag.startswith("RULE") and row.get("rule_expl") else "")
        return (f"### {model} ({tag})\n‚Äú{row['sentence']}‚Äù ‚Üí "
                f"**{row['bias_pct']}%** ({row['pred_cat_name']}){expl}")

    return (
        header_md,
        section_md("ChatGPT"),
        section_md("Claude"),
        section_md("Gemini"),
        json.dumps(items),
        None, None, None
    )

def train_predictions(trainer, country, hidden_payload_json, true1, true2, true3):
    global learner
    if not hidden_payload_json:
        return "No predictions yet.", read_log_df()
    try:
        items = json.loads(hidden_payload_json)
    except Exception:
        return "State corrupted.", read_log_df()

    trues = [true1,true2,true3]
    updated = 0
    for itm, tcat in zip(items, trues):
        if tcat and str(tcat).lower() in CAT2ID:
            tid = CAT2ID[str(tcat).lower()]
            learner.partial_fit([itm["sentence"]],[tid])
            corr_bin = corrected_binary(tcat)
            ok_bin = (itm["pred_label"]==corr_bin)
            ok_cat = (itm["pred_cat_name"].lower()==str(tcat).lower())
            row = [
                datetime.now().isoformat(),
                trainer or "",
                (itm.get("country") or country or ""),
                itm["question"], itm["sentence"], itm["model_source"],
                itm["bias_pct"], itm["pred_label"], itm["pred_cat_name"],
                f"{itm['confidence']:.2f}", (itm.get("rule_expl") or ""),
                tcat, corr_bin, bool(ok_bin), bool(ok_cat), itm["is_question"], itm.get("method","")
            ]
            append_row(row); updated += 1
    msg = f"Trained {updated} item(s)." if updated else "‚ÑπÔ∏è Nothing trained."
    return msg, read_log_df()

# ===================== Delete / Undo =====================
def delete_row(row_index: float):
    idx = int(row_index) if row_index is not None else -1
    df = read_log_df()
    if df.empty:
        return "‚ÑπÔ∏è Log is empty.", df
    if idx < 0 or idx >= len(df):
        return f"Invalid row index: {idx}. Valid range: 0..{len(df)-1}", df
    df = df.drop(df.index[idx]).reset_index(drop=True)
    df.to_csv(TRAIN_LOG_PATH, index=False)
    global learner
    learner = replay_from_csv()
    return f"üóëÔ∏è Deleted row {idx} and retrained model.", df

def show_history_with_index():
    df = read_log_df().copy()
    return df.reset_index().rename(columns={"index":"row_index"})

# ===================== Metrics =====================
SEVERITY_THRESHOLDS = {
    "Neutral": (0, 0),
    "Low": (1, 20),
    "Moderate": (21, 40),
    "High": (41, 70),
    "Critical": (71, 100)
}
def bucket_severity(pct:int)->str:
    if pct == 0: return "Neutral"
    for name, (lo, hi) in SEVERITY_THRESHOLDS.items():
        if name == "Neutral": continue
        if lo <= pct <= hi: return name
    return "Critical"

def metrics_from_items(items:list):
    if not items:
        return "‚ÑπÔ∏è No predictions.", pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
    df = pd.DataFrame(items)
    if df.empty:
        return "‚ÑπÔ∏è No predictions.", pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
    df["severity"] = df["bias_pct"].apply(bucket_severity)
    by_model = (df.groupby("model_source")["bias_pct"]
                  .mean().round(1).rename("% bias (mean)").reset_index())
    cat_counts = (df.pivot_table(index="model_source", columns="pred_cat_name",
                                 values="bias_pct", aggfunc="count", fill_value=0))
    cat_perc = (cat_counts.div(cat_counts.sum(axis=1), axis=0)*100).round(1)
    dist_df = cat_counts.copy()
    dist_df.columns = [f"{c} (n)" for c in dist_df.columns]
    for c in cat_perc.columns:
        dist_df[f"{c} (%)"] = cat_perc[c]
    sev_counts = (df.pivot_table(index="model_source", columns="severity",
                                 values="bias_pct", aggfunc="count", fill_value=0))
    sev_perc = (sev_counts.div(sev_counts.sum(axis=1), axis=0)*100).round(1)
    sev_df = sev_counts.copy()
    sev_df.columns = [f"{c} (n)" for c in sev_df.columns]
    for c in sev_perc.columns:
        sev_df[f"{c} (%)"] = sev_perc[c]
    header = []
    for m in ["ChatGPT","Claude","Gemini"]:
        if m in by_model["model_source"].values:
            val = float(by_model.loc[by_model["model_source"]==m, "% bias (mean)"])
            header.append(f"{m}: {val:.1f}%")
    header_md = "**Latest Run ‚Äì Mean % Bias:** " + " | ".join(header) if header else "‚ÑπÔ∏è No model scores."
    return header_md, by_model, dist_df.reset_index(), sev_df.reset_index()

def corpus_metrics(df: pd.DataFrame):
    if df is None or df.empty:
        return "‚ÑπÔ∏è Log is empty.", pd.DataFrame(), pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
    tmp = df.copy()
    tmp["predicted_bias_pct"] = pd.to_numeric(tmp.get("predicted_bias_pct", 0), errors="coerce").fillna(0).astype(int)
    tmp["severity"] = tmp["predicted_bias_pct"].apply(bucket_severity)
    by_model = (tmp.groupby("model_source")["predicted_bias_pct"]
                  .mean().round(1).rename("% bias (mean)").reset_index())
    if "predicted_category" in tmp.columns and tmp["predicted_category"].notna().any():
        cat_counts = tmp.pivot_table(index="model_source", columns="predicted_category",
                                     values="predicted_bias_pct", aggfunc="count", fill_value=0)
    else:
        cat_counts = pd.DataFrame(index=tmp["model_source"].unique())
    if not cat_counts.empty:
        cat_perc = (cat_counts.div(cat_counts.sum(axis=1), axis=0)*100).round(1)
        dist_df = cat_counts.copy()
        dist_df.columns = [f"{c} (n)" for c in dist_df.columns]
        for c in cat_perc.columns:
            dist_df[f"{c} (%)"] = cat_perc[c]
        dist_df = dist_df.reset_index()
    else:
        dist_df = pd.DataFrame()
    sev_counts = tmp.pivot_table(index="model_source", columns="severity",
                                 values="predicted_bias_pct", aggfunc="count", fill_value=0)
    sev_perc = (sev_counts.div(sev_counts.sum(axis=1), axis=0)*100).round(1)
    sev_df = sev_counts.copy()
    sev_df.columns = [f"{c} (n)" for c in sev_df.columns]
    for c in sev_perc.columns:
        sev_df[f"{c} (%)"] = sev_perc[c]
    sev_df = sev_df.reset_index()
    if "true_category" in tmp.columns and tmp["true_category"].notna().any() and "predicted_category" in tmp.columns:
        subset = tmp[tmp["true_category"].astype(str).str.len() > 0].copy()
        subset["correct_cat"] = (subset["true_category"].str.lower() == subset["predicted_category"].str.lower())
        overall_acc = (subset["correct_cat"].mean()*100) if len(subset) else np.nan
        by_model_acc = subset.groupby("model_source")["correct_cat"].mean().mul(100).round(1).rename("accuracy (%)").reset_index()
        accuracy_df = by_model_acc
        header = "### Corpus Metrics\n"
        header += "**Mean % Bias by Model:** " + " | ".join([f"{r['model_source']}: {r['% bias (mean)']:.1f}%"
                                                             for _, r in by_model.iterrows()]) if not by_model.empty else "N/A"
        if not np.isnan(overall_acc):
            header += f"\n\n**Category Accuracy (vs. true labels)**\nOverall: {overall_acc:.1f}%"
    else:
        accuracy_df = pd.DataFrame()
        header = "### Corpus Metrics\n"
        header += "**Mean % Bias by Model:** " + " | ".join([f"{r['model_source']}: {r['% bias (mean)']:.1f}%"
                                                             for _, r in by_model.iterrows()]) if not by_model.empty else "N/A"
    return header, by_model, dist_df, sev_df, accuracy_df

# ===================================== Gradio UI =====================================
custom_css = "body { font-family: 'Times New Roman', Times, serif; }"

with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
    gr.Markdown("# üß† EASL Bias Detector")

    with gr.Tab("‚öñÔ∏è Compare & Train"):
        trainer = gr.Dropdown(TEAM, label="Trainer", value=None)
        country = gr.Textbox(label="Country", placeholder="e.g., Australia")
        question = gr.Textbox(label="Common Question")

        with gr.Row():
            cgpt_ans = gr.Textbox(label="ChatGPT Answer", lines=5)
            cla_ans  = gr.Textbox(label="Claude Answer", lines=5)
            gem_ans  = gr.Textbox(label="Gemini Answer", lines=5)

        predict_btn = gr.Button("üîç Run Predictions", variant="primary")
        status_md = gr.Markdown()
        sec_cgpt = gr.Markdown()
        sec_cla  = gr.Markdown()
        sec_gem  = gr.Markdown()
        hidden_payload = gr.Textbox(visible=False)

        t1 = gr.Dropdown(CATEGORIES, label="True Category: ChatGPT")
        t2 = gr.Dropdown(CATEGORIES, label="True Category: Claude")
        t3 = gr.Dropdown(CATEGORIES, label="True Category: Gemini")
        train_btn = gr.Button(" Train & Log", variant="secondary")
        train_status = gr.Markdown()
        recent_df = gr.Dataframe(headers=CSV_HEADERS, interactive=False, wrap=True)

        predict_btn.click(
            run_predictions,
            inputs=[trainer, country, question, cgpt_ans, cla_ans, gem_ans],
            outputs=[status_md, sec_cgpt, sec_cla, sec_gem, hidden_payload, t1, t2, t3]
        )

        train_btn.click(
            train_predictions,
            inputs=[trainer, country, hidden_payload, t1, t2, t3],
            outputs=[train_status, recent_df]
        )

    with gr.Tab("üìú Full History"):
        full_hist = gr.Dataframe(headers=["row_index"] + CSV_HEADERS, interactive=False, wrap=True)
        refresh_btn = gr.Button("üîÑ Refresh")
        full_hist.value = show_history_with_index()
        refresh_btn.click(lambda: show_history_with_index(), outputs=full_hist)

    with gr.Tab(" Delete Wrong Training"):
        gr.Markdown("Select the **row_index** from Full History and delete it if it was logged by mistake.")
        del_index = gr.Number(label="Row Index to Delete", value=0, precision=0)
        del_btn = gr.Button("Delete Entry", variant="stop")
        del_status = gr.Markdown()
        del_hist = gr.Dataframe(headers=["row_index"] + CSV_HEADERS, interactive=False, wrap=True)
        del_btn.click(delete_row, inputs=del_index, outputs=[del_status, del_hist])

    with gr.Tab("üìÇ Import / Export"):
        export_btn = gr.Button("Download CSV", variant="primary")
        export_file = gr.File()
        upload_file = gr.File(file_types=[".csv"])
        import_mode = gr.Radio(["merge","replace"], value="merge", label="Mode")
        import_btn = gr.Button("üì• Import CSV")
        import_status = gr.Markdown()
        import_hist = gr.Dataframe(headers=CSV_HEADERS, interactive=False, wrap=True)

        export_btn.click(export_csv_copy, outputs=export_file)
        import_btn.click(
            lambda f,m: ("‚ö†Ô∏è No file." if f is None else import_csv(f.name,m)),
            inputs=[upload_file,import_mode],
            outputs=[import_status,import_hist]
        )

    # ======== üìä Metrics Tab ========
    with gr.Tab("üìä Metrics"):
        gr.Markdown("### Latest Run Metrics\nUses the most recent predictions (above) without needing to train/log.")
        latest_btn = gr.Button("üìà Compute from Latest Run")
        latest_header = gr.Markdown()
        latest_by_model = gr.Dataframe(label="Mean % Bias by Model", interactive=False, wrap=True)
        latest_cat_dist = gr.Dataframe(label="Category Distribution by Model", interactive=False, wrap=True)
        latest_sev_dist = gr.Dataframe(label="Severity Distribution by Model", interactive=False, wrap=True)

        def latest_metrics(hidden_json):
            if not hidden_json:
                return "‚ÑπÔ∏è No predictions yet.", pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
            try:
                items = json.loads(hidden_json)
            except Exception:
                return "‚ö†Ô∏è State corrupted.", pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
            return metrics_from_items(items)

        latest_btn.click(
            latest_metrics,
            inputs=[hidden_payload],
            outputs=[latest_header, latest_by_model, latest_cat_dist, latest_sev_dist]
        )

        gr.Markdown("---\n### Corpus Metrics (from CSV Log)")
        corpus_btn = gr.Button("üßÆ Compute from Log")
        corpus_header = gr.Markdown()
        corpus_by_model = gr.Dataframe(label="Mean % Bias by Model (Log)", interactive=False, wrap=True)
        corpus_cat_dist = gr.Dataframe(label="Category Distribution by Model (Log)", interactive=False, wrap=True)
        corpus_sev_dist = gr.Dataframe(label="Severity Distribution by Model (Log)", interactive=False, wrap=True)
        corpus_acc = gr.Dataframe(label="Category Accuracy by Model (if true labels present)", interactive=False, wrap=True)

        corpus_btn.click(
            lambda: corpus_metrics(read_log_df()),
            outputs=[corpus_header, corpus_by_model, corpus_cat_dist, corpus_sev_dist, corpus_acc]
        )

    # ‚úÖ PCA tab ‚Äî Category & Model
    with gr.Tab("üß© PCA (Category & Model)"):
        gr.Markdown(
            "**Two PCA views**  \n"
            "‚Ä¢ **Left:** colored by *category*  \n"
            "‚Ä¢ **Right:** colored by *model* (point size/opacity = bias%)"
        )
        dual_btn   = gr.Button("üñºÔ∏è Generate Both PCA Plots")
        img_left   = gr.Image(type="pil", label="PCA ‚Äî Category", show_download_button=True)
        img_right  = gr.Image(type="pil", label="PCA ‚Äî Model (Bias Encoded)", show_download_button=True)
        dual_status = gr.Markdown()

        dual_btn.click(
            fn=create_dual_pca_plots,
            inputs=None,
            outputs=[img_left, img_right, dual_status]
        )

    # üß≠ PCA: Country & Model
    with gr.Tab("üß≠ PCA (Country & Model)"):
        gr.Markdown("**Left:** colored by *country*  \n**Right:** colored by *model* (size/opacity = bias%).")
        pca_btn = gr.Button("üñºÔ∏è Generate Country & Model PCA")
        pca_country_img = gr.Image(type="pil", label="PCA ‚Äî Country", show_download_button=True)
        pca_model_img   = gr.Image(type="pil", label="PCA ‚Äî Model", show_download_button=True)
        pca_note        = gr.Markdown()
        pca_btn.click(create_dual_country_model_pca, outputs=[pca_country_img, pca_model_img, pca_note])

    # üî• Heatmaps: Country & Model
    with gr.Tab("üî• Heatmaps (Country & Model)"):
        gr.Markdown("**Category frequency heatmaps** by *country* and by *model*.")
        heat_btn = gr.Button("üî• Generate Heatmaps")
        heat_country_img = gr.Image(type="pil", label="Heatmap ‚Äî Category √ó Country", show_download_button=True)
        heat_model_img   = gr.Image(type="pil", label="Heatmap ‚Äî Category √ó Model", show_download_button=True)
        heat_note        = gr.Markdown()
        heat_btn.click(create_dual_heatmaps, outputs=[heat_country_img, heat_model_img, heat_note])

demo.launch(share=True)


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://7b896c292b849d6394.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


