# Multimodal Hate/Abuse Detection — Master Pipeline (Steps 2–8)

This notebook orchestrates the **entire pipeline** from Step 2 to Step 8
for the multimodal hate/abuse detection system.

It is designed to be **adaptive**:
- You can run it on the **pilot subset** or the **full dataset**.
- All intermediate artifacts for this notebook are written under
  `Final/` (except **models**, which always live under `models/`).
- Each step is clearly separated so you can re-run only parts as needed.

> **Note:** This notebook *wraps* the existing scripts/notebooks
> (`Step_2`, `Step_3`, `Step_4`, `Step_5`, `Step_6`, `Step_7`, `Step_8`).
> If you want to debug internals of a specific step, you can still open
> the original step notebooks/scripts.


In [1]:
# Install all required packages for the full pipeline (run once per environment).
# You can skip this cell if everything is already installed.

%pip install --upgrade pip

# Core + training libs
%pip install torch torchvision torchaudio

# NLP / vision / datasets / utilities
%pip install transformers datasets webdataset accelerate timm sentencepiece pillow

# OCR stack for Step 2
%pip install paddleocr paddlepaddle

# Lightweight web UI for Step 8
%pip install "huggingface-hub>=0.33.5,<1.0" gradio


Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
# Global configuration and paths

from pathlib import Path
from typing import Dict, Any, List

import sys
import torch

# Robust project root detection:
# Look for the Datasets/ folder which only exists at the true project root
ROOT = Path.cwd().resolve()

# Walk up until we find the actual project root (contains Datasets/)
project_root = ROOT
for _ in range(5):  # max 5 levels up
    if (project_root / "Datasets").is_dir():
        break
    project_root = project_root.parent

# Sanity check
if not (project_root / "Datasets").is_dir():
    raise RuntimeError(f"Could not find project root with Datasets/ folder. Searched from {ROOT}")

FINAL_DIR = project_root / "Final"
FINAL_DIR.mkdir(exist_ok=True)

# Per-step output roots for this master notebook
FINAL_STEP2 = FINAL_DIR / "Step_2"
FINAL_STEP3 = FINAL_DIR / "Step_3"
FINAL_STEP6 = FINAL_DIR / "Step_6"
FINAL_STEP7 = FINAL_DIR / "Step_7"

for p in [FINAL_STEP2, FINAL_STEP3, FINAL_STEP6, FINAL_STEP7]:
    p.mkdir(parents=True, exist_ok=True)

# Root directory where datasets are stored.
# Override this if your datasets live somewhere else.
DATASETS_ROOT = project_root / "Datasets"

if not DATASETS_ROOT.exists():
    raise RuntimeError(f"DATASETS_ROOT does not exist: {DATASETS_ROOT}")

# Global torch device: GPU first (CUDA), then CPU fallback.
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

# For backwards compatibility with later cells that expect `device`
device = DEVICE

# Ensure project_root is importable if you ever need to import local modules
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

print("Project root:", project_root)
print("Datasets root:", DATASETS_ROOT)
print("Final outputs root:", FINAL_DIR)
print("Using torch device:", DEVICE)
print("sys.path includes project root:", str(project_root) in sys.path)

Project root: /Users/yashwanthreddy/Documents/GitHub/DL_Proj
Datasets root: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Datasets
Final outputs root: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Final
Using torch device: cpu
sys.path includes project root: True


## Step 2 — Build Manifest (Final/Step_2/data_manifest.jsonl)

This section wraps `Step_2/build_data_manifest.py` and writes the
manifest for this master notebook under `Final/Step_2/data_manifest.jsonl`.

The manifest still reads raw datasets from `Datasets/`, but **no files
are written into `Step_2/`** by this cell.

In [3]:
# Step 2 — Build Manifest (inlined from Step_2/build_data_manifest.py)

import csv
import json
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional

# Datasets live under the configurable root
DATASETS = DATASETS_ROOT


def iter_hateful_memes() -> Iterable[Dict[str, Any]]:
    ds_root = DATASETS / "hateful_memes"
    for split_name, filename in [
        ("train", "train.jsonl"),
        ("dev", "dev.jsonl"),
        ("test", "test.jsonl"),
    ]:
        jsonl_path = ds_root / filename
        if not jsonl_path.exists():
            continue
        with jsonl_path.open("r", encoding="utf-8") as f:
            for line in f:
                if not line.strip():
                    continue
                obj = json.loads(line)
                img_rel = obj.get("img")
                image_path = str((ds_root / img_rel).resolve()) if img_rel else None
                yield {
                    "id": f"hateful_memes_{split_name}_{obj.get('id')}",
                    "dataset": "hateful_memes",
                    "split": split_name,
                    "image_path": image_path,
                    "text_raw": obj.get("text", ""),
                    "labels_raw": obj.get("label"),
                }


