# Multimodal Hate/Abuse Detection — Master Pipeline (Steps 2–8)

This notebook orchestrates the **entire pipeline** from Step 2 to Step 8
for the multimodal hate/abuse detection system.

It is designed to be **adaptive**:
- You can run it on the **pilot subset** or the **full dataset**.
- All intermediate artifacts for this notebook are written under
  `Final/` (except **models**, which always live under `models/`).
- Each step is clearly separated so you can re-run only parts as needed.

> **Note:** This notebook *wraps* the existing scripts/notebooks
> (`Step_2`, `Step_3`, `Step_4`, `Step_5`, `Step_6`, `Step_7`, `Step_8`).
> If you want to debug internals of a specific step, you can still open
> the original step notebooks/scripts.


In [1]:
# Install all required packages for the full pipeline (run once per environment).
# You can skip this cell if everything is already installed.

%pip install --upgrade pip

# Core + training libs
%pip install torch torchvision torchaudio

# NLP / vision / datasets / utilities
%pip install transformers datasets webdataset accelerate timm sentencepiece pillow

# OCR stack for Step 2
%pip install paddleocr paddlepaddle

# Lightweight web UI for Step 8
%pip install "huggingface-hub>=0.33.5,<1.0" gradio


Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
# Global configuration and paths

from pathlib import Path
from typing import Optional, Dict, Any, List

import sys

# Robust project root detection:
# Look for the Datasets/ folder which only exists at the true project root
ROOT = Path.cwd().resolve()

# Walk up until we find the actual project root (contains Datasets/ and Step_2/)
project_root = ROOT
for _ in range(5):  # max 5 levels up
    if (project_root / "Datasets").is_dir() and (project_root / "Step_2").is_dir():
        break
    project_root = project_root.parent

# Sanity check
if not (project_root / "Datasets").is_dir():
    raise RuntimeError(f"Could not find project root with Datasets/ folder. Searched from {ROOT}")

FINAL_DIR = project_root / "Final"
FINAL_DIR.mkdir(exist_ok=True)

# Per-step output roots for this master notebook
FINAL_STEP2 = FINAL_DIR / "Step_2"
FINAL_STEP3 = FINAL_DIR / "Step_3"
FINAL_STEP6 = FINAL_DIR / "Step_6"
FINAL_STEP7 = FINAL_DIR / "Step_7"

for p in [FINAL_STEP2, FINAL_STEP3, FINAL_STEP6, FINAL_STEP7]:
    p.mkdir(parents=True, exist_ok=True)

MODE = "pilot"  # "pilot" (small subset) or "full" (all data)

# Pilot limits (set to None for full runs)
if MODE == "pilot":
    OCR_LIMIT: Optional[int] = 50  # number of images for OCR
    PACK_EXAMPLES_LIMIT: Optional[int] = 100  # number of examples to pack
else:
    OCR_LIMIT = None
    PACK_EXAMPLES_LIMIT = None

print("Project root:", project_root)
print("Final outputs root:", FINAL_DIR)
print("MODE:", MODE, "(OCR_LIMIT=", OCR_LIMIT, "PACK_EXAMPLES_LIMIT=", PACK_EXAMPLES_LIMIT, ")")

# Make sure we can import Step_2/Step_3 modules from the project root
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

print("sys.path includes project root:", str(project_root) in sys.path)

Project root: /Users/yashwanthreddy/Documents/GitHub/DL_Proj
Final outputs root: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Final
MODE: pilot (OCR_LIMIT= 50 PACK_EXAMPLES_LIMIT= 100 )
sys.path includes project root: True


## Step 2 — Build Manifest (Final/Step_2/data_manifest.jsonl)

This section wraps `Step_2/build_data_manifest.py` and writes the
manifest for this master notebook under `Final/Step_2/data_manifest.jsonl`.

The manifest still reads raw datasets from `Datasets/`, but **no files
are written into `Step_2/`** by this cell.

In [3]:
import importlib

# Import the Step 2 manifest builder
manifest_mod = importlib.import_module("Step_2.build_data_manifest")

manifest_path = FINAL_STEP2 / "data_manifest.jsonl"
manifest_path.parent.mkdir(parents=True, exist_ok=True)

print("Building manifest at:", manifest_path)
manifest_mod.build_manifest(manifest_path)

# Quick sanity check: show a few lines
print("\nFirst 3 lines of manifest:")
with manifest_path.open("r", encoding="utf-8") as f:
    for i, line in enumerate(f):
        if i >= 3:
            break
        print(line.strip())

Building manifest at: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Final/Step_2/data_manifest.jsonl

First 3 lines of manifest:
{"id": "hateful_memes_train_42953", "dataset": "hateful_memes", "split": "train", "image_path": "/Users/yashwanthreddy/Documents/GitHub/DL_Proj/Datasets/hateful_memes/img/42953.png", "text_raw": "its their character not their color that matters", "labels_raw": 0}
{"id": "hateful_memes_train_23058", "dataset": "hateful_memes", "split": "train", "image_path": "/Users/yashwanthreddy/Documents/GitHub/DL_Proj/Datasets/hateful_memes/img/23058.png", "text_raw": "don't be afraid to love again everyone is not like your ex", "labels_raw": 0}
{"id": "hateful_memes_train_13894", "dataset": "hateful_memes", "split": "train", "image_path": "/Users/yashwanthreddy/Documents/GitHub/DL_Proj/Datasets/hateful_memes/img/13894.png", "text_raw": "putting bows on your pet", "labels_raw": 0}


## Step 2 — Run OCR (Final/Step_2/ocr.jsonl)

This wraps `Step_2/run_ocr.py` and writes OCR results to
`Final/Step_2/ocr.jsonl`. QC images are written to
`Final/Step_2/ocr_qc/`.

Use `MODE` + `OCR_LIMIT` from the config cell to control whether this
is a small pilot run or a full run.

In [4]:
# Import the Step 2 OCR runner and redirect QC directory into Final/Step_2
ocr_mod = importlib.import_module("Step_2.run_ocr")

# Override QC_DIR so QC overlays are written under Final/Step_2
ocr_mod.QC_DIR = FINAL_STEP2 / "ocr_qc"  # type: ignore[attr-defined]

ocr_output_path = FINAL_STEP2 / "ocr.jsonl"

print("Running OCR...")
ocr_mod.run_ocr(
    manifest_path=manifest_path,
    output_path=ocr_output_path,
    qc=True,
    limit=OCR_LIMIT,
    progress_interval=50 if MODE == "pilot" else 500,
)

