# 📊🧪 Literature Screening – Multi‑class Model Evaluation
Aggregate results for **train** and **test** splits, report core metrics for the binary *Included vs Excluded* task, then evaluate the extra fields returned *only* for items predicted as **Included**.

We focus on the structured fields that are straightforward to score:
* **domain** → matches column **Social, Behavioural or Implementation Science?**
* **dmf_stage** → matches column **DMF - Identify the issue and its context, assess risks and benefits, identify and analyze options, select a strategy, implement the strategy, monitor and evaluate results, involve interested and affected parties**
* **decision_type** → matches column **DMF - Are the decisions regulatory, policy, or other? Please describe the “other” if applicable.**  
  For **decision_type** a prediction that *contains* the word **other** is considered correct when the ground‑truth field also contains **other** (ignoring any free‑text description).

In [None]:
# ╔════════════════════════════════════════════════╗
# ║ Cell 1 – Imports and helpers 🔌                ║
# ╚════════════════════════════════════════════════╝
import json
from pathlib import Path
from collections import defaultdict
from typing import Iterable

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
    classification_report,
)

from rapidfuzz import fuzz

sns.set(style="whitegrid")

import re
import unicodedata

# basic stop-word list (tune as needed)
_STOPWORDS: set[str] = {
    "the", "a", "an", "of", "and", "to", "in", "on", "for", "with",
    "at", "by", "from", "about", "as", "into", "that", "this",
}

_token_re = re.compile(r"[a-z0-9]+")


# ── text normalisation helpers ─────────────────── #
def _ascii_fold(text: str) -> str:
    """Transliterate accented characters → plain ASCII."""
    return (
        unicodedata.normalize("NFKD", text)
        .encode("ascii", "ignore")
        .decode("ascii")
    )


def _normalize(text: str) -> str:
    """
    Aggressive normalisation that ignores superficial differences:
      • accents, case, punctuation, apostrophes
      • extra whitespace / newlines
      • common stop-words
      • token order (tokens are deduped & sorted)
    Returns a single space-separated string.
    """
    if text is None or (isinstance(text, float) and pd.isna(text)):
        return ""

    txt = _ascii_fold(str(text).lower())
    txt = txt.replace("’", " ").replace("'", " ")

    tokens: Iterable[str] = _token_re.findall(txt)
    tokens = [t for t in tokens if t not in _STOPWORDS]

    # deduplicate and sort so order does not matter
    tokens = sorted(set(tokens))
    return " ".join(tokens)


# ── fuzzy comparison wrappers ──────────────────── #
def _fuzzy_equal(a: str, b: str, threshold: int = 90) -> bool:
    """
    Fuzzy equality using rapidfuzz (ratio 0-100).
    Falls back to strict equality if rapidfuzz not available.
    """
    if fuzz is None:
        return _normalize(a) == _normalize(b)
    return fuzz.ratio(_normalize(a), _normalize(b)) >= threshold


def _match_decision_type(pred: str, truth: str) -> bool:
    """
    Decision-type matching with relaxed “other” rule:
      • if ground truth mentions “other”, prediction is correct
        when it also contains “other” (after normalisation)
      • otherwise, use fuzzy equality
    """
    if pd.isna(truth):
        return False

    p_norm = _normalize(pred)
    t_norm = _normalize(truth)

    if "other" in t_norm and "other" in p_norm:
        return True
    return _fuzzy_equal(p_norm, t_norm, threshold=92)

def _match_identity(pred: str, truth: str) -> bool:
    """
    String matching with relaxed rules - used for bipoc, indigenous, sex, and gender fields 
    If the two strings contain "yes", return True.
    If the two strings contain "not reported", return True.
    Otherwise, compare the normalized strings.
    """
    truth_l = truth.lower() if isinstance(truth, str) else ""
    pred_l = pred.lower() if isinstance(pred, str) else ""

    if ("yes" in truth_l) and ("yes" in pred_l):
        return True
    elif ("not reported" in truth_l) and ("not reported" in pred_l):
        return True
    else:
        return(_normalize(pred) == _normalize(truth))