def iter_mami() -> Iterable[Dict[str, Any]]:
    ds_root = DATASETS / "mami"

    # Train: TRAINING/training.csv + images
    train_csv = ds_root / "TRAINING" / "training.csv"
    train_img_root: Optional[Path] = None
    training_dir = ds_root / "TRAINING"
    # Heuristic: images are either directly under TRAINING or a subfolder named like "training" / "Training".
    if training_dir.exists():
        # Prefer nested image folder if present
        for sub in training_dir.iterdir():
            if sub.is_dir() and sub.name.lower().startswith("train"):
                train_img_root = sub
                break
        if train_img_root is None:
            train_img_root = training_dir

    if train_csv.exists() and train_img_root is not None:
        with train_csv.open("r", encoding="utf-8") as f:
            reader = csv.DictReader(f, delimiter="\t")
            # If delimiter guess fails, fall back to default csv dialect
            if reader.fieldnames is None or reader.fieldnames == [
                "file_name,misogynous,shaming,stereotype,objectification,violence,Text Transcription"
            ]:
                f.seek(0)
                reader = csv.DictReader(f)
            for row in reader:
                file_name = row.get("file_name") or row.get("file_name ")
                if not file_name:
                    continue
                img_path = (train_img_root / file_name).resolve()
                text = row.get("Text Transcription", "")
                labels = {
                    "misogynous": int(row.get("misogynous", 0)),
                    "shaming": int(row.get("shaming", 0)),
                    "stereotype": int(row.get("stereotype", 0)),
                    "objectification": int(row.get("objectification", 0)),
                    "violence": int(row.get("violence", 0)),
                }
                yield {
                    "id": f"mami_train_{file_name}",
                    "dataset": "mami",
                    "split": "train",
                    "image_path": str(img_path),
                    "text_raw": text,
                    "labels_raw": labels,
                }

    # Test: test/Test.csv + test_labels.txt + images
    test_txt = ds_root / "test_labels.txt"
    test_csv = ds_root / "test" / "Test.csv"
    test_img_root: Optional[Path] = None
    test_dir = ds_root / "test"
    if test_dir.exists():
        test_img_root = test_dir

    labels_by_name: Dict[str, List[int]] = {}
    if test_txt.exists():
        with test_txt.open("r", encoding="utf-8") as f:
            for line in f:
                parts = line.strip().split("\t")
                if len(parts) != 6:
                    continue
                fn = parts[0]
                try:
                    vals = [int(x) for x in parts[1:]]
                except ValueError:
                    continue
                labels_by_name[fn] = vals

    if test_csv.exists() and test_img_root is not None:
        with test_csv.open("r", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            for row in reader:
                file_name = row.get("file_name")
                if not file_name:
                    continue
                img_path = (test_img_root / file_name).resolve()
                text = row.get("Text Transcription", "")
                raw_labels = labels_by_name.get(file_name)
                labels = None
                if raw_labels is not None and len(raw_labels) == 5:
                    labels = {
                        "misogynous": raw_labels[0],
                        "shaming": raw_labels[1],
                        "stereotype": raw_labels[2],
                        "objectification": raw_labels[3],
                        "violence": raw_labels[4],
                    }
                yield {
                    "id": f"mami_test_{file_name}",
                    "dataset": "mami",
                    "split": "test",
                    "image_path": str(img_path),
                    "text_raw": text,
                    "labels_raw": labels,
                }


def iter_mmhs150k() -> Iterable[Dict[str, Any]]:
    ds_root = DATASETS / "mmhs150k"
    gt_path = ds_root / "MMHS150K_GT.json"
    splits_root = ds_root / "splits"
    img_root = ds_root / "img_resized"
    if not gt_path.exists():
        return
    with gt_path.open("r", encoding="utf-8") as f:
        gt = json.load(f)

    id_to_split: Dict[str, str] = {}
    for split_name, fname in [
        ("train", "train_ids.txt"),
        ("val", "val_ids.txt"),
        ("test", "test_ids.txt"),
    ]:
        p = splits_root / fname
        if not p.exists():
            continue
        with p.open("r", encoding="utf-8") as f_ids:
            for line in f_ids:
                ex_id = line.strip()
                if not ex_id:
                    continue
                id_to_split[ex_id] = split_name

    for ex_id, info in gt.items():
        split = id_to_split.get(ex_id, "unsplit")
        img_name = f"{ex_id}.jpg"
        img_path = (img_root / img_name).resolve()
        text = info.get("tweet_text", "")
        labels = {
            "labels": info.get("labels"),
            "labels_str": info.get("labels_str"),
        }
        yield {
            "id": f"mmhs150k_{ex_id}",
            "dataset": "mmhs150k",
            "split": split,
            "image_path": str(img_path),
            "text_raw": text,
            "labels_raw": labels,
        }


def iter_memotion1() -> Iterable[Dict[str, Any]]:
    ds_root = DATASETS / "memotion1"
    csv_path = ds_root / "labels.csv"
    img_root = ds_root / "images"
    if not csv_path.exists():
        return

    with csv_path.open("r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            img_name = row.get("image_name")
            if not img_name:
                continue
            img_path = (img_root / img_name).resolve()
            text = row.get("text_corrected") or row.get("text_ocr") or ""
            labels = {
                "humour": row.get("humour"),
                "sarcasm": row.get("sarcasm"),
                "offensive": row.get("offensive"),
                "motivational": row.get("motivational"),
                "overall_sentiment": row.get("overall_sentiment"),
            }
            yield {
                "id": f"memotion1_{img_name}",
                "dataset": "memotion1",
                "split": "train",  # dataset does not define explicit splits
                "image_path": str(img_path),
                "text_raw": text,
                "labels_raw": labels,
            }


def iter_memotion2() -> Iterable[Dict[str, Any]]:
    ds_root = DATASETS / "memotion2"
    img_root_train = ds_root / "image folder" / "train_images"
    img_root_val = ds_root / "image folder" / "val_images"

    def _iter_split(csv_path: Path, img_root: Path, split_name: str) -> Iterable[Dict[str, Any]]:
        if not csv_path.exists() or not img_root.exists():
            return []
        with csv_path.open("r", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            for row in reader:
                ex_id = row.get("Id")
                if not ex_id:
                    continue
                img_name = f"{ex_id}.jpg"
                img_path = (img_root / img_name).resolve()
                text = row.get("ocr_text", "")
                labels = {
                    "humour": row.get("humour"),
                    "sarcastic": row.get("sarcastic"),
                    "offensive": row.get("offensive"),
                    "motivational": row.get("motivational"),
                    "overall_sentiment": row.get("overall_sentiment"),
                    "classification_based_on": row.get("classification_based_on"),
                }
                yield {
                    "id": f"memotion2_{split_name}_{ex_id}",
                    "dataset": "memotion2",
                    "split": split_name,
                    "image_path": str(img_path),
                    "text_raw": text,
                    "labels_raw": labels,
                }

    train_csv = ds_root / "Memotion2" / "memotion_train.csv"
    val_csv = ds_root / "Memotion2" / "memotion_val.csv"

    for rec in _iter_split(train_csv, img_root_train, "train"):
        yield rec
    for rec in _iter_split(val_csv, img_root_val, "val"):
        yield rec

    # Test split has no labels; include if desired
    test_csv = ds_root / "memotion_test.csv"
    if test_csv.exists():
        with test_csv.open("r", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            for row in reader:
                ex_id = row.get("Id")
                if not ex_id:
                    continue
                img_name = f"{ex_id}.jpg"
                # Test images are usually not provided locally; leave image_path empty if missing
                img_path = None
                if img_root_train.exists():
                    candidate = img_root_train / img_name
                    if candidate.exists():
                        img_path = str(candidate.resolve())
                if img_path is None and img_root_val.exists():
                    candidate = img_root_val / img_name
                    if candidate.exists():
                        img_path = str(candidate.resolve())
                text = row.get("ocr_text", "")
                labels = {
                    "classification_based_on": row.get("classification_based_on"),
                }
                yield {
                    "id": f"memotion2_test_{ex_id}",
                    "dataset": "memotion2",
                    "split": "test",
                    "image_path": img_path,
                    "text_raw": text,
                    "labels_raw": labels,
                }


def iter_hatexplain() -> Iterable[Dict[str, Any]]:
    ds_root = DATASETS / "hatexplain" / "Data"
    dataset_path = ds_root / "dataset.json"
    splits_path = ds_root / "post_id_divisions.json"
    if not dataset_path.exists() or not splits_path.exists():
        return
    with dataset_path.open("r", encoding="utf-8") as f:
        data = json.load(f)
    with splits_path.open("r", encoding="utf-8") as f:
        splits = json.load(f)

    id_to_split: Dict[str, str] = {}
    for split_name, ids in splits.items():
        for pid in ids:
            id_to_split[pid] = split_name

    for pid, info in data.items():
        split = id_to_split.get(pid, "unsplit")
        tokens = info.get("post_tokens", [])
        text = " ".join(tokens)
        labels = info.get("annotators", [])
        yield {
            "id": f"hatexplain_{pid}",
            "dataset": "hatexplain",
            "split": split,
            "image_path": None,
            "text_raw": text,
            "labels_raw": labels,
        }


def iter_olid() -> Iterable[Dict[str, Any]]:
    ds_root = DATASETS / "olid"
    train_tsv = ds_root / "olid-training-v1.0.tsv"
    test_tsv = ds_root / "testset-levela.tsv"
    test_labels_csv = ds_root / "labels-levela.csv"

    # Train
    if train_tsv.exists():
        with train_tsv.open("r", encoding="utf-8") as f:
            reader = csv.DictReader(f, delimiter="\t")
            for row in reader:
                ex_id = row.get("id")
                if not ex_id:
                    continue
                text = row.get("tweet", "")
                labels = {
                    "subtask_a": row.get("subtask_a"),
                    "subtask_b": row.get("subtask_b"),
                    "subtask_c": row.get("subtask_c"),
                }
                yield {
                    "id": f"olid_train_{ex_id}",
                    "dataset": "olid",
                    "split": "train",
                    "image_path": None,
                    "text_raw": text,
                    "labels_raw": labels,
                }

    # Test (level a)
    id_to_label: Dict[str, str] = {}
    if test_labels_csv.exists():
        with test_labels_csv.open("r", encoding="utf-8") as f:
            reader = csv.reader(f)
            for row in reader:
                if len(row) != 2:
                    continue
                id_to_label[row[0]] = row[1]

    if test_tsv.exists():
        with test_tsv.open("r", encoding="utf-8") as f:
            reader = csv.DictReader(f, delimiter="\t")
            for row in reader:
                ex_id = row.get("id")
                if not ex_id:
                    continue
                text = row.get("tweet", "")
                label = id_to_label.get(ex_id)
                labels = {"subtask_a": label}
                yield {
                    "id": f"olid_test_{ex_id}",
                    "dataset": "olid",
                    "split": "test",
                    "image_path": None,
                    "text_raw": text,
                    "labels_raw": labels,
                }


def build_manifest_step2(output_path: Path) -> None:
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with output_path.open("w", encoding="utf-8") as out_f:
        for iterator in [
            iter_hateful_memes,
            iter_mami,
            iter_mmhs150k,
            iter_memotion1,
            iter_memotion2,
            iter_hatexplain,
            iter_olid,
        ]:
            for record in iterator():
                out_f.write(json.dumps(record, ensure_ascii=False) + "\n")


# Run manifest build for this master pipeline
manifest_path = FINAL_STEP2 / "data_manifest.jsonl"
print("Building manifest at:", manifest_path)
build_manifest_step2(manifest_path)

# Quick sanity check: show a few lines
print("\nFirst 3 lines of manifest:")
with manifest_path.open("r", encoding="utf-8") as f:
    for i, line in enumerate(f):
        if i >= 3:
            break
        print(line.strip())

Building manifest at: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Final/Step_2/data_manifest.jsonl

First 3 lines of manifest:
{"id": "memotion1_image_1.jpg", "dataset": "memotion1", "split": "train", "image_path": "/Users/yashwanthreddy/Documents/GitHub/DL_Proj/Datasets/memotion1/images/image_1.jpg", "text_raw": "LOOK THERE MY FRIEND LIGHTYEAR NOW ALL SOHALIKUT TREND PLAY THE 10 YEARS CHALLENGE AT FACEBOOK imgflip.com", "labels_raw": {"humour": "hilarious", "sarcasm": "general", "offensive": "not_offensive", "motivational": "not_motivational", "overall_sentiment": "very_positive"}}
{"id": "memotion1_image_2.jpeg", "dataset": "memotion1", "split": "train", "image_path": "/Users/yashwanthreddy/Documents/GitHub/DL_Proj/Datasets/memotion1/images/image_2.jpeg", "text_raw": "The best of #10 YearChallenge! Completed in less the 4 years. Kudus to @narendramodi ji 8:05 PM - 16 Jan 2019 from Mumbai  India", "labels_raw": {"humour": "not_funny", "sarcasm": "general", "offensive": "not_offensi

## Step 2 — Run OCR (Final/Step_2/ocr.jsonl)

Run PaddleOCR over **all images** listed in the manifest and write
results to `Final/Step_2/ocr.jsonl`. QC images are written to
`Final/Step_2/ocr_qc/`.

In [4]:
# Step 2 — Run OCR (inlined from Step_2/run_ocr.py, outputs under Final/Step_2)

import json
import sys
import time
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional

import numpy as np
from PIL import Image, ImageDraw

# QC images live under Final/Step_2/ocr_qc
QC_DIR = FINAL_STEP2 / "ocr_qc"


@dataclass
class OcrLine:
    text: str
    conf: float
    bbox: List[float]  # [x_min, y_min, x_max, y_max]


def _load_paddleocr():
    """Load PaddleOCR with the modern 3.x API settings.

    We try to initialize with GPU first (use_gpu=True) *if* the installed
    PaddleOCR supports that argument. Otherwise we rely on the library's
    own device selection (GPU build -> GPU, otherwise CPU).
    """
    try:
        from paddleocr import PaddleOCR
    except ImportError as e:
        raise RuntimeError(
            "PaddleOCR is not installed. Install with:\n"
            "  pip install paddleocr paddlepaddle pillow numpy\n"
        ) from e

    base_kwargs = dict(
        use_doc_orientation_classify=False,
        use_doc_unwarping=False,
        use_textline_orientation=False,
        lang="en",
    )

    # Some PaddleOCR versions expose a 'use_gpu' argument, others do not.
    # We inspect the signature to decide whether we can pass it.
    import inspect

    sig = inspect.signature(PaddleOCR.__init__)
    supports_use_gpu = "use_gpu" in sig.parameters

    if supports_use_gpu:
        try:
            ocr = PaddleOCR(**base_kwargs, use_gpu=True)
            print("PaddleOCR initialized with GPU (use_gpu=True).")
            return ocr
        except Exception as gpu_err:
            print(f"[INFO] PaddleOCR GPU init failed, falling back to CPU: {gpu_err}")
            try:
                ocr = PaddleOCR(**base_kwargs, use_gpu=False)
                print("PaddleOCR initialized on CPU (use_gpu=False).")
                return ocr
            except Exception as cpu_err:
                raise RuntimeError(
                    f"Failed to initialize PaddleOCR even on CPU: {cpu_err}"
                ) from cpu_err

    # If 'use_gpu' is not supported, fall back to defaults (library decides device)
    try:
        ocr = PaddleOCR(**base_kwargs)
        print("PaddleOCR initialized (no explicit use_gpu; using library defaults).")
        return ocr
    except Exception as e:
        raise RuntimeError(f"Failed to initialize PaddleOCR: {e}") from e


def _iter_manifest_ocr(manifest_path: Path) -> Iterable[Dict[str, Any]]:
    """Iterate over records in the manifest JSONL file."""
    with manifest_path.open("r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            yield json.loads(line)


def _count_manifest_images(manifest_path: Path) -> int:
    """Count records with image_path in the manifest."""
    count = 0
    with manifest_path.open("r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            rec = json.loads(line)
            if rec.get("image_path"):
                count += 1
    return count


def _resize_image_if_needed(img: Image.Image, max_side: int = 1024) -> Image.Image:
    """Resize image so long side <= max_side, preserving aspect ratio."""
    w, h = img.size
    long_side = max(w, h)
    if long_side > max_side:
        scale = max_side / float(long_side)
        new_size = (int(w * scale), int(h * scale))
        img = img.resize(new_size, Image.LANCZOS)
    return img


def _run_ocr_on_image(ocr_engine, image_path: Path) -> List[OcrLine]:
    """Run OCR on a single image using PaddleOCR 3.x predict() API."""
    lines: List[OcrLine] = []

    try:
        # Load and optionally resize image
        with Image.open(image_path) as img:
            img = img.convert("RGB")
            img = _resize_image_if_needed(img, max_side=1024)
            img_array = np.array(img)

        # PaddleOCR 3.x API: use predict() method
        results = ocr_engine.predict(input=img_array)

        if not results:
            return lines

        # results is a list of Result objects (one per input image)
        for res in results:
            # Get the JSON representation of results
            # Structure: res.json = {'res': {'rec_texts': [...], 'rec_scores': [...], 'rec_boxes': [...]}}
            res_json = res.json if hasattr(res, "json") else {}
            res_dict = res_json.get("res", {}) if isinstance(res_json, dict) else {}

            # Extract recognized texts, scores, and boxes
            rec_texts = res_dict.get("rec_texts", [])
            rec_scores = res_dict.get("rec_scores", [])
            rec_boxes = res_dict.get("rec_boxes", [])

            # Build OcrLine objects
            for i, text in enumerate(rec_texts):
                conf = rec_scores[i] if i < len(rec_scores) else 0.0
                # rec_boxes is list of [x_min, y_min, x_max, y_max]
                if i < len(rec_boxes):
                    box = rec_boxes[i]
                    if isinstance(box, np.ndarray):
                        box = box.tolist()
                    bbox = [float(x) for x in box[:4]] if len(box) >= 4 else [0.0, 0.0, 0.0, 0.0]
                else:
                    bbox = [0.0, 0.0, 0.0, 0.0]

                lines.append(OcrLine(text=str(text), conf=float(conf), bbox=bbox))

    except Exception as e:
        # Log error but don't crash the whole pipeline
        print(f"  [WARN] OCR failed for {image_path}: {e}", file=sys.stderr)
        return []

    return lines


def _draw_qc(image_path: Path, lines: List[OcrLine], qc_path: Path) -> None:
    """Draw OCR boxes and text on image for QC visualization."""
    try:
        with Image.open(image_path) as img:
            img = img.convert("RGB")
            draw = ImageDraw.Draw(img)
            for ln in lines:
                x1, y1, x2, y2 = ln.bbox
                draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
                label = (
                    f"{ln.text[:30]}... ({ln.conf:.2f})" if len(ln.text) > 30 else f"{ln.text} ({ln.conf:.2f})"
                )
                draw.text((x1, max(0, y1 - 12)), label, fill="red")
            qc_path.parent.mkdir(parents=True, exist_ok=True)
            img.save(qc_path)
    except Exception:
        # QC failures should not crash main pipeline
        pass


def run_ocr_step2(
    manifest_path: Path,
    output_path: Path,
    qc: bool = True,
    limit: Optional[int] = None,
    progress_interval: int = 100,
) -> None:
    """Run OCR on all images in the manifest and write results to output_path."""
    if not manifest_path.exists():
        raise FileNotFoundError(f"Manifest not found: {manifest_path}")

    QC_DIR.mkdir(parents=True, exist_ok=True)

    print("Loading PaddleOCR model...")
    ocr_engine = _load_paddleocr()
    print("PaddleOCR model loaded.")

    # Count total for progress
    print("Counting images in manifest...")
    total_images = _count_manifest_images(manifest_path)
    if limit:
        total_images = min(total_images, limit)
    print(f"Will process up to {total_images} images.")

    processed = 0
    no_text = 0
    conf_sum = 0.0
    conf_count = 0
    start_time = time.time()

    with output_path.open("w", encoding="utf-8") as out_f:
        for record in _iter_manifest_ocr(manifest_path):
            image_path = record.get("image_path")
            if not image_path:
                continue

            img_path = Path(image_path)
            if not img_path.exists():
                continue

            processed += 1
            if limit and processed > limit:
                break

            # Run OCR
            lines = _run_ocr_on_image(ocr_engine, img_path)

            if not lines:
                no_text += 1

            for ln in lines:
                conf_sum += ln.conf
                conf_count += 1

            # Write result
            ocr_payload = {
                "id": record.get("id"),
                "dataset": record.get("dataset"),
                "split": record.get("split"),
                "image_path": image_path,
                "ocr": {
                    "lines": [asdict(ln) for ln in lines]
                },
            }
            out_f.write(json.dumps(ocr_payload, ensure_ascii=False) + "\n")

            # QC image
            if qc and lines:
                qc_path = QC_DIR / f"{record.get('id')}.png"
                _draw_qc(img_path, lines, qc_path)

            # Progress reporting
            if processed % progress_interval == 0:
                elapsed = time.time() - start_time
                rate = processed / elapsed if elapsed > 0 else 0
                eta = (total_images - processed) / rate if rate > 0 else 0
                print(
                    f"  Processed {processed}/{total_images} images "
                    f"({processed/total_images*100:.1f}%) - "
                    f"{rate:.1f} img/s - ETA: {eta/60:.1f} min"
                )

    # Final summary
    elapsed = time.time() - start_time
    failure_rate = float(no_text) / float(processed) if processed > 0 else 0.0
    avg_conf = float(conf_sum) / float(conf_count) if conf_count > 0 else 0.0

    print("\n" + "=" * 60)
    print("OCR SUMMARY")
    print("=" * 60)
    print(f"Total images processed: {processed}")
    print(f"Images with no OCR text: {no_text} (failure rate: {failure_rate:.3f})")
    print(f"Total text lines detected: {conf_count}")
    print(f"Mean OCR confidence: {avg_conf:.3f}")
    print(f"Total time: {elapsed/60:.1f} minutes ({elapsed:.0f} seconds)")
    print(f"Output written to: {output_path}")
    print("=" * 60)


# Run OCR for this master pipeline
ocr_output_path = FINAL_STEP2 / "ocr.jsonl"
print("Running OCR...")
run_ocr_step2(
    manifest_path=manifest_path,
    output_path=ocr_output_path,
    qc=True,
    limit=None,
    progress_interval=200,
)

print("OCR output:", ocr_output_path)

Running OCR...
Loading PaddleOCR model...


[32mCreating model: ('PP-OCRv5_server_det', None)[0m
[32mModel files already exist. Using cached files. To redownload, please delete the directory manually: `/Users/yashwanthreddy/.paddlex/official_models/PP-OCRv5_server_det`.[0m
[32mCreating model: ('en_PP-OCRv5_mobile_rec', None)[0m
[32mModel files already exist. Using cached files. To redownload, please delete the directory manually: `/Users/yashwanthreddy/.paddlex/official_models/en_PP-OCRv5_mobile_rec`.[0m


PaddleOCR initialized (no explicit use_gpu; using library defaults).
PaddleOCR model loaded.
Counting images in manifest...
Will process up to 6992 images.
  Processed 200/6992 images (2.9%) - 0.6 img/s - ETA: 200.9 min


KeyboardInterrupt: 

## Step 3 — Pack Manifest + OCR into WebDataset Shards (Final/Step_3)

Pack the unified manifest + OCR JSONL into WebDataset shards under
`Final/Step_3/shards/`. Also writes:

- `Final/Step_3/label_taxonomy.json`
- `Final/Step_3/splits.json`
- `Final/Step_3/stats.json`.

In [None]:
# Step 3 — Pack Manifest + OCR into WebDataset Shards (inlined from Step_3/pack_examples.py)

import json
import time
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional

import numpy as np
from PIL import Image
import webdataset as wds


def _iter_manifest_pack(manifest_path: Path) -> Iterable[Dict[str, Any]]:
    """Yield records from a JSONL manifest file."""
    with manifest_path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            yield json.loads(line)


def _load_ocr_index(ocr_path: Path) -> Dict[str, List[Dict[str, Any]]]:
    """Load OCR results into an index keyed by example id.

    Each value is the `ocr.lines` list for that example.
    """
    index: Dict[str, List[Dict[str, Any]]] = {}
    if not ocr_path.exists():
        print(f"[WARN] OCR file not found, continuing without OCR: {ocr_path}")
        return index

    with ocr_path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rec = json.loads(line)
            rec_id = rec.get("id")
            ocr = rec.get("ocr") or {}
            lines = ocr.get("lines") or []
            if rec_id:
                index[rec_id] = lines
    return index


def _build_text(ocr_lines: List[Dict[str, Any]], text_raw: Optional[str]) -> str:
    """Build combined text field from OCR lines and raw caption/text.

    Pattern (for now, assuming English):
    [OCR] <joined OCR lines> [/OCR] [CAP] <text_raw> [/CAP] <lang=en>
    """
    parts: List[str] = []

    texts_ocr: List[str] = []
    for ln in ocr_lines or []:
        t = (ln.get("text") or "").strip()
        try:
            conf = float(ln.get("conf", 0.0))
        except Exception:
            conf = 0.0
        if not t:
            continue
        # Drop extremely low confidence lines
        if conf < 0.3:
            continue
        texts_ocr.append(t)

    if texts_ocr:
        parts.append("[OCR] " + " ".join(texts_ocr) + " [/OCR]")

    if text_raw is not None:
        tr = str(text_raw).strip()
        if tr:
            parts.append("[CAP] " + tr + " [/CAP]")

    if parts:
        parts.append("<lang=en>")

    return " ".join(parts).strip()


def _map_labels(dataset: str, labels_raw: Any) -> Dict[str, Any]:
    """Map dataset-specific `labels_raw` into a unified label schema."""
    result: Dict[str, Any] = {
        "abuse_hate": None,
        "offensive": None,
        "target_type": "unknown",
        "target_group": [],
        "aux": {
            "dataset": dataset,
            "labels_raw": labels_raw,
        },
    }

    # Example: Hateful Memes uses 0/1 labels where 1 is hateful
    if dataset == "hateful_memes":
        try:
            v = int(labels_raw)
        except Exception:
            v = labels_raw
        if v in (0, 1):
            result["abuse_hate"] = v
            result["offensive"] = v
            result["target_type"] = "group" if v == 1 else "none"

    # TODO: Extend mappings for other datasets (MAMI, MMHS150K, Memotion, HateXplain, OLID)

    return result


def _ensure_default_taxonomy(taxonomy_path: Path) -> Dict[str, Any]:
    """Ensure a default label taxonomy file exists and return its contents."""
    if taxonomy_path.exists():
        with taxonomy_path.open("r", encoding="utf-8") as f:
            return json.load(f)

    taxonomy: Dict[str, Any] = {
        "schema": {
            "abuse_hate": {
                "type": "binary",
                "values": [0, 1],
                "description": "Hate/abuse present (1) or not (0)",
            },
            "offensive": {
                "type": "binary",
                "values": [0, 1],
                "description": "Offensive language present",
            },
            "target_type": {
                "type": "categorical",
                "values": ["none", "individual", "group", "other", "unknown"],
                "description": "Target granularity (if any)",
            },
            "target_group": {
                "type": "multilabel",
                "description": "Target groups (e.g. women, black, jewish, muslim, lgbtq)",
            },
            "aux": {
                "type": "object",
                "description": "Dataset-specific raw labels and metadata",
            },
        },
        "datasets": {
            "hateful_memes": {
                "notes": "labels_raw in {0,1} mapped to abuse_hate/offensive; target_type='group' if 1 else 'none'",
            },
        },
    }

    taxonomy_path.parent.mkdir(parents=True, exist_ok=True)
    with taxonomy_path.open("w", encoding="utf-8") as f:
        json.dump(taxonomy, f, indent=2, ensure_ascii=False)

    return taxonomy


def _open_shard_writers(output_dir: Path, shard_size: int) -> Dict[str, wds.ShardWriter]:
    """Create a dict of ShardWriter objects per split (lazily)."""
    writers: Dict[str, wds.ShardWriter] = {}

    def get_writer(split: str) -> wds.ShardWriter:
        if split not in writers:
            split_dir = output_dir / split
            split_dir.mkdir(parents=True, exist_ok=True)
            pattern = str(split_dir / "shard-%06d.tar")
            writers[split] = wds.ShardWriter(pattern, maxcount=shard_size)
        return writers[split]

    # Attach the accessor to the dict for convenience
    writers["__get_writer__"] = get_writer  # type: ignore
    return writers


def _close_writers(writers: Dict[str, wds.ShardWriter]) -> None:
    for key, writer in list(writers.items()):
        if key == "__get_writer__":
            continue
        writer.close()


def run_packing_step3(
    manifest_path: Path,
    ocr_path: Path,
    output_dir: Path,
    taxonomy_path: Path,
    splits_path: Path,
    examples_limit: Optional[int] = None,
    shard_size: int = 5000,
    image_size: int = 384,
    progress_interval: int = 1000,
    include_datasets: Optional[List[str]] = None,
) -> None:
    """Main packing routine for the master pipeline."""
    if not manifest_path.exists():
        raise FileNotFoundError(f"Manifest not found: {manifest_path}")

    output_dir.mkdir(parents=True, exist_ok=True)

    print("Ensuring label taxonomy exists...")
    _ensure_default_taxonomy(taxonomy_path)

    print("Loading OCR index (this may take a while for large files)...")
    ocr_index = _load_ocr_index(ocr_path)
    print(f"Loaded OCR for {len(ocr_index)} examples.")

    include_set = set(include_datasets) if include_datasets else None

    writers = _open_shard_writers(output_dir, shard_size)
    get_writer = writers["__get_writer__"]  # type: ignore

    # Stats
    total_seen = 0
    total_packed = 0
    total_failed = 0
    split_counts: Dict[str, int] = {}
    split_packed: Dict[str, int] = {}
    dataset_packed: Dict[str, int] = {}
    splits_map: Dict[str, str] = {}

    start_time = time.time()

    for record in _iter_manifest_pack(manifest_path):
        total_seen += 1

        example_id = record.get("id")
        dataset = record.get("dataset", "unknown")
        split = record.get("split") or "train"

        if not example_id:
            total_failed += 1
            continue

        splits_map[example_id] = split
        split_counts[split] = split_counts.get(split, 0) + 1

        if include_set is not None and dataset not in include_set:
            # Skip but do not treat as failure; it's a user-chosen subset
            continue

        if examples_limit is not None and total_packed >= examples_limit:
            break

        try:
            image_path = record.get("image_path")
            text_raw = record.get("text_raw")
            labels_raw = record.get("labels_raw")

            # Image (optional)
            img_bytes: Optional[bytes] = None
            if image_path:
                img_path = Path(image_path)
                if img_path.exists():
                    with Image.open(img_path) as img:
                        img = img.convert("RGB")
                        img = img.resize((image_size, image_size), Image.BICUBIC)
                        buf = BytesIO()
                        img.save(buf, format="PNG")
                        img_bytes = buf.getvalue()
                else:
                    # Missing image file; treat as a failure for this example
                    raise FileNotFoundError(f"Image not found: {img_path}")

            # OCR lines (may be empty)
            ocr_lines = ocr_index.get(example_id, [])

            # Combined text
            combined_text = _build_text(ocr_lines, text_raw)

            # Labels
            labels = _map_labels(dataset, labels_raw)

            meta = {
                "id": example_id,
                "dataset": dataset,
                "split": split,
                "labels": labels,
            }

            sample: Dict[str, Any] = {"__key__": example_id}
            if img_bytes is not None:
                sample["png"] = img_bytes
            sample["txt"] = (combined_text or "").encode("utf-8")
            sample["json"] = json.dumps(meta, ensure_ascii=False).encode("utf-8")

            writer = get_writer(split)
            writer.write(sample)

            total_packed += 1
            split_packed[split] = split_packed.get(split, 0) + 1
            dataset_packed[dataset] = dataset_packed.get(dataset, 0) + 1

        except Exception as e:
            total_failed += 1
            print(f"[WARN] failed to pack id={example_id}: {e}", file=sys.stderr)

        if progress_interval and total_seen % progress_interval == 0:
            elapsed = time.time() - start_time
            rate = total_seen / elapsed if elapsed > 0 else 0.0
            print(
                f"Processed {total_seen} records, packed={total_packed}, "
                f"failed={total_failed} - {rate:.1f} rec/s"
            )

    _close_writers(writers)

    elapsed = time.time() - start_time
    failure_rate = float(total_failed) / float(total_seen) if total_seen > 0 else 0.0

    stats = {
        "total_manifest_records": total_seen,
        "packed_examples": total_packed,
        "failed_examples": total_failed,
        "failure_rate": failure_rate,
        "split_counts": split_counts,
        "split_packed": split_packed,
        "dataset_packed": dataset_packed,
        "elapsed_seconds": elapsed,
    }

    stats_path = FINAL_STEP3 / "stats.json"
    with stats_path.open("w", encoding="utf-8") as f:
        json.dump(stats, f, indent=2, ensure_ascii=False)

    # Only create splits.json automatically if it does not exist yet
    if not splits_path.exists():
        with splits_path.open("w", encoding="utf-8") as f:
            json.dump(splits_map, f, indent=2, ensure_ascii=False)

    print("\n" + "=" * 60)
    print("STEP 3 PACKING SUMMARY")
    print("=" * 60)
    print(f"Total manifest records seen: {total_seen}")
    print(f"Packed examples: {total_packed}")
    print(f"Failed examples: {total_failed}")
    print(f"Failure rate: {failure_rate:.4f}")
    print(f"Elapsed time: {elapsed/60:.1f} minutes ({elapsed:.0f} seconds)")
    print(f"Stats written to: {stats_path}")
    if not splits_path.exists():
        print(f"Splits mapping written to: {splits_path}")
    else:
        print(f"Splits mapping already existed at: {splits_path}")
    print("=" * 60)


# Run packing for this master pipeline
packed_shards_dir = FINAL_STEP3 / "shards"
packed_shards_dir.mkdir(parents=True, exist_ok=True)

label_taxonomy_path = FINAL_STEP3 / "label_taxonomy.json"
splits_path = FINAL_STEP3 / "splits.json"

print("Packing WebDataset shards into:", packed_shards_dir)

run_packing_step3(
    manifest_path=manifest_path,
    ocr_path=ocr_output_path,
    output_dir=packed_shards_dir,
    taxonomy_path=label_taxonomy_path,
    splits_path=splits_path,
    examples_limit=PACK_EXAMPLES_LIMIT,
    shard_size=5000,
    image_size=384,
    progress_interval=500 if MODE == "pilot" else 5000,
    include_datasets=None,
)

print("Shards directory:", packed_shards_dir)

## Step 4 — Train Unimodal Experts (Text & Image)

This section inlines the training logic from `Step_4/step4_unimodal_experts.ipynb`,
pointing it at the **Final/Step_3** shards. Models are still saved to
`models/text_expert/` and `models/vision_expert/`.

For the pilot mode we train **1 epoch** on the first shard; for full
mode you can point `shard_pattern` to a larger set of shards and/or
increase the number of epochs.

In [None]:
import json

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

import webdataset as wds

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoImageProcessor,
    AutoModelForImageClassification,
)

from torchvision import transforms

models_root = project_root / "models"
models_root.mkdir(exist_ok=True)

# Shard pattern from Final/Step_3
train_shards_dir = FINAL_STEP3 / "shards" / "train"
shard_pattern = str(train_shards_dir / "shard-000000.tar")

print("Training unimodal experts from shards:", shard_pattern)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

TEXT_MODEL_NAME = "microsoft/mdeberta-v3-base"
VISION_MODEL_NAME = "google/siglip-base-patch16-384"


class TextDataset(Dataset):
    def __init__(self, data: List[Dict[str, Any]], tokenizer, max_length: int = 256):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.data[idx]
        text = item["text"]
        label = item["label"]
        enc = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        enc = {k: v.squeeze(0) for k, v in enc.items()}
        enc["labels"] = torch.tensor(label, dtype=torch.long)
        return enc


class ImageDataset(Dataset):
    def __init__(self, data: List[Dict[str, Any]], image_processor, train: bool = True):
        self.data = data
        self.image_processor = image_processor
        self.train = train
        self.aug = transforms.Compose([
            transforms.RandomHorizontalFlip(),
        ])

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.data[idx]
        img = item["image"]
        label = item["label"]
        if self.train:
            img = self.aug(img)
        inputs = self.image_processor(images=img, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze(0)
        return {
            "pixel_values": pixel_values,
            "labels": torch.tensor(label, dtype=torch.long),
        }


def make_text_dataset(shard_pattern: str, max_samples: Optional[int] = None):
    ds = (
        wds.WebDataset(shard_pattern, shardshuffle=False)
        .to_tuple("txt", "json")
    )

    out: List[Dict[str, Any]] = []
    for text_obj, meta_obj in ds:
        if isinstance(text_obj, (bytes, bytearray)):
            text = text_obj.decode("utf-8", errors="replace")
        else:
            text = str(text_obj)

        if isinstance(meta_obj, (bytes, bytearray)):
            meta = json.loads(meta_obj.decode("utf-8"))
        else:
            meta = meta_obj

        labels = (meta or {}).get("labels", {})
        y = labels.get("abuse_hate")
        if y is None:
            continue
        out.append({"text": text, "label": int(y)})
        if max_samples is not None and len(out) >= max_samples:
            break
    return out


def make_image_dataset(shard_pattern: str, max_samples: Optional[int] = None):
    ds = (
        wds.WebDataset(shard_pattern, shardshuffle=False)
        .decode("pil")
        .to_tuple("png", "json")
    )

    out: List[Dict[str, Any]] = []
    for img, meta_obj in ds:
        if isinstance(meta_obj, (bytes, bytearray)):
            meta = json.loads(meta_obj.decode("utf-8"))
        else:
            meta = meta_obj

        labels = (meta or {}).get("labels", {})
        y = labels.get("abuse_hate")
        if y is None:
            continue
        out.append({"image": img, "label": int(y)})
        if max_samples is not None and len(out) >= max_samples:
            break
    return out


text_examples = make_text_dataset(shard_pattern, max_samples=None if MODE == "full" else 1000)
image_examples = make_image_dataset(shard_pattern, max_samples=None if MODE == "full" else 1000)
print(f"Loaded {len(text_examples)} text examples and {len(image_examples)} image examples.")

# --- Train text expert ---

tokenizer_text = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
model_text = AutoModelForSequenceClassification.from_pretrained(
    TEXT_MODEL_NAME,
    num_labels=2,
)
model_text.to(device)

dataset_text = TextDataset(text_examples, tokenizer_text)
loader_text = DataLoader(dataset_text, batch_size=8, shuffle=True)

optimizer_text = torch.optim.AdamW(model_text.parameters(), lr=2e-5)
criterion_text = nn.CrossEntropyLoss()

EPOCHS_TEXT = 1 if MODE == "pilot" else 2

model_text.train()
for epoch in range(EPOCHS_TEXT):
    total_loss = 0.0
    for batch in loader_text:
        batch = {k: v.to(device) for k, v in batch.items()}
        labels = batch.pop("labels")
        outputs = model_text(**batch)
        logits = outputs.logits
        loss = criterion_text(logits, labels)
        optimizer_text.zero_grad()
        loss.backward()
        optimizer_text.step()
        total_loss += loss.item()
    print(f"[Text] Epoch {epoch + 1}/{EPOCHS_TEXT} - loss: {total_loss / len(loader_text):.4f}")

text_expert_dir = models_root / "text_expert"
model_text.save_pretrained(text_expert_dir)
tokenizer_text.save_pretrained(text_expert_dir)
print("Saved text expert to", text_expert_dir)

# --- Train vision expert ---

image_processor = AutoImageProcessor.from_pretrained(VISION_MODEL_NAME)
model_vision = AutoModelForImageClassification.from_pretrained(
    VISION_MODEL_NAME,
    num_labels=2,
)
model_vision.to(device)

dataset_img = ImageDataset(image_examples, image_processor, train=True)
loader_img = DataLoader(dataset_img, batch_size=8, shuffle=True)

optimizer_v = torch.optim.AdamW(model_vision.parameters(), lr=1e-4)
criterion_v = nn.CrossEntropyLoss()

EPOCHS_V = 1 if MODE == "pilot" else 2

model_vision.train()
for epoch in range(EPOCHS_V):
    total_loss = 0.0
    for batch in loader_img:
        batch = {k: v.to(device) for k, v in batch.items()}
        labels = batch.pop("labels")
        outputs = model_vision(**batch)
        logits = outputs.logits
        loss = criterion_v(logits, labels)
        optimizer_v.zero_grad()
        loss.backward()
        optimizer_v.step()
        total_loss += loss.item()
    print(f"[Vision] Epoch {epoch + 1}/{EPOCHS_V} - loss: {total_loss / len(loader_img):.4f}")

vision_expert_dir = models_root / "vision_expert"
model_vision.save_pretrained(vision_expert_dir)
image_processor.save_pretrained(vision_expert_dir)
print("Saved vision expert to", vision_expert_dir)


## Step 5 — Train Multimodal Fusion Model

This section mirrors `Step_5/step5_fusion.ipynb` but uses shards under
`Final/Step_3/shards/` and the unimodal experts saved to `models/`.

The fusion model weights are saved to `models/mm_fusion/fusion_model.pt`.

In [None]:
from torch.utils.data import Dataset as TorchDataset
from torchvision import transforms as T

from transformers import AutoModel


class FusionDataset(TorchDataset):
    def __init__(
        self,
        data: List[Dict[str, Any]],
        text_tokenizer,
        image_processor,
        max_length: int = 256,
        train: bool = True,
    ) -> None:
        self.data = data
        self.text_tokenizer = text_tokenizer
        self.image_processor = image_processor
        self.max_length = max_length
        self.train = train
        self.aug = T.Compose([
            T.RandomHorizontalFlip(),
        ])

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.data[idx]
        text = item["text"]
        img = item["image"]
        label = item["label"]

        if self.train:
            img = self.aug(img)

        text_enc = self.text_tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        text_enc = {k: v.squeeze(0) for k, v in text_enc.items()}

        img_enc = self.image_processor(images=img, return_tensors="pt")
        pixel_values = img_enc["pixel_values"].squeeze(0)

        return {
            "input_ids": text_enc["input_ids"],
            "attention_mask": text_enc["attention_mask"],
            "pixel_values": pixel_values,
            "labels": torch.tensor(label, dtype=torch.long),
        }


def make_fusion_examples(shard_pattern: str, max_samples: Optional[int] = None) -> List[Dict[str, Any]]:
    ds = (
        wds.WebDataset(shard_pattern, shardshuffle=False)
        .decode("pil")
        .to_tuple("txt", "png", "json")
    )

    out: List[Dict[str, Any]] = []
    for text_obj, img, meta_obj in ds:
        if isinstance(text_obj, (bytes, bytearray)):
            text = text_obj.decode("utf-8", errors="replace")
        else:
            text = str(text_obj)

        if isinstance(meta_obj, (bytes, bytearray)):
            meta = json.loads(meta_obj.decode("utf-8"))
        else:
            meta = meta_obj

        labels = (meta or {}).get("labels", {})
        y = labels.get("abuse_hate")
        if y is None:
            continue

        out.append({"text": text, "image": img, "label": int(y)})

        if max_samples is not None and len(out) >= max_samples:
            break

    return out


fusion_shard_pattern = shard_pattern  # reuse train shard from Final/Step_3
fusion_examples = make_fusion_examples(
    fusion_shard_pattern,
    max_samples=None if MODE == "full" else 1000,
)
print(f"Loaded {len(fusion_examples)} multimodal examples for fusion training.")

text_expert_dir = models_root / "text_expert"
vision_expert_dir = models_root / "vision_expert"

text_tokenizer_fusion = AutoTokenizer.from_pretrained(text_expert_dir)
image_processor_fusion = AutoImageProcessor.from_pretrained(vision_expert_dir)

fusion_dataset = FusionDataset(
    fusion_examples,
    text_tokenizer=text_tokenizer_fusion,
    image_processor=image_processor_fusion,
    max_length=256,
    train=True,
)

loader_fusion = DataLoader(fusion_dataset, batch_size=8, shuffle=True)
print("Fusion batches per epoch:", len(loader_fusion))


class FusionModel(nn.Module):
    def __init__(
        self,
        text_encoder: nn.Module,
        vision_encoder: nn.Module,
        t_dim: int,
        v_dim: int,
        hidden_dim: int,
        num_labels: int,
    ) -> None:
        super().__init__()
        self.text_encoder = text_encoder
        self.vision_encoder = vision_encoder
        self.mlp = nn.Sequential(
            nn.Linear(t_dim + v_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, num_labels),
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
        with torch.no_grad():
            text_out = self.text_encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            if hasattr(text_out, "pooler_output") and text_out.pooler_output is not None:
                t_repr = text_out.pooler_output
            else:
                t_repr = text_out.last_hidden_state[:, 0, :]

            vision_out = self.vision_encoder(pixel_values=pixel_values)
            v_repr = vision_out.logits

        h = torch.cat([t_repr, v_repr], dim=-1)
        logits = self.mlp(h)
        return logits


text_encoder_fusion = AutoModel.from_pretrained(text_expert_dir)
vision_encoder_fusion = AutoModelForImageClassification.from_pretrained(vision_expert_dir)

text_encoder_fusion.to(device)
vision_encoder_fusion.to(device)

for p in text_encoder_fusion.parameters():
    p.requires_grad = False
for p in vision_encoder_fusion.parameters():
    p.requires_grad = False

t_dim = text_encoder_fusion.config.hidden_size
v_dim = vision_encoder_fusion.config.num_labels

fusion_hidden = 512
num_labels = 2

fusion_model = FusionModel(
    text_encoder=text_encoder_fusion,
    vision_encoder=vision_encoder_fusion,
    t_dim=t_dim,
    v_dim=v_dim,
    hidden_dim=fusion_hidden,
    num_labels=num_labels,
).to(device)

optimizer_fusion = torch.optim.AdamW(fusion_model.parameters(), lr=1e-4)
criterion_fusion = nn.CrossEntropyLoss()

EPOCHS_FUSION = 1 if MODE == "pilot" else 2

for epoch in range(EPOCHS_FUSION):
    fusion_model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for batch in loader_fusion:
        batch = {k: v.to(device) for k, v in batch.items()}
        labels = batch.pop("labels")

        logits = fusion_model(**batch)
        loss = criterion_fusion(logits, labels)

        optimizer_fusion.zero_grad()
        loss.backward()
        optimizer_fusion.step()

        total_loss += loss.item()
        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / max(1, len(loader_fusion))
    acc = correct / total if total > 0 else 0.0
    print(f"[Fusion] Epoch {epoch + 1}/{EPOCHS_FUSION} - loss: {avg_loss:.4f} - acc: {acc:.3f}")

mm_dir = models_root / "mm_fusion"
mm_dir.mkdir(exist_ok=True)
mm_fusion_path = mm_dir / "fusion_model.pt"
torch.save(fusion_model.state_dict(), mm_fusion_path)
print("Saved fusion model weights to", mm_fusion_path)


## Step 6 — Evaluation (metrics only, results under Final/Step_6)

We evaluate:

- Text expert (`models/text_expert/`)
- Vision expert (`models/vision_expert/`)
- Fusion model (`models/mm_fusion/fusion_model.pt`)

on the packed shard from `Final/Step_3/shards/train/shard-000000.tar` and
save metrics to `Final/Step_6/results_{MODE}.json`.

In [None]:
import numpy as np

from torch.utils.data import DataLoader as TorchDataLoader


def compute_accuracy(preds: np.ndarray, labels: np.ndarray) -> float:
    return float((preds == labels).mean()) if len(labels) > 0 else 0.0


def compute_macro_f1(preds: np.ndarray, labels: np.ndarray, num_classes: int = 2) -> float:
    f1s: List[float] = []
    for c in range(num_classes):
        tp = np.logical_and(preds == c, labels == c).sum()
        fp = np.logical_and(preds == c, labels != c).sum()
        fn = np.logical_and(preds != c, labels == c).sum()

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        if precision + recall == 0:
            f1 = 0.0
        else:
            f1 = 2 * precision * recall / (precision + recall)
        f1s.append(f1)
    return float(np.mean(f1s)) if f1s else 0.0


def compute_brier_score(probs_pos: np.ndarray, labels: np.ndarray) -> float:
    return float(np.mean((probs_pos - labels) ** 2)) if len(labels) > 0 else 0.0


def compute_ece(probs_pos: np.ndarray, labels: np.ndarray, num_bins: int = 10) -> float:
    bins = np.linspace(0.0, 1.0, num_bins + 1)
    ece = 0.0
    n = len(labels)
    if n == 0:
        return 0.0

    for i in range(num_bins):
        mask = (probs_pos >= bins[i]) & (probs_pos < bins[i + 1])
        if not np.any(mask):
            continue
        bin_conf = probs_pos[mask].mean()
        bin_acc = (labels[mask] == (probs_pos[mask] >= 0.5)).mean()
        ece += (mask.sum() / n) * abs(bin_conf - bin_acc)
    return float(ece)


def summarize_metrics(logits: torch.Tensor, labels: torch.Tensor) -> Dict[str, float]:
    probs = torch.softmax(logits, dim=-1).cpu().numpy()
    preds = probs.argmax(axis=-1)
    labels_np = labels.cpu().numpy()
    probs_pos = probs[:, 1]

    acc = compute_accuracy(preds, labels_np)
    macro_f1 = compute_macro_f1(preds, labels_np, num_classes=2)
    brier = compute_brier_score(probs_pos, labels_np)
    ece = compute_ece(probs_pos, labels_np, num_bins=10)

    return {
        "accuracy": acc,
        "macro_f1": macro_f1,
        "brier": brier,
        "ece": ece,
    }


class EvalDataset(torch.utils.data.Dataset):
    def __init__(self, examples: List[Dict[str, Any]]):
        self.examples = examples

    def __len__(self) -> int:
        return len(self.examples)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        return self.examples[idx]


def collate_batch(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    texts = [b["text"] for b in batch]
    images = [b["image"] for b in batch]
    labels = torch.tensor([b["label"] for b in batch], dtype=torch.long)
    return {"texts": texts, "images": images, "labels": labels}


# Reuse fusion_examples format to build eval set from the same shard

eval_examples = fusion_examples

loader_eval = TorchDataLoader(
    EvalDataset(eval_examples),
    batch_size=8,
    shuffle=False,
    collate_fn=collate_batch,
)
print("Eval batches:", len(loader_eval))


def evaluate_text_expert_eval(loader: TorchDataLoader) -> Dict[str, float]:
    model = AutoModelForSequenceClassification.from_pretrained(text_expert_dir)
    model.to(device)
    model.eval()

    all_logits: List[torch.Tensor] = []
    all_labels: List[torch.Tensor] = []

    with torch.no_grad():
        for batch in loader:
            texts = batch["texts"]
            labels = batch["labels"].to(device)

            enc = text_tokenizer_fusion(
                texts,
                padding=True,
                truncation=True,
                max_length=256,
                return_tensors="pt",
            )
            enc = {k: v.to(device) for k, v in enc.items()}

            outputs = model(**enc)
            logits = outputs.logits

            all_logits.append(logits.cpu())
            all_labels.append(labels.cpu())

    logits_cat = torch.cat(all_logits, dim=0)
    labels_cat = torch.cat(all_labels, dim=0)
    return summarize_metrics(logits_cat, labels_cat)


def evaluate_vision_expert_eval(loader: TorchDataLoader) -> Dict[str, float]:
    model = AutoModelForImageClassification.from_pretrained(vision_expert_dir)
    model.to(device)
    model.eval()

    all_logits: List[torch.Tensor] = []
    all_labels: List[torch.Tensor] = []

    with torch.no_grad():
        for batch in loader:
            images = batch["images"]
            labels = batch["labels"].to(device)

            enc = image_processor_fusion(images=images, return_tensors="pt")
            pixel_values = enc["pixel_values"].to(device)

            outputs = model(pixel_values=pixel_values)
            logits = outputs.logits

            all_logits.append(logits.cpu())
            all_labels.append(labels.cpu())

    logits_cat = torch.cat(all_logits, dim=0)
    labels_cat = torch.cat(all_labels, dim=0)
    return summarize_metrics(logits_cat, labels_cat)


def evaluate_fusion_eval(loader: TorchDataLoader) -> Dict[str, float]:
    model = fusion_model  # already trained
    model.eval()

    all_logits: List[torch.Tensor] = []
    all_labels: List[torch.Tensor] = []

    with torch.no_grad():
        for batch in loader:
            texts = batch["texts"]
            images = batch["images"]
            labels = batch["labels"].to(device)

            enc_text = text_tokenizer_fusion(
                texts,
                padding=True,
                truncation=True,
                max_length=256,
                return_tensors="pt",
            )
            enc_text = {k: v.to(device) for k, v in enc_text.items()}

            enc_img = image_processor_fusion(images=images, return_tensors="pt")
            pixel_values = enc_img["pixel_values"].to(device)

            logits = model(
                input_ids=enc_text["input_ids"],
                attention_mask=enc_text["attention_mask"],
                pixel_values=pixel_values,
            )

            all_logits.append(logits.cpu())
            all_labels.append(labels.cpu())

    logits_cat = torch.cat(all_logits, dim=0)
    labels_cat = torch.cat(all_labels, dim=0)
    return summarize_metrics(logits_cat, labels_cat)


results_eval: Dict[str, Dict[str, float]] = {}

print("Evaluating text expert...")
results_eval["text_expert"] = evaluate_text_expert_eval(loader_eval)
print("Text expert:", results_eval["text_expert"])

print("\nEvaluating vision expert...")
results_eval["vision_expert"] = evaluate_vision_expert_eval(loader_eval)
print("Vision expert:", results_eval["vision_expert"])

print("\nEvaluating fusion model...")
results_eval["mm_fusion"] = evaluate_fusion_eval(loader_eval)
print("Fusion model:", results_eval["mm_fusion"])

FINAL_STEP6.mkdir(parents=True, exist_ok=True)
results_path = FINAL_STEP6 / f"results_{MODE}.json"
with results_path.open("w", encoding="utf-8") as f:
    json.dump(results_eval, f, indent=2)

print("\nSaved evaluation metrics to", results_path)


## Step 7 — Calibration, Thresholds, and Error Analysis (Final/Step_7)

This section mirrors `Step_7/step7_analysis.ipynb` but writes all
artifacts under `Final/Step_7/`:

- `calibration_{MODE}.json`
- `thresholds_{MODE}.json`
- `thresholds_curve_*.csv`
- `errors_{MODE}.csv`.

In [None]:
import csv  # Required for threshold curve and error CSV writing

FINAL_STEP7.mkdir(parents=True, exist_ok=True)

labels_tensor = torch.tensor([ex["label"] for ex in eval_examples], dtype=torch.long)

# Reuse logits from evaluation by recomputing them once per model


def collect_logits_text() -> torch.Tensor:
    model = AutoModelForSequenceClassification.from_pretrained(text_expert_dir)
    model.to(device)
    model.eval()

    all_logits: List[torch.Tensor] = []

    with torch.no_grad():
        for batch in loader_eval:
            texts = batch["texts"]
            enc = text_tokenizer_fusion(
                texts,
                padding=True,
                truncation=True,
                max_length=256,
                return_tensors="pt",
            )
            enc = {k: v.to(device) for k, v in enc.items()}
            outputs = model(**enc)
            all_logits.append(outputs.logits.cpu())

    return torch.cat(all_logits, dim=0)


def collect_logits_vision() -> torch.Tensor:
    model = AutoModelForImageClassification.from_pretrained(vision_expert_dir)
    model.to(device)
    model.eval()

    all_logits: List[torch.Tensor] = []

    with torch.no_grad():
        for batch in loader_eval:
            images = batch["images"]
            enc = image_processor_fusion(images=images, return_tensors="pt")
            pixel_values = enc["pixel_values"].to(device)
            outputs = model(pixel_values=pixel_values)
            all_logits.append(outputs.logits.cpu())

    return torch.cat(all_logits, dim=0)


def collect_logits_fusion() -> torch.Tensor:
    model = fusion_model
    model.eval()

    all_logits: List[torch.Tensor] = []

    with torch.no_grad():
        for batch in loader_eval:
            texts = batch["texts"]
            images = batch["images"]

            enc_text = text_tokenizer_fusion(
                texts,
                padding=True,
                truncation=True,
                max_length=256,
                return_tensors="pt",
            )
            enc_text = {k: v.to(device) for k, v in enc_text.items()}

            enc_img = image_processor_fusion(images=images, return_tensors="pt")
            pixel_values = enc_img["pixel_values"].to(device)

            logits = model(
                input_ids=enc_text["input_ids"],
                attention_mask=enc_text["attention_mask"],
                pixel_values=pixel_values,
            )
            all_logits.append(logits.cpu())

    return torch.cat(all_logits, dim=0)


logits_text = collect_logits_text()
logits_vision = collect_logits_vision()
logits_fusion = collect_logits_fusion()

print("Logits shapes:", logits_text.shape, logits_vision.shape, logits_fusion.shape)


def fit_temperature(logits: torch.Tensor, labels: torch.Tensor, max_iter: int = 200, lr: float = 0.01) -> float:
    logits = logits.clone().to(torch.float32)
    labels = labels.clone().to(torch.long)

    T = nn.Parameter(torch.ones(1))
    optimizer = torch.optim.Adam([T], lr=lr)

    for _ in range(max_iter):
        optimizer.zero_grad()
        scaled_logits = logits / T
        loss = nn.functional.cross_entropy(scaled_logits, labels)
        loss.backward()
        optimizer.step()

    return float(T.detach().item())


metrics_before: Dict[str, Dict[str, float]] = {}
metrics_after: Dict[str, Dict[str, float]] = {}
temperatures: Dict[str, float] = {}

# Text expert
metrics_before["text_expert"] = summarize_metrics(logits_text, labels_tensor)
T_text = fit_temperature(logits_text, labels_tensor)
logits_text_cal = logits_text / T_text
metrics_after["text_expert"] = summarize_metrics(logits_text_cal, labels_tensor)
temperatures["text_expert"] = T_text

# Vision expert
metrics_before["vision_expert"] = summarize_metrics(logits_vision, labels_tensor)
T_vision = fit_temperature(logits_vision, labels_tensor)
logits_vision_cal = logits_vision / T_vision
metrics_after["vision_expert"] = summarize_metrics(logits_vision_cal, labels_tensor)
temperatures["vision_expert"] = T_vision

# Fusion model
metrics_before["mm_fusion"] = summarize_metrics(logits_fusion, labels_tensor)
T_fusion = fit_temperature(logits_fusion, labels_tensor)
logits_fusion_cal = logits_fusion / T_fusion
metrics_after["mm_fusion"] = summarize_metrics(logits_fusion_cal, labels_tensor)
temperatures["mm_fusion"] = T_fusion

print("Pre-calibration metrics:")
for name, m in metrics_before.items():
    print(name, m)

print("\nPost-calibration metrics:")
for name, m in metrics_after.items():
    print(name, m)

calib_path = FINAL_STEP7 / f"calibration_{MODE}.json"
with calib_path.open("w", encoding="utf-8") as f:
    json.dump(temperatures, f, indent=2)

print("\nSaved calibration temperatures to", calib_path)


def precision_recall_f1_at_threshold(probs_pos: np.ndarray, labels: np.ndarray, threshold: float):
    preds = (probs_pos >= threshold).astype(int)
    tp = np.logical_and(preds == 1, labels == 1).sum()
    fp = np.logical_and(preds == 1, labels == 0).sum()
    fn = np.logical_and(preds == 0, labels == 1).sum()

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    if precision + recall == 0:
        f1 = 0.0
    else:
        f1 = 2 * precision * recall / (precision + recall)
    return float(precision), float(recall), float(f1)


def threshold_sweep_from_logits(
    logits: torch.Tensor,
    labels: torch.Tensor,
    model_name: str,
    csv_path: Path,
    num_thresholds: int = 17,
):
    probs = torch.softmax(logits, dim=-1).cpu().numpy()
    probs_pos = probs[:, 1]
    labels_np = labels.cpu().numpy()

    thresholds = np.linspace(0.1, 0.9, num_thresholds)
    rows: List[Dict[str, Any]] = []

    best_threshold = 0.5
    best_f1 = -1.0

    for thr in thresholds:
        precision, recall, f1 = precision_recall_f1_at_threshold(probs_pos, labels_np, float(thr))
        rows.append({
            "model": model_name,
            "threshold": float(thr),
            "precision": precision,
            "recall": recall,
            "f1": f1,
        })
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = float(thr)

    with csv_path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=["model", "threshold", "precision", "recall", "f1"])
        writer.writeheader()
        writer.writerows(rows)

    return best_threshold, best_f1


thr_text, f1_text = threshold_sweep_from_logits(
    logits_text_cal,
    labels_tensor,
    "text_expert",
    FINAL_STEP7 / f"thresholds_curve_text_expert_{MODE}.csv",
)

thr_vision, f1_vision = threshold_sweep_from_logits(
    logits_vision_cal,
    labels_tensor,
    "vision_expert",
    FINAL_STEP7 / f"thresholds_curve_vision_expert_{MODE}.csv",
)

thr_fusion, f1_fusion = threshold_sweep_from_logits(
    logits_fusion_cal,
    labels_tensor,
    "mm_fusion",
    FINAL_STEP7 / f"thresholds_curve_mm_fusion_{MODE}.csv",
)

thresholds_summary = {
    "text_expert": {"abuse_hate": thr_text, "best_f1": f1_text},
    "vision_expert": {"abuse_hate": thr_vision, "best_f1": f1_vision},
    "mm_fusion": {"abuse_hate": thr_fusion, "best_f1": f1_fusion},
}

thr_path = FINAL_STEP7 / f"thresholds_{MODE}.json"
with thr_path.open("w", encoding="utf-8") as f:
    json.dump(thresholds_summary, f, indent=2)

print("Recommended thresholds:")
for name, info in thresholds_summary.items():
    print(name, "-> threshold =", info["abuse_hate"], "best F1 =", info["best_f1"])

print("\nSaved threshold curves and summary to", FINAL_STEP7)

# Error table with fusion as primary model

probs_text = torch.softmax(logits_text_cal, dim=-1).cpu().numpy()[:, 1]
probs_vision = torch.softmax(logits_vision_cal, dim=-1).cpu().numpy()[:, 1]
probs_fusion = torch.softmax(logits_fusion_cal, dim=-1).cpu().numpy()[:, 1]

labels_np = labels_tensor.cpu().numpy()

preds_text = (probs_text >= thr_text).astype(int)
preds_vision = (probs_vision >= thr_vision).astype(int)
preds_fusion = (probs_fusion >= thr_fusion).astype(int)

rows: List[Dict[str, Any]] = []

for i, ex in enumerate(eval_examples):
    label = int(labels_np[i])
    pt = float(probs_text[i])
    pv = float(probs_vision[i])
    pf = float(probs_fusion[i])
    yt = int(preds_text[i])
    yv = int(preds_vision[i])
    yf = int(preds_fusion[i])

    if yf == 1 and label == 1:
        err_type = "TP"
    elif yf == 0 and label == 0:
        err_type = "TN"
    elif yf == 1 and label == 0:
        err_type = "FP"
    else:
        err_type = "FN"

    all_agree = int((yt == yv) and (yv == yf))
    fusion_correct_both_wrong = int((yf == label) and (yt != label) and (yv != label))

    rows.append({
        "index": i,
        "text": ex["text"],
        "label": label,
        "text_prob": pt,
        "text_pred": yt,
        "vision_prob": pv,
        "vision_pred": yv,
        "fusion_prob": pf,
        "fusion_pred": yf,
        "fusion_error_type": err_type,
        "all_models_agree": all_agree,
        "fusion_correct_both_wrong": fusion_correct_both_wrong,
    })

errors_path = FINAL_STEP7 / f"errors_{MODE}.csv"
with errors_path.open("w", newline="", encoding="utf-8") as f:
    writer = csv.DictWriter(
        f,
        fieldnames=[
            "index",
            "text",
            "label",
            "text_prob",
            "text_pred",
            "vision_prob",
            "vision_pred",
            "fusion_prob",
            "fusion_pred",
            "fusion_error_type",
            "all_models_agree",
            "fusion_correct_both_wrong",
        ],
    )
    writer.writeheader()
    writer.writerows(rows)

print(f"\nSaved error analysis to {errors_path}")
print(f"Total examples: {len(rows)}")
print(f"Fusion TP: {sum(1 for r in rows if r['fusion_error_type'] == 'TP')}")
print(f"Fusion TN: {sum(1 for r in rows if r['fusion_error_type'] == 'TN')}")
print(f"Fusion FP: {sum(1 for r in rows if r['fusion_error_type'] == 'FP')}")
print(f"Fusion FN: {sum(1 for r in rows if r['fusion_error_type'] == 'FN')}")

## Step 8 — Inference & Simple UI (uses models + Final/Step_7)

Finally, we provide a thin inference wrapper and a small Gradio UI
(similar to `Step_8/step8_inference.ipynb`).

This cell **does not write new files**, but it **loads**:

- `models/text_expert/`
- `models/vision_expert/`
- `models/mm_fusion/fusion_model.pt`
- `Final/Step_7/calibration_{MODE}.json`
- `Final/Step_7/thresholds_{MODE}.json`

and exposes both a Python API and a simple browser UI.

In [None]:
import tempfile
import uuid

import gradio as gr
from PIL import Image  # Required for image loading

# Load calibration + thresholds from Final/Step_7
calib_file = FINAL_STEP7 / f"calibration_{MODE}.json"
thr_file = FINAL_STEP7 / f"thresholds_{MODE}.json"

calibrations: Dict[str, float] = {}
thresholds: Dict[str, Dict[str, float]] = {}

if calib_file.exists():
    with calib_file.open("r", encoding="utf-8") as f:
        calibrations = json.load(f)
else:
    print("[WARN] Calibration file not found, using T=1.0.", calib_file)

if thr_file.exists():
    with thr_file.open("r", encoding="utf-8") as f:
        thresholds = json.load(f)
else:
    print("[WARN] Thresholds file not found, using threshold=0.5.", thr_file)

T_TEXT = float(calibrations.get("text_expert", 1.0))
T_VISION = float(calibrations.get("vision_expert", 1.0))
T_FUSION = float(calibrations.get("mm_fusion", 1.0))

THR_TEXT = float(thresholds.get("text_expert", {}).get("abuse_hate", 0.5))
THR_VISION = float(thresholds.get("vision_expert", {}).get("abuse_hate", 0.5))
THR_FUSION = float(thresholds.get("mm_fusion", {}).get("abuse_hate", 0.5))

print("Loaded calibration + thresholds from Final/Step_7:")
print("T_TEXT =", T_TEXT, "T_VISION =", T_VISION, "T_FUSION =", T_FUSION)
print("THR_TEXT =", THR_TEXT, "THR_VISION =", THR_VISION, "THR_FUSION =", THR_FUSION)


def _load_image(image_path: Path) -> Image.Image:
    if not image_path.exists():
        raise FileNotFoundError(f"Image not found: {image_path}")
    return Image.open(image_path).convert("RGB")


def predict_text(text: str) -> Dict[str, Any]:
    enc_text = text_tokenizer_fusion(
        text,
        padding=True,
        truncation=True,
        max_length=256,
        return_tensors="pt",
    )
    enc_text = {k: v.to(device) for k, v in enc_text.items()}

    with torch.no_grad():
        model = AutoModelForSequenceClassification.from_pretrained(text_expert_dir).to(device)
        logits = model(**enc_text).logits
        logits = logits / T_TEXT
        probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()

    prob_hate = float(probs[1])
    label = int(prob_hate >= THR_TEXT)

    return {
        "model": "text_expert",
        "prob_hate": prob_hate,
        "threshold": THR_TEXT,
        "label": label,
        "probs": probs.tolist(),
    }


def predict_image(image_path: Path) -> Dict[str, Any]:
    img = _load_image(image_path)
    enc_img = image_processor_fusion(images=[img], return_tensors="pt")
    pixel_values = enc_img["pixel_values"].to(device)

    with torch.no_grad():
        model = AutoModelForImageClassification.from_pretrained(vision_expert_dir).to(device)
        logits = model(pixel_values=pixel_values).logits
        logits = logits / T_VISION
        probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()

    prob_hate = float(probs[1])
    label = int(prob_hate >= THR_VISION)

    return {
        "model": "vision_expert",
        "prob_hate": prob_hate,
        "threshold": THR_VISION,
        "label": label,
        "probs": probs.tolist(),
    }


def predict_fusion_infer(text: str, image_path: Path) -> Dict[str, Any]:
    img = _load_image(image_path)

    enc_text = text_tokenizer_fusion(
        text,
        padding=True,
        truncation=True,
        max_length=256,
        return_tensors="pt",
    )
    enc_img = image_processor_fusion(images=[img], return_tensors="pt")

    input_ids = enc_text["input_ids"].to(device)
    attention_mask = enc_text["attention_mask"].to(device)
    pixel_values = enc_img["pixel_values"].to(device)

    with torch.no_grad():
        logits = fusion_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
        )
        logits = logits / T_FUSION
        probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()

    prob_hate = float(probs[1])
    label = int(prob_hate >= THR_FUSION)

    return {
        "model": "mm_fusion",
        "prob_hate": prob_hate,
        "threshold": THR_FUSION,
        "label": label,
        "probs": probs.tolist(),
    }


def predict_post_api(text: Optional[str] = None, image_path: Optional[Path] = None, strategy: str = "auto") -> Dict[str, Any]:
    if text is None and image_path is None:
        raise ValueError("Provide at least one of `text` or `image_path`.")

    strategy = strategy.lower()
    if strategy == "auto":
        if text is not None and image_path is not None:
            return predict_fusion_infer(text, image_path)
        if text is not None:
            return predict_text(text)
        return predict_image(image_path)

    if strategy == "text":
        if text is None:
            raise ValueError("strategy='text' requires `text`.")
        return predict_text(text)

    if strategy == "image":
        if image_path is None:
            raise ValueError("strategy='image' requires `image_path`.")
        return predict_image(image_path)

    if strategy == "fusion":
        if text is None or image_path is None:
            raise ValueError("strategy='fusion' requires both `text` and `image_path`.")
        return predict_fusion_infer(text, image_path)

    raise ValueError(f"Unknown strategy: {strategy}")


def _ui_predict(text: str, image):
    text_in: Optional[str] = text.strip() if text and text.strip() else None

    image_path: Optional[Path] = None
    if image is not None:
        tmp_dir = Path(tempfile.gettempdir())
        tmp_file = tmp_dir / f"mmui_{uuid.uuid4().hex}.png"
        image.save(tmp_file)
        image_path = tmp_file

    if text_in is None and image_path is None:
        return "Please provide text, an image, or both.", {}

    result = predict_post_api(text=text_in, image_path=image_path, strategy="auto")

    label = int(result.get("label", 0))
    prob_hate = float(result.get("prob_hate", 0.0))
    model_name = str(result.get("model", "unknown_model"))

    if label == 1:
        explanation = (
            f"**Predicted: HATEFUL / ABUSIVE**  \n"
            f"Model `{model_name}` gives P(hate) = {prob_hate:.3f}.  \n"
            "This system currently makes a binary decision (hate/abuse vs non-hate); "
            "it does not predict fine-grained types of hate."
        )
    else:
        explanation = (
            f"**Predicted: NOT hateful / abusive**  \n"
            f"Model `{model_name}` gives P(hate) = {prob_hate:.3f} "
            f"(so P(non-hate) ≈ {1.0 - prob_hate:.3f})."
        )

    return explanation, result


with gr.Blocks() as demo:
    gr.Markdown(
        f"""## Multimodal Hate/Abuse Detection Demo ({MODE})

Provide text, an image, or both. The system will automatically choose
between text, image, or fusion models (using calibrated thresholds from
Final/Step_7) to decide whether the content is hateful/abusive.
"""
    )

    with gr.Row():
        text_in = gr.Textbox(
            lines=4,
            label="Post text (optional)",
            placeholder="Paste OCR+caption text or any post text here...",
        )
        image_in = gr.Image(
            type="pil",
            label="Image (optional)",
        )

    run_btn = gr.Button("Run")

    explanation_out = gr.Markdown(label="Explanation")
    raw_out = gr.JSON(label="Raw model output")

    run_btn.click(
        fn=_ui_predict,
        inputs=[text_in, image_in],
        outputs=[explanation_out, raw_out],
    )

# Launch inside the notebook
demo.launch()