print("OCR output:", ocr_output_path)


Running OCR...
Loading PaddleOCR model...


[32mCreating model: ('PP-OCRv5_server_det', None)[0m
[32mModel files already exist. Using cached files. To redownload, please delete the directory manually: `/Users/yashwanthreddy/.paddlex/official_models/PP-OCRv5_server_det`.[0m
[32mCreating model: ('en_PP-OCRv5_mobile_rec', None)[0m
[32mModel files already exist. Using cached files. To redownload, please delete the directory manually: `/Users/yashwanthreddy/.paddlex/official_models/en_PP-OCRv5_mobile_rec`.[0m


PaddleOCR model loaded.
Counting images in manifest...
Will process up to 50 images.
  Processed 50/50 images (100.0%) - 0.5 img/s - ETA: 0.0 min

OCR SUMMARY
Total images processed: 51
Images with no OCR text: 0 (failure rate: 0.000)
Total text lines detected: 166
Mean OCR confidence: 0.928
Total time: 1.6 minutes (93 seconds)
Output written to: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Final/Step_2/ocr.jsonl
OCR output: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Final/Step_2/ocr.jsonl


## Step 3 — Pack Manifest + OCR into WebDataset Shards (Final/Step_3)

This wraps `Step_3/pack_examples.py` and writes packed shards under
`Final/Step_3/shards/`. It also writes:

- `Final/Step_3/label_taxonomy.json`
- `Final/Step_3/splits.json`
- `Final/Step_3/stats.json`

We temporarily patch `STEP3` inside `pack_examples.py` so that its
internal stats file is also written under `Final/Step_3/` instead of
`Step_3/`.

In [5]:
# Import the Step 3 packer and redirect its STEP3 root into Final/Step_3
pack_mod = importlib.import_module("Step_3.pack_examples")

# Patch STEP3 so stats.json goes under Final/Step_3
if hasattr(pack_mod, "STEP3"):
    pack_mod.STEP3 = FINAL_STEP3  # type: ignore[attr-defined]

packed_shards_dir = FINAL_STEP3 / "shards"
packed_shards_dir.mkdir(parents=True, exist_ok=True)

label_taxonomy_path = FINAL_STEP3 / "label_taxonomy.json"
splits_path = FINAL_STEP3 / "splits.json"

print("Packing WebDataset shards into:", packed_shards_dir)

pack_mod.run_packing(
    manifest_path=manifest_path,
    ocr_path=ocr_output_path,
    output_dir=packed_shards_dir,
    taxonomy_path=label_taxonomy_path,
    splits_path=splits_path,
    examples_limit=PACK_EXAMPLES_LIMIT,
    shard_size=5000,
    image_size=384,
    progress_interval=500 if MODE == "pilot" else 5000,
    include_datasets=None,
)

print("Shards directory:", packed_shards_dir)


Packing WebDataset shards into: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Final/Step_3/shards
Ensuring label taxonomy exists...
Loading OCR index (this may take a while for large files)...
Loaded OCR for 50 examples.
# writing /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Final/Step_3/shards/train/shard-000000.tar 0 0.0 GB 0

STEP 3 PACKING SUMMARY
Total manifest records seen: 101
Packed examples: 100
Failed examples: 0
Failure rate: 0.0000
Elapsed time: 0.1 minutes (5 seconds)
Stats written to: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Final/Step_3/stats.json
Splits mapping already existed at: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Final/Step_3/splits.json
Shards directory: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Final/Step_3/shards


## Step 4 — Train Unimodal Experts (Text & Image)

This section inlines the training logic from `Step_4/step4_unimodal_experts.ipynb`,
pointing it at the **Final/Step_3** shards. Models are still saved to
`models/text_expert/` and `models/vision_expert/`.

For the pilot mode we train **1 epoch** on the first shard; for full
mode you can point `shard_pattern` to a larger set of shards and/or
increase the number of epochs.

In [6]:
import json

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

import webdataset as wds

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoImageProcessor,
    AutoModelForImageClassification,
)

from torchvision import transforms

models_root = project_root / "models"
models_root.mkdir(exist_ok=True)

# Shard pattern from Final/Step_3
train_shards_dir = FINAL_STEP3 / "shards" / "train"
shard_pattern = str(train_shards_dir / "shard-000000.tar")

print("Training unimodal experts from shards:", shard_pattern)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

TEXT_MODEL_NAME = "microsoft/mdeberta-v3-base"
VISION_MODEL_NAME = "google/siglip-base-patch16-384"


class TextDataset(Dataset):
    def __init__(self, data: List[Dict[str, Any]], tokenizer, max_length: int = 256):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.data[idx]
        text = item["text"]
        label = item["label"]
        enc = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        enc = {k: v.squeeze(0) for k, v in enc.items()}
        enc["labels"] = torch.tensor(label, dtype=torch.long)
        return enc


class ImageDataset(Dataset):
    def __init__(self, data: List[Dict[str, Any]], image_processor, train: bool = True):
        self.data = data
        self.image_processor = image_processor
        self.train = train
        self.aug = transforms.Compose([
            transforms.RandomHorizontalFlip(),
        ])

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.data[idx]
        img = item["image"]
        label = item["label"]
        if self.train:
            img = self.aug(img)
        inputs = self.image_processor(images=img, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze(0)
        return {
            "pixel_values": pixel_values,
            "labels": torch.tensor(label, dtype=torch.long),
        }


def make_text_dataset(shard_pattern: str, max_samples: Optional[int] = None):
    ds = (
        wds.WebDataset(shard_pattern, shardshuffle=False)
        .to_tuple("txt", "json")
    )

    out: List[Dict[str, Any]] = []
    for text_obj, meta_obj in ds:
        if isinstance(text_obj, (bytes, bytearray)):
            text = text_obj.decode("utf-8", errors="replace")
        else:
            text = str(text_obj)

        if isinstance(meta_obj, (bytes, bytearray)):
            meta = json.loads(meta_obj.decode("utf-8"))
        else:
            meta = meta_obj

        labels = (meta or {}).get("labels", {})
        y = labels.get("abuse_hate")
        if y is None:
            continue
        out.append({"text": text, "label": int(y)})
        if max_samples is not None and len(out) >= max_samples:
            break
    return out


def make_image_dataset(shard_pattern: str, max_samples: Optional[int] = None):
    ds = (
        wds.WebDataset(shard_pattern, shardshuffle=False)
        .decode("pil")
        .to_tuple("png", "json")
    )

    out: List[Dict[str, Any]] = []
    for img, meta_obj in ds:
        if isinstance(meta_obj, (bytes, bytearray)):
            meta = json.loads(meta_obj.decode("utf-8"))
        else:
            meta = meta_obj

        labels = (meta or {}).get("labels", {})
        y = labels.get("abuse_hate")
        if y is None:
            continue
        out.append({"image": img, "label": int(y)})
        if max_samples is not None and len(out) >= max_samples:
            break
    return out


text_examples = make_text_dataset(shard_pattern, max_samples=None if MODE == "full" else 1000)
image_examples = make_image_dataset(shard_pattern, max_samples=None if MODE == "full" else 1000)
print(f"Loaded {len(text_examples)} text examples and {len(image_examples)} image examples.")

# --- Train text expert ---

tokenizer_text = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
model_text = AutoModelForSequenceClassification.from_pretrained(
    TEXT_MODEL_NAME,
    num_labels=2,
)
model_text.to(device)

dataset_text = TextDataset(text_examples, tokenizer_text)
loader_text = DataLoader(dataset_text, batch_size=8, shuffle=True)

optimizer_text = torch.optim.AdamW(model_text.parameters(), lr=2e-5)
criterion_text = nn.CrossEntropyLoss()

EPOCHS_TEXT = 1 if MODE == "pilot" else 2

model_text.train()
for epoch in range(EPOCHS_TEXT):
    total_loss = 0.0
    for batch in loader_text:
        batch = {k: v.to(device) for k, v in batch.items()}
        labels = batch.pop("labels")
        outputs = model_text(**batch)
        logits = outputs.logits
        loss = criterion_text(logits, labels)
        optimizer_text.zero_grad()
        loss.backward()
        optimizer_text.step()
        total_loss += loss.item()
    print(f"[Text] Epoch {epoch + 1}/{EPOCHS_TEXT} - loss: {total_loss / len(loader_text):.4f}")

text_expert_dir = models_root / "text_expert"
model_text.save_pretrained(text_expert_dir)
tokenizer_text.save_pretrained(text_expert_dir)
print("Saved text expert to", text_expert_dir)

# --- Train vision expert ---

image_processor = AutoImageProcessor.from_pretrained(VISION_MODEL_NAME)
model_vision = AutoModelForImageClassification.from_pretrained(
    VISION_MODEL_NAME,
    num_labels=2,
)
model_vision.to(device)

dataset_img = ImageDataset(image_examples, image_processor, train=True)
loader_img = DataLoader(dataset_img, batch_size=8, shuffle=True)

optimizer_v = torch.optim.AdamW(model_vision.parameters(), lr=1e-4)
criterion_v = nn.CrossEntropyLoss()

EPOCHS_V = 1 if MODE == "pilot" else 2

model_vision.train()
for epoch in range(EPOCHS_V):
    total_loss = 0.0
    for batch in loader_img:
        batch = {k: v.to(device) for k, v in batch.items()}
        labels = batch.pop("labels")
        outputs = model_vision(**batch)
        logits = outputs.logits
        loss = criterion_v(logits, labels)
        optimizer_v.zero_grad()
        loss.backward()
        optimizer_v.step()
        total_loss += loss.item()
    print(f"[Vision] Epoch {epoch + 1}/{EPOCHS_V} - loss: {total_loss / len(loader_img):.4f}")

vision_expert_dir = models_root / "vision_expert"
model_vision.save_pretrained(vision_expert_dir)
image_processor.save_pretrained(vision_expert_dir)
print("Saved vision expert to", vision_expert_dir)


Training unimodal experts from shards: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Final/Step_3/shards/train/shard-000000.tar
Using device: cpu
Loaded 100 text examples and 100 image examples.


Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/mdeberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Text] Epoch 1/1 - loss: 0.5573


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Saved text expert to /Users/yashwanthreddy/Documents/GitHub/DL_Proj/models/text_expert


Some weights of SiglipForImageClassification were not initialized from the model checkpoint at google/siglip-base-patch16-384 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Vision] Epoch 1/1 - loss: 1.0418
Saved vision expert to /Users/yashwanthreddy/Documents/GitHub/DL_Proj/models/vision_expert


## Step 5 — Train Multimodal Fusion Model

This section mirrors `Step_5/step5_fusion.ipynb` but uses shards under
`Final/Step_3/shards/` and the unimodal experts saved to `models/`.

The fusion model weights are saved to `models/mm_fusion/fusion_model.pt`.

In [7]:
from torch.utils.data import Dataset as TorchDataset
from torchvision import transforms as T

from transformers import AutoModel


class FusionDataset(TorchDataset):
    def __init__(
        self,
        data: List[Dict[str, Any]],
        text_tokenizer,
        image_processor,
        max_length: int = 256,
        train: bool = True,
    ) -> None:
        self.data = data
        self.text_tokenizer = text_tokenizer
        self.image_processor = image_processor
        self.max_length = max_length
        self.train = train
        self.aug = T.Compose([
            T.RandomHorizontalFlip(),
        ])

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.data[idx]
        text = item["text"]
        img = item["image"]
        label = item["label"]

        if self.train:
            img = self.aug(img)

        text_enc = self.text_tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        text_enc = {k: v.squeeze(0) for k, v in text_enc.items()}

        img_enc = self.image_processor(images=img, return_tensors="pt")
        pixel_values = img_enc["pixel_values"].squeeze(0)

        return {
            "input_ids": text_enc["input_ids"],
            "attention_mask": text_enc["attention_mask"],
            "pixel_values": pixel_values,
            "labels": torch.tensor(label, dtype=torch.long),
        }


def make_fusion_examples(shard_pattern: str, max_samples: Optional[int] = None) -> List[Dict[str, Any]]:
    ds = (
        wds.WebDataset(shard_pattern, shardshuffle=False)
        .decode("pil")
        .to_tuple("txt", "png", "json")
    )

    out: List[Dict[str, Any]] = []
    for text_obj, img, meta_obj in ds:
        if isinstance(text_obj, (bytes, bytearray)):
            text = text_obj.decode("utf-8", errors="replace")
        else:
            text = str(text_obj)

        if isinstance(meta_obj, (bytes, bytearray)):
            meta = json.loads(meta_obj.decode("utf-8"))
        else:
            meta = meta_obj

        labels = (meta or {}).get("labels", {})
        y = labels.get("abuse_hate")
        if y is None:
            continue

        out.append({"text": text, "image": img, "label": int(y)})

        if max_samples is not None and len(out) >= max_samples:
            break

    return out


fusion_shard_pattern = shard_pattern  # reuse train shard from Final/Step_3
fusion_examples = make_fusion_examples(
    fusion_shard_pattern,
    max_samples=None if MODE == "full" else 1000,
)
print(f"Loaded {len(fusion_examples)} multimodal examples for fusion training.")

text_expert_dir = models_root / "text_expert"
vision_expert_dir = models_root / "vision_expert"

text_tokenizer_fusion = AutoTokenizer.from_pretrained(text_expert_dir)
image_processor_fusion = AutoImageProcessor.from_pretrained(vision_expert_dir)

fusion_dataset = FusionDataset(
    fusion_examples,
    text_tokenizer=text_tokenizer_fusion,
    image_processor=image_processor_fusion,
    max_length=256,
    train=True,
)

loader_fusion = DataLoader(fusion_dataset, batch_size=8, shuffle=True)
print("Fusion batches per epoch:", len(loader_fusion))


class FusionModel(nn.Module):
    def __init__(
        self,
        text_encoder: nn.Module,
        vision_encoder: nn.Module,
        t_dim: int,
        v_dim: int,
        hidden_dim: int,
        num_labels: int,
    ) -> None:
        super().__init__()
        self.text_encoder = text_encoder
        self.vision_encoder = vision_encoder
        self.mlp = nn.Sequential(
            nn.Linear(t_dim + v_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, num_labels),
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
        with torch.no_grad():
            text_out = self.text_encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            if hasattr(text_out, "pooler_output") and text_out.pooler_output is not None:
                t_repr = text_out.pooler_output
            else:
                t_repr = text_out.last_hidden_state[:, 0, :]

            vision_out = self.vision_encoder(pixel_values=pixel_values)
            v_repr = vision_out.logits

        h = torch.cat([t_repr, v_repr], dim=-1)
        logits = self.mlp(h)
        return logits


text_encoder_fusion = AutoModel.from_pretrained(text_expert_dir)
vision_encoder_fusion = AutoModelForImageClassification.from_pretrained(vision_expert_dir)

text_encoder_fusion.to(device)
vision_encoder_fusion.to(device)

for p in text_encoder_fusion.parameters():
    p.requires_grad = False
for p in vision_encoder_fusion.parameters():
    p.requires_grad = False

t_dim = text_encoder_fusion.config.hidden_size
v_dim = vision_encoder_fusion.config.num_labels

fusion_hidden = 512
num_labels = 2

fusion_model = FusionModel(
    text_encoder=text_encoder_fusion,
    vision_encoder=vision_encoder_fusion,
    t_dim=t_dim,
    v_dim=v_dim,
    hidden_dim=fusion_hidden,
    num_labels=num_labels,
).to(device)

optimizer_fusion = torch.optim.AdamW(fusion_model.parameters(), lr=1e-4)
criterion_fusion = nn.CrossEntropyLoss()

EPOCHS_FUSION = 1 if MODE == "pilot" else 2

for epoch in range(EPOCHS_FUSION):
    fusion_model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for batch in loader_fusion:
        batch = {k: v.to(device) for k, v in batch.items()}
        labels = batch.pop("labels")

        logits = fusion_model(**batch)
        loss = criterion_fusion(logits, labels)

        optimizer_fusion.zero_grad()
        loss.backward()
        optimizer_fusion.step()

        total_loss += loss.item()
        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = total_loss / max(1, len(loader_fusion))
    acc = correct / total if total > 0 else 0.0
    print(f"[Fusion] Epoch {epoch + 1}/{EPOCHS_FUSION} - loss: {avg_loss:.4f} - acc: {acc:.3f}")

mm_dir = models_root / "mm_fusion"
mm_dir.mkdir(exist_ok=True)
mm_fusion_path = mm_dir / "fusion_model.pt"
torch.save(fusion_model.state_dict(), mm_fusion_path)
print("Saved fusion model weights to", mm_fusion_path)


Loaded 100 multimodal examples for fusion training.
Fusion batches per epoch: 13
[Fusion] Epoch 1/1 - loss: 0.5155 - acc: 0.810
Saved fusion model weights to /Users/yashwanthreddy/Documents/GitHub/DL_Proj/models/mm_fusion/fusion_model.pt


## Step 6 — Evaluation (metrics only, results under Final/Step_6)

We evaluate:

- Text expert (`models/text_expert/`)
- Vision expert (`models/vision_expert/`)
- Fusion model (`models/mm_fusion/fusion_model.pt`)

on the packed shard from `Final/Step_3/shards/train/shard-000000.tar` and
save metrics to `Final/Step_6/results_{MODE}.json`.

In [8]:
import numpy as np

from torch.utils.data import DataLoader as TorchDataLoader


def compute_accuracy(preds: np.ndarray, labels: np.ndarray) -> float:
    return float((preds == labels).mean()) if len(labels) > 0 else 0.0


def compute_macro_f1(preds: np.ndarray, labels: np.ndarray, num_classes: int = 2) -> float:
    f1s: List[float] = []
    for c in range(num_classes):
        tp = np.logical_and(preds == c, labels == c).sum()
        fp = np.logical_and(preds == c, labels != c).sum()
        fn = np.logical_and(preds != c, labels == c).sum()

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        if precision + recall == 0:
            f1 = 0.0
        else:
            f1 = 2 * precision * recall / (precision + recall)
        f1s.append(f1)
    return float(np.mean(f1s)) if f1s else 0.0


def compute_brier_score(probs_pos: np.ndarray, labels: np.ndarray) -> float:
    return float(np.mean((probs_pos - labels) ** 2)) if len(labels) > 0 else 0.0


def compute_ece(probs_pos: np.ndarray, labels: np.ndarray, num_bins: int = 10) -> float:
    bins = np.linspace(0.0, 1.0, num_bins + 1)
    ece = 0.0
    n = len(labels)
    if n == 0:
        return 0.0

    for i in range(num_bins):
        mask = (probs_pos >= bins[i]) & (probs_pos < bins[i + 1])
        if not np.any(mask):
            continue
        bin_conf = probs_pos[mask].mean()
        bin_acc = (labels[mask] == (probs_pos[mask] >= 0.5)).mean()
        ece += (mask.sum() / n) * abs(bin_conf - bin_acc)
    return float(ece)


def summarize_metrics(logits: torch.Tensor, labels: torch.Tensor) -> Dict[str, float]:
    probs = torch.softmax(logits, dim=-1).cpu().numpy()
    preds = probs.argmax(axis=-1)
    labels_np = labels.cpu().numpy()
    probs_pos = probs[:, 1]

    acc = compute_accuracy(preds, labels_np)
    macro_f1 = compute_macro_f1(preds, labels_np, num_classes=2)
    brier = compute_brier_score(probs_pos, labels_np)
    ece = compute_ece(probs_pos, labels_np, num_bins=10)

    return {
        "accuracy": acc,
        "macro_f1": macro_f1,
        "brier": brier,
        "ece": ece,
    }


class EvalDataset(torch.utils.data.Dataset):
    def __init__(self, examples: List[Dict[str, Any]]):
        self.examples = examples

    def __len__(self) -> int:
        return len(self.examples)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        return self.examples[idx]


def collate_batch(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    texts = [b["text"] for b in batch]
    images = [b["image"] for b in batch]
    labels = torch.tensor([b["label"] for b in batch], dtype=torch.long)
    return {"texts": texts, "images": images, "labels": labels}


# Reuse fusion_examples format to build eval set from the same shard

eval_examples = fusion_examples

loader_eval = TorchDataLoader(
    EvalDataset(eval_examples),
    batch_size=8,
    shuffle=False,
    collate_fn=collate_batch,
)
print("Eval batches:", len(loader_eval))


def evaluate_text_expert_eval(loader: TorchDataLoader) -> Dict[str, float]:
    model = AutoModelForSequenceClassification.from_pretrained(text_expert_dir)
    model.to(device)
    model.eval()

    all_logits: List[torch.Tensor] = []
    all_labels: List[torch.Tensor] = []

    with torch.no_grad():
        for batch in loader:
            texts = batch["texts"]
            labels = batch["labels"].to(device)

            enc = text_tokenizer_fusion(
                texts,
                padding=True,
                truncation=True,
                max_length=256,
                return_tensors="pt",
            )
            enc = {k: v.to(device) for k, v in enc.items()}

            outputs = model(**enc)
            logits = outputs.logits

            all_logits.append(logits.cpu())
            all_labels.append(labels.cpu())

    logits_cat = torch.cat(all_logits, dim=0)
    labels_cat = torch.cat(all_labels, dim=0)
    return summarize_metrics(logits_cat, labels_cat)


def evaluate_vision_expert_eval(loader: TorchDataLoader) -> Dict[str, float]:
    model = AutoModelForImageClassification.from_pretrained(vision_expert_dir)
    model.to(device)
    model.eval()

    all_logits: List[torch.Tensor] = []
    all_labels: List[torch.Tensor] = []

    with torch.no_grad():
        for batch in loader:
            images = batch["images"]
            labels = batch["labels"].to(device)

            enc = image_processor_fusion(images=images, return_tensors="pt")
            pixel_values = enc["pixel_values"].to(device)

            outputs = model(pixel_values=pixel_values)
            logits = outputs.logits

            all_logits.append(logits.cpu())
            all_labels.append(labels.cpu())

    logits_cat = torch.cat(all_logits, dim=0)
    labels_cat = torch.cat(all_labels, dim=0)
    return summarize_metrics(logits_cat, labels_cat)


def evaluate_fusion_eval(loader: TorchDataLoader) -> Dict[str, float]:
    model = fusion_model  # already trained
    model.eval()

    all_logits: List[torch.Tensor] = []
    all_labels: List[torch.Tensor] = []

    with torch.no_grad():
        for batch in loader:
            texts = batch["texts"]
            images = batch["images"]
            labels = batch["labels"].to(device)

            enc_text = text_tokenizer_fusion(
                texts,
                padding=True,
                truncation=True,
                max_length=256,
                return_tensors="pt",
            )
            enc_text = {k: v.to(device) for k, v in enc_text.items()}

            enc_img = image_processor_fusion(images=images, return_tensors="pt")
            pixel_values = enc_img["pixel_values"].to(device)

            logits = model(
                input_ids=enc_text["input_ids"],
                attention_mask=enc_text["attention_mask"],
                pixel_values=pixel_values,
            )

            all_logits.append(logits.cpu())
            all_labels.append(labels.cpu())

    logits_cat = torch.cat(all_logits, dim=0)
    labels_cat = torch.cat(all_labels, dim=0)
    return summarize_metrics(logits_cat, labels_cat)


results_eval: Dict[str, Dict[str, float]] = {}

print("Evaluating text expert...")
results_eval["text_expert"] = evaluate_text_expert_eval(loader_eval)
print("Text expert:", results_eval["text_expert"])

print("\nEvaluating vision expert...")
results_eval["vision_expert"] = evaluate_vision_expert_eval(loader_eval)
print("Vision expert:", results_eval["vision_expert"])

print("\nEvaluating fusion model...")
results_eval["mm_fusion"] = evaluate_fusion_eval(loader_eval)
print("Fusion model:", results_eval["mm_fusion"])

FINAL_STEP6.mkdir(parents=True, exist_ok=True)
results_path = FINAL_STEP6 / f"results_{MODE}.json"
with results_path.open("w", encoding="utf-8") as f:
    json.dump(results_eval, f, indent=2)

print("\nSaved evaluation metrics to", results_path)


Eval batches: 13
Evaluating text expert...
Text expert: {'accuracy': 0.85, 'macro_f1': 0.45945945945945943, 'brier': 0.12722012048731096, 'ece': 0.6669068136811255}

Evaluating vision expert...
Vision expert: {'accuracy': 0.82, 'macro_f1': 0.5738636363636364, 'brier': 0.13332780932850632, 'ece': 0.5733259925246239}

Evaluating fusion model...
Fusion model: {'accuracy': 0.85, 'macro_f1': 0.45945945945945943, 'brier': 0.12999686437238636, 'ece': 0.7615154530107976}

Saved evaluation metrics to /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Final/Step_6/results_pilot.json


## Step 7 — Calibration, Thresholds, and Error Analysis (Final/Step_7)

This section mirrors `Step_7/step7_analysis.ipynb` but writes all
artifacts under `Final/Step_7/`:

- `calibration_{MODE}.json`
- `thresholds_{MODE}.json`
- `thresholds_curve_*.csv`
- `errors_{MODE}.csv`.

In [9]:
import csv  # Required for threshold curve and error CSV writing

FINAL_STEP7.mkdir(parents=True, exist_ok=True)

labels_tensor = torch.tensor([ex["label"] for ex in eval_examples], dtype=torch.long)

# Reuse logits from evaluation by recomputing them once per model


def collect_logits_text() -> torch.Tensor:
    model = AutoModelForSequenceClassification.from_pretrained(text_expert_dir)
    model.to(device)
    model.eval()

    all_logits: List[torch.Tensor] = []

    with torch.no_grad():
        for batch in loader_eval:
            texts = batch["texts"]
            enc = text_tokenizer_fusion(
                texts,
                padding=True,
                truncation=True,
                max_length=256,
                return_tensors="pt",
            )
            enc = {k: v.to(device) for k, v in enc.items()}
            outputs = model(**enc)
            all_logits.append(outputs.logits.cpu())

    return torch.cat(all_logits, dim=0)


def collect_logits_vision() -> torch.Tensor:
    model = AutoModelForImageClassification.from_pretrained(vision_expert_dir)
    model.to(device)
    model.eval()

    all_logits: List[torch.Tensor] = []

    with torch.no_grad():
        for batch in loader_eval:
            images = batch["images"]
            enc = image_processor_fusion(images=images, return_tensors="pt")
            pixel_values = enc["pixel_values"].to(device)
            outputs = model(pixel_values=pixel_values)
            all_logits.append(outputs.logits.cpu())

    return torch.cat(all_logits, dim=0)


def collect_logits_fusion() -> torch.Tensor:
    model = fusion_model
    model.eval()

    all_logits: List[torch.Tensor] = []

    with torch.no_grad():
        for batch in loader_eval:
            texts = batch["texts"]
            images = batch["images"]

            enc_text = text_tokenizer_fusion(
                texts,
                padding=True,
                truncation=True,
                max_length=256,
                return_tensors="pt",
            )
            enc_text = {k: v.to(device) for k, v in enc_text.items()}

            enc_img = image_processor_fusion(images=images, return_tensors="pt")
            pixel_values = enc_img["pixel_values"].to(device)

            logits = model(
                input_ids=enc_text["input_ids"],
                attention_mask=enc_text["attention_mask"],
                pixel_values=pixel_values,
            )
            all_logits.append(logits.cpu())

    return torch.cat(all_logits, dim=0)


logits_text = collect_logits_text()
logits_vision = collect_logits_vision()
logits_fusion = collect_logits_fusion()

print("Logits shapes:", logits_text.shape, logits_vision.shape, logits_fusion.shape)


def fit_temperature(logits: torch.Tensor, labels: torch.Tensor, max_iter: int = 200, lr: float = 0.01) -> float:
    logits = logits.clone().to(torch.float32)
    labels = labels.clone().to(torch.long)

    T = nn.Parameter(torch.ones(1))
    optimizer = torch.optim.Adam([T], lr=lr)

    for _ in range(max_iter):
        optimizer.zero_grad()
        scaled_logits = logits / T
        loss = nn.functional.cross_entropy(scaled_logits, labels)
        loss.backward()
        optimizer.step()

    return float(T.detach().item())


metrics_before: Dict[str, Dict[str, float]] = {}
metrics_after: Dict[str, Dict[str, float]] = {}
temperatures: Dict[str, float] = {}

# Text expert
metrics_before["text_expert"] = summarize_metrics(logits_text, labels_tensor)
T_text = fit_temperature(logits_text, labels_tensor)
logits_text_cal = logits_text / T_text
metrics_after["text_expert"] = summarize_metrics(logits_text_cal, labels_tensor)
temperatures["text_expert"] = T_text

# Vision expert
metrics_before["vision_expert"] = summarize_metrics(logits_vision, labels_tensor)
T_vision = fit_temperature(logits_vision, labels_tensor)
logits_vision_cal = logits_vision / T_vision
metrics_after["vision_expert"] = summarize_metrics(logits_vision_cal, labels_tensor)
temperatures["vision_expert"] = T_vision

# Fusion model
metrics_before["mm_fusion"] = summarize_metrics(logits_fusion, labels_tensor)
T_fusion = fit_temperature(logits_fusion, labels_tensor)
logits_fusion_cal = logits_fusion / T_fusion
metrics_after["mm_fusion"] = summarize_metrics(logits_fusion_cal, labels_tensor)
temperatures["mm_fusion"] = T_fusion

print("Pre-calibration metrics:")
for name, m in metrics_before.items():
    print(name, m)

print("\nPost-calibration metrics:")
for name, m in metrics_after.items():
    print(name, m)

calib_path = FINAL_STEP7 / f"calibration_{MODE}.json"
with calib_path.open("w", encoding="utf-8") as f:
    json.dump(temperatures, f, indent=2)

print("\nSaved calibration temperatures to", calib_path)


def precision_recall_f1_at_threshold(probs_pos: np.ndarray, labels: np.ndarray, threshold: float):
    preds = (probs_pos >= threshold).astype(int)
    tp = np.logical_and(preds == 1, labels == 1).sum()
    fp = np.logical_and(preds == 1, labels == 0).sum()
    fn = np.logical_and(preds == 0, labels == 1).sum()

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    if precision + recall == 0:
        f1 = 0.0
    else:
        f1 = 2 * precision * recall / (precision + recall)
    return float(precision), float(recall), float(f1)


def threshold_sweep_from_logits(
    logits: torch.Tensor,
    labels: torch.Tensor,
    model_name: str,
    csv_path: Path,
    num_thresholds: int = 17,
):
    probs = torch.softmax(logits, dim=-1).cpu().numpy()
    probs_pos = probs[:, 1]
    labels_np = labels.cpu().numpy()

    thresholds = np.linspace(0.1, 0.9, num_thresholds)
    rows: List[Dict[str, Any]] = []

    best_threshold = 0.5
    best_f1 = -1.0

    for thr in thresholds:
        precision, recall, f1 = precision_recall_f1_at_threshold(probs_pos, labels_np, float(thr))
        rows.append({
            "model": model_name,
            "threshold": float(thr),
            "precision": precision,
            "recall": recall,
            "f1": f1,
        })
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = float(thr)

    with csv_path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=["model", "threshold", "precision", "recall", "f1"])
        writer.writeheader()
        writer.writerows(rows)

    return best_threshold, best_f1


thr_text, f1_text = threshold_sweep_from_logits(
    logits_text_cal,
    labels_tensor,
    "text_expert",
    FINAL_STEP7 / f"thresholds_curve_text_expert_{MODE}.csv",
)

thr_vision, f1_vision = threshold_sweep_from_logits(
    logits_vision_cal,
    labels_tensor,
    "vision_expert",
    FINAL_STEP7 / f"thresholds_curve_vision_expert_{MODE}.csv",
)

thr_fusion, f1_fusion = threshold_sweep_from_logits(
    logits_fusion_cal,
    labels_tensor,
    "mm_fusion",
    FINAL_STEP7 / f"thresholds_curve_mm_fusion_{MODE}.csv",
)

thresholds_summary = {
    "text_expert": {"abuse_hate": thr_text, "best_f1": f1_text},
    "vision_expert": {"abuse_hate": thr_vision, "best_f1": f1_vision},
    "mm_fusion": {"abuse_hate": thr_fusion, "best_f1": f1_fusion},
}

thr_path = FINAL_STEP7 / f"thresholds_{MODE}.json"
with thr_path.open("w", encoding="utf-8") as f:
    json.dump(thresholds_summary, f, indent=2)

print("Recommended thresholds:")
for name, info in thresholds_summary.items():
    print(name, "-> threshold =", info["abuse_hate"], "best F1 =", info["best_f1"])

print("\nSaved threshold curves and summary to", FINAL_STEP7)

# Error table with fusion as primary model

probs_text = torch.softmax(logits_text_cal, dim=-1).cpu().numpy()[:, 1]
probs_vision = torch.softmax(logits_vision_cal, dim=-1).cpu().numpy()[:, 1]
probs_fusion = torch.softmax(logits_fusion_cal, dim=-1).cpu().numpy()[:, 1]

labels_np = labels_tensor.cpu().numpy()

preds_text = (probs_text >= thr_text).astype(int)
preds_vision = (probs_vision >= thr_vision).astype(int)
preds_fusion = (probs_fusion >= thr_fusion).astype(int)

rows: List[Dict[str, Any]] = []

for i, ex in enumerate(eval_examples):
    label = int(labels_np[i])
    pt = float(probs_text[i])
    pv = float(probs_vision[i])
    pf = float(probs_fusion[i])
    yt = int(preds_text[i])
    yv = int(preds_vision[i])
    yf = int(preds_fusion[i])

    if yf == 1 and label == 1:
        err_type = "TP"
    elif yf == 0 and label == 0:
        err_type = "TN"
    elif yf == 1 and label == 0:
        err_type = "FP"
    else:
        err_type = "FN"

    all_agree = int((yt == yv) and (yv == yf))
    fusion_correct_both_wrong = int((yf == label) and (yt != label) and (yv != label))

    rows.append({
        "index": i,
        "text": ex["text"],
        "label": label,
        "text_prob": pt,
        "text_pred": yt,
        "vision_prob": pv,
        "vision_pred": yv,
        "fusion_prob": pf,
        "fusion_pred": yf,
        "fusion_error_type": err_type,
        "all_models_agree": all_agree,
        "fusion_correct_both_wrong": fusion_correct_both_wrong,
    })

errors_path = FINAL_STEP7 / f"errors_{MODE}.csv"
with errors_path.open("w", newline="", encoding="utf-8") as f:
    writer = csv.DictWriter(
        f,
        fieldnames=[
            "index",
            "text",
            "label",
            "text_prob",
            "text_pred",
            "vision_prob",
            "vision_pred",
            "fusion_prob",
            "fusion_pred",
            "fusion_error_type",
            "all_models_agree",
            "fusion_correct_both_wrong",
        ],
    )
    writer.writeheader()
    writer.writerows(rows)

print(f"\nSaved error analysis to {errors_path}")
print(f"Total examples: {len(rows)}")
print(f"Fusion TP: {sum(1 for r in rows if r['fusion_error_type'] == 'TP')}")
print(f"Fusion TN: {sum(1 for r in rows if r['fusion_error_type'] == 'TN')}")
print(f"Fusion FP: {sum(1 for r in rows if r['fusion_error_type'] == 'FP')}")
print(f"Fusion FN: {sum(1 for r in rows if r['fusion_error_type'] == 'FN')}")

Logits shapes: torch.Size([100, 2]) torch.Size([100, 2]) torch.Size([100, 2])
Pre-calibration metrics:
text_expert {'accuracy': 0.85, 'macro_f1': 0.45945945945945943, 'brier': 0.12722012048731096, 'ece': 0.6669068136811255}
vision_expert {'accuracy': 0.82, 'macro_f1': 0.5738636363636364, 'brier': 0.13332780932850632, 'ece': 0.5733259925246239}
mm_fusion {'accuracy': 0.85, 'macro_f1': 0.45945945945945943, 'brier': 0.12999686437238636, 'ece': 0.7615154530107976}

Post-calibration metrics:
text_expert {'accuracy': 0.85, 'macro_f1': 0.45945945945945943, 'brier': 0.12615271780413803, 'ece': 0.7029532849788666}
vision_expert {'accuracy': 0.82, 'macro_f1': 0.5738636363636364, 'brier': 0.11855185513197009, 'ece': 0.6846428109705448}
mm_fusion {'accuracy': 0.85, 'macro_f1': 0.45945945945945943, 'brier': 0.12635051153895138, 'ece': 0.7000653457641601}

Saved calibration temperatures to /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Final/Step_7/calibration_pilot.json
Recommended thresholds:
text

## Step 8 — Inference & Simple UI (uses models + Final/Step_7)

Finally, we provide a thin inference wrapper and a small Gradio UI
(similar to `Step_8/step8_inference.ipynb`).

This cell **does not write new files**, but it **loads**:

- `models/text_expert/`
- `models/vision_expert/`
- `models/mm_fusion/fusion_model.pt`
- `Final/Step_7/calibration_{MODE}.json`
- `Final/Step_7/thresholds_{MODE}.json`

and exposes both a Python API and a simple browser UI.

In [11]:
import tempfile
import uuid

import gradio as gr
from PIL import Image  # Required for image loading

# Load calibration + thresholds from Final/Step_7
calib_file = FINAL_STEP7 / f"calibration_{MODE}.json"
thr_file = FINAL_STEP7 / f"thresholds_{MODE}.json"

calibrations: Dict[str, float] = {}
thresholds: Dict[str, Dict[str, float]] = {}

if calib_file.exists():
    with calib_file.open("r", encoding="utf-8") as f:
        calibrations = json.load(f)
else:
    print("[WARN] Calibration file not found, using T=1.0.", calib_file)

if thr_file.exists():
    with thr_file.open("r", encoding="utf-8") as f:
        thresholds = json.load(f)
else:
    print("[WARN] Thresholds file not found, using threshold=0.5.", thr_file)

T_TEXT = float(calibrations.get("text_expert", 1.0))
T_VISION = float(calibrations.get("vision_expert", 1.0))
T_FUSION = float(calibrations.get("mm_fusion", 1.0))

THR_TEXT = float(thresholds.get("text_expert", {}).get("abuse_hate", 0.5))
THR_VISION = float(thresholds.get("vision_expert", {}).get("abuse_hate", 0.5))
THR_FUSION = float(thresholds.get("mm_fusion", {}).get("abuse_hate", 0.5))

print("Loaded calibration + thresholds from Final/Step_7:")
print("T_TEXT =", T_TEXT, "T_VISION =", T_VISION, "T_FUSION =", T_FUSION)
print("THR_TEXT =", THR_TEXT, "THR_VISION =", THR_VISION, "THR_FUSION =", THR_FUSION)


def _load_image(image_path: Path) -> Image.Image:
    if not image_path.exists():
        raise FileNotFoundError(f"Image not found: {image_path}")
    return Image.open(image_path).convert("RGB")


def predict_text(text: str) -> Dict[str, Any]:
    enc_text = text_tokenizer_fusion(
        text,
        padding=True,
        truncation=True,
        max_length=256,
        return_tensors="pt",
    )
    enc_text = {k: v.to(device) for k, v in enc_text.items()}

    with torch.no_grad():
        model = AutoModelForSequenceClassification.from_pretrained(text_expert_dir).to(device)
        logits = model(**enc_text).logits
        logits = logits / T_TEXT
        probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()

    prob_hate = float(probs[1])
    label = int(prob_hate >= THR_TEXT)

    return {
        "model": "text_expert",
        "prob_hate": prob_hate,
        "threshold": THR_TEXT,
        "label": label,
        "probs": probs.tolist(),
    }


def predict_image(image_path: Path) -> Dict[str, Any]:
    img = _load_image(image_path)
    enc_img = image_processor_fusion(images=[img], return_tensors="pt")
    pixel_values = enc_img["pixel_values"].to(device)

    with torch.no_grad():
        model = AutoModelForImageClassification.from_pretrained(vision_expert_dir).to(device)
        logits = model(pixel_values=pixel_values).logits
        logits = logits / T_VISION
        probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()

    prob_hate = float(probs[1])
    label = int(prob_hate >= THR_VISION)

    return {
        "model": "vision_expert",
        "prob_hate": prob_hate,
        "threshold": THR_VISION,
        "label": label,
        "probs": probs.tolist(),
    }


def predict_fusion_infer(text: str, image_path: Path) -> Dict[str, Any]:
    img = _load_image(image_path)

    enc_text = text_tokenizer_fusion(
        text,
        padding=True,
        truncation=True,
        max_length=256,
        return_tensors="pt",
    )
    enc_img = image_processor_fusion(images=[img], return_tensors="pt")

    input_ids = enc_text["input_ids"].to(device)
    attention_mask = enc_text["attention_mask"].to(device)
    pixel_values = enc_img["pixel_values"].to(device)

    with torch.no_grad():
        logits = fusion_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
        )
        logits = logits / T_FUSION
        probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()

    prob_hate = float(probs[1])
    label = int(prob_hate >= THR_FUSION)

    return {
        "model": "mm_fusion",
        "prob_hate": prob_hate,
        "threshold": THR_FUSION,
        "label": label,
        "probs": probs.tolist(),
    }


def predict_post_api(text: Optional[str] = None, image_path: Optional[Path] = None, strategy: str = "auto") -> Dict[str, Any]:
    if text is None and image_path is None:
        raise ValueError("Provide at least one of `text` or `image_path`.")

    strategy = strategy.lower()
    if strategy == "auto":
        if text is not None and image_path is not None:
            return predict_fusion_infer(text, image_path)
        if text is not None:
            return predict_text(text)
        return predict_image(image_path)

    if strategy == "text":
        if text is None:
            raise ValueError("strategy='text' requires `text`.")
        return predict_text(text)

    if strategy == "image":
        if image_path is None:
            raise ValueError("strategy='image' requires `image_path`.")
        return predict_image(image_path)

    if strategy == "fusion":
        if text is None or image_path is None:
            raise ValueError("strategy='fusion' requires both `text` and `image_path`.")
        return predict_fusion_infer(text, image_path)

    raise ValueError(f"Unknown strategy: {strategy}")


def _ui_predict(text: str, image):
    text_in: Optional[str] = text.strip() if text and text.strip() else None

    image_path: Optional[Path] = None
    if image is not None:
        tmp_dir = Path(tempfile.gettempdir())
        tmp_file = tmp_dir / f"mmui_{uuid.uuid4().hex}.png"
        image.save(tmp_file)
        image_path = tmp_file

    if text_in is None and image_path is None:
        return "Please provide text, an image, or both.", {}

    result = predict_post_api(text=text_in, image_path=image_path, strategy="auto")

    label = int(result.get("label", 0))
    prob_hate = float(result.get("prob_hate", 0.0))
    model_name = str(result.get("model", "unknown_model"))

    if label == 1:
        explanation = (
            f"**Predicted: HATEFUL / ABUSIVE**  \n"
            f"Model `{model_name}` gives P(hate) = {prob_hate:.3f}.  \n"
            "This system currently makes a binary decision (hate/abuse vs non-hate); "
            "it does not predict fine-grained types of hate."
        )
    else:
        explanation = (
            f"**Predicted: NOT hateful / abusive**  \n"
            f"Model `{model_name}` gives P(hate) = {prob_hate:.3f} "
            f"(so P(non-hate) ≈ {1.0 - prob_hate:.3f})."
        )

    return explanation, result


with gr.Blocks() as demo:
    gr.Markdown(
        f"""## Multimodal Hate/Abuse Detection Demo ({MODE})

Provide text, an image, or both. The system will automatically choose
between text, image, or fusion models (using calibrated thresholds from
Final/Step_7) to decide whether the content is hateful/abusive.
"""
    )

    with gr.Row():
        text_in = gr.Textbox(
            lines=4,
            label="Post text (optional)",
            placeholder="Paste OCR+caption text or any post text here...",
        )
        image_in = gr.Image(
            type="pil",
            label="Image (optional)",
        )

    run_btn = gr.Button("Run")

    explanation_out = gr.Markdown(label="Explanation")
    raw_out = gr.JSON(label="Raw model output")

    run_btn.click(
        fn=_ui_predict,
        inputs=[text_in, image_in],
        outputs=[explanation_out, raw_out],
    )

# Launch inside the notebook
demo.launch()

Loaded calibration + thresholds from Final/Step_7:
T_TEXT = 0.8505802750587463 T_VISION = 0.5144212245941162 T_FUSION = 1.3515019416809082
THR_TEXT = 0.15000000000000002 THR_VISION = 0.2 THR_FUSION = 0.2
* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.