In [None]:
# ╔════════════════════════════════════════════════╗
# ║ Cell 2 – Locate outputs and datasets 🔍         ║
# ╚════════════════════════════════════════════════╝
notebook_dir = Path(__file__).parent if "__file__" in globals() else Path.cwd()
root_dir      = notebook_dir  # all_class_files/
outputs_root  = root_dir / "outputs"
datasets_dir  = root_dir.parent / "datasets"

if not outputs_root.exists():
    raise RuntimeError(f"Could not find outputs directory at: {outputs_root}")
if not datasets_dir.exists():
    raise RuntimeError(f"Could not find datasets directory at: {datasets_dir}")

# every sub‑folder inside outputs/ is a model name
model_dirs = [d for d in outputs_root.iterdir() if d.is_dir() and d.name != "datasets"]
if not model_dirs:
    raise RuntimeError(f"No model result folders found inside '{outputs_root}/'")

print("Models found:", ", ".join(d.name for d in model_dirs))

In [None]:
# ╔════════════════════════════════════════════════╗
# ║ Cell 3 – Load predictions and merge ground-truth 🗄️ ║
# ╚════════════════════════════════════════════════╝
# Structure: {model: {split: DataFrame}}
all_predictions = defaultdict(dict)

for mdir in model_dirs:
    model_name = mdir.name

    for split in ("train", "test"):
        preds_dir = mdir / split / "predictions"
        if not preds_dir.exists():
            continue

        # ----- ground-truth dataset ----- #
        csv_path = datasets_dir / f"{split}_dataset.csv"
        if not csv_path.exists():
            raise FileNotFoundError(csv_path)
        df_truth = pd.read_csv(csv_path)
        truth_cols = {
            "id": "id",
            "label": "ground_truth",
            "Social, Behavioural or Implementation Science?": "domain_gt",
            "DMF - Identify the issue and its context, assess risks and benefits, identify and analyze options, select a strategy, implement the strategy, monitor and evaluate results, involve interested and affected parties": "dmf_stage_gt",
            "DMF - Are the decisions regulatory, policy, or other? Please describe the “other” if applicable.": "decision_type_gt",
            "IS - Does your submission include or intersect with Black, Indigenous or racialized groups?": "bipoc_gt",
            "IS - Does your submission include or intersect with Indigenous Peoples?": "indigenous_gt",
            "IS - Have you included Sex in your study:": "sex_gt",
            "IS - Have you included Gender in your study?": "gender_gt",
            "IS - Have you considered identity factors other than sex and gender?": "identity_factors_gt",
        }
        df_truth = df_truth.rename(columns=truth_cols)[list(truth_cols.values())]

        # ----- predictions ----- #
        rows = []
        for jf in preds_dir.glob("*.json"):
            with open(jf, encoding="utf-8") as f:
                data = json.load(f)

            pred_block = data.get("prediction", {})
            x = {
                "id": jf.stem,
                "pred_class": pred_block.get("classification"),
                "raw_rationale": pred_block.get("classification_rationale"),
                "domain_pred": pred_block.get("domain"),
                "dmf_stage_pred": pred_block.get("dmf_stage"),
                "decision_type_pred": pred_block.get("decision_type"),
                "bipoc_pred": pred_block.get("BIPOC"),
                "indigenous_pred": pred_block.get("Indigenous"),
                "sex_pred": pred_block.get("Sex"),
                "gender_pred": pred_block.get("Gender"),
            }
            rows.append(x)

        if not rows:
            continue

        df_pred = pd.DataFrame(rows)
        df = pd.merge(df_pred, df_truth, on="id", how="left")
        all_predictions[model_name][split] = df

        unparsable = (df["pred_class"] == "ParseError").sum()
        print(f"{model_name} [{split}] -> {len(df):,} rows, {unparsable} unparsable")


In [None]:
# ╔════════════════════════════════════════════════╗
# ║ Cell 4 – Core binary‑classification metrics 📋 ║
# ╚════════════════════════════════════════════════╝
metrics_cls = []

for model, split_dict in all_predictions.items():
    for split, df in split_dict.items():
        parsable = df[df["pred_class"].isin(["Included", "Excluded"])]
        unparsed = len(df) - len(parsable)
        if parsable.empty:
            continue

        y_true = parsable["ground_truth"]
        y_pred = parsable["pred_class"]

        metrics_cls.append(
            {
                "model": model,
                "split": split,
                "n_total": len(df),
                "n_unparsed": unparsed,
                "accuracy": accuracy_score(y_true, y_pred),
                "precision": precision_score(y_true, y_pred, pos_label="Included"),
                "recall": recall_score(y_true, y_pred, pos_label="Included"),
                "f1": f1_score(y_true, y_pred, pos_label="Included"),
            }
        )

metrics_cls_df = (
    pd.DataFrame(metrics_cls)
    .set_index(["model", "split"])
    .sort_values(["model", "split"])
)
metrics_cls_df.style.format({"accuracy": "{:.3f}", "precision": "{:.3f}", "recall": "{:.3f}", "f1": "{:.3f}"})

In [None]:
# ╔════════════════════════════════════════════════╗
# ║ Cell 5 – Visualise binary metrics 📊           ║
# ╚════════════════════════════════════════════════╝
for split in ("train", "test"):
    subset = metrics_cls_df.xs(split, level="split")
    if subset.empty:
        continue

    fig, axes = plt.subplots(1, 3, figsize=(16, 4))
    subset["accuracy"].plot(kind="bar", ax=axes[0])
    axes[0].set_title(f"Accuracy ({split})")
    axes[0].set_ylim(0, 1)

    subset["f1"].plot(kind="bar", ax=axes[1])
    axes[1].set_title(f"F1‑score ({split})")
    axes[1].set_ylim(0, 1)

    (subset["n_unparsed"] / subset["n_total"]).plot(kind="bar", ax=axes[2])
    axes[2].set_title(f"Unparsed % ({split})")
    axes[2].set_ylim(0, 1)

    plt.suptitle(f"Model comparison on {split} split")
    plt.tight_layout()
    plt.show()

In [None]:
# ╔════════════════════════════════════════════════╗
# ║ Cell 6 – Extra-field scoring 🏷️               ║
# ╚════════════════════════════════════════════════╝
extra_metrics = []
fields = [
    ("domain_pred",       "domain_gt",       "domain"),
    ("dmf_stage_pred",    "dmf_stage_gt",    "dmf_stage"),
    ("decision_type_pred","decision_type_gt","decision_type"),
    ("bipoc_pred",        "bipoc_gt",        "bipoc"),
    ("indigenous_pred",   "indigenous_gt",   "indigenous"),
    ("sex_pred",          "sex_gt",          "sex"),
    ("gender_pred",       "gender_gt",       "gender")
]

for model, split_dict in all_predictions.items():
    for split, df in split_dict.items():

        # evaluate only items predicted as "Included"  
        include_rows = df["pred_class"] == "Included"

        # do not evaluate items for which ground truth label is 'Included' but other ground truth fields are None
        if fields:
            conditions = [df["ground_truth"] == "Included"]
            for field in fields:
                conditions.append(df[field[1]].isna())
            exclude_rows = pd.concat(conditions, axis=1).all(axis=1)
            df_inc = df[include_rows & ~exclude_rows].copy()
        else:
            df_inc = df[include_rows].copy()

        if df_inc.empty:
            continue

        for pcol, tcol, name in fields:
            if name == "decision_type":
                matches = [_match_decision_type(p, t) for p, t in zip(df_inc[pcol], df_inc[tcol])]
            elif name in ["bipoc", "indigenous", "sex", "gender"]:
                matches = [_match_identity(p, t) for p, t in zip(df_inc[pcol], df_inc[tcol])]
            else:
                matches = [_normalize(p) == _normalize(t) for p, t in zip(df_inc[pcol], df_inc[tcol])]

            acc = float(np.mean(matches)) if matches else float("nan")
    
            extra_metrics.append(
                {
                    "model"   : model,
                    "split"   : split,
                    "field"   : name,
                    "n_scored": int(len(matches)),
                    "accuracy": round(acc, 3),
                }
            )

# ---- pivot & pretty-print ---- #
a_extra = pd.DataFrame(extra_metrics)
if not a_extra.empty:
    pivot = a_extra.pivot(index=["model", "split"], columns="field", values="accuracy")
    pivot = pivot[["domain", "dmf_stage", "decision_type", "bipoc", "indigenous", "sex", "gender"]]
    display(pivot.style.format("{:.3f}"))
else:
    print("No Included predictions found, cannot score extra fields.")

In [None]:
# ╔════════════════════════════════════════════════╗
# ║ Cell 7 – Visualise extra-field accuracy 📊🏷️   ║
# ╚════════════════════════════════════════════════╝
if not a_extra.empty:
    for field in ["domain", "dmf_stage", "decision_type", "bipoc", "indigenous", "sex", "gender"]:
        fig, ax = plt.subplots(figsize=(9, 4))
        subset = a_extra[a_extra["field"] == field].pivot(
            index="model", columns="split", values="accuracy"
        )
        subset.plot(kind="bar", ax=ax, rot=0)  # one bar-group per model
        ax.set_title(f"{field} – accuracy by model and split")
        ax.set_ylabel("accuracy")
        ax.set_xlabel("model")
        ax.set_ylim(0, 1)
        ax.legend(title="split")
        plt.tight_layout()
        plt.show()
else:
    print("No extra-field metrics to plot.")


In [None]:
# ╔════════════════════════════════════════════════╗
# ║ Cell 8 – Per-model feature accuracy 📊         ║
# ╚════════════════════════════════════════════════╝
if not a_extra.empty:
    for model_name in a_extra["model"].unique():
        fig, ax = plt.subplots(figsize=(8, 4))
        subset = (
            a_extra[a_extra["model"] == model_name]
            .pivot(index="field", columns="split", values="accuracy")
            .reindex(["domain", "dmf_stage", "decision_type", "bipoc", "indigenous", "sex", "gender"])
        )
        subset.plot(kind="bar", ax=ax, rot=0)
        ax.set_title(f"{model_name} – extra-field accuracy")
        ax.set_ylabel("accuracy")
        ax.set_xlabel("field")
        ax.set_ylim(0, 1)
        ax.legend(title="split")
        plt.tight_layout()
        plt.show()
else:
    print("No extra-field metrics to plot.")

In [None]:
# ╔════════════════════════════════════════════════╗
# ║ Cell 9 – Confusion matrices for binary task 🔲 ║
# ╚════════════════════════════════════════════════╝
for model, split_dict in all_predictions.items():
    for split, df in split_dict.items():
        parsable = df[df["pred_class"].isin(["Included", "Excluded"])]
        if parsable.empty:
            continue

        y_true = parsable["ground_truth"]
        y_pred = parsable["pred_class"]

        cm = confusion_matrix(y_true, y_pred, labels=["Included", "Excluded"])
        plt.figure(figsize=(4, 3))
        sns.heatmap(
            cm,
            annot=True,
            fmt="d",
            cmap="Purples",
            xticklabels=["Included", "Excluded"],
            yticklabels=["Included", "Excluded"],
        )
        plt.title(f"Confusion Matrix – {model} ({split})")
        plt.xlabel("Predicted")
        plt.ylabel("Actual")
        plt.tight_layout()
        plt.show()

        print(f"Classification report for {model} ({split})")
        print(classification_report(y_true, y_pred, digits=3))