# Step 8: Inference and Serving (Pilot)

This notebook provides a simple `predict(text, image_path)` interface
for the **multimodal fusion model** trained in earlier steps.

It uses:
- `models/text_expert/`
- `models/vision_expert/`
- `models/mm_fusion/fusion_model.pt`
- Calibration and thresholds from `Step_7/`.

You can later point this notebook at full-data artifacts (e.g.
`calibration_full.json`, `thresholds_full.json`) without changing the
core logic.

In [1]:
# Install required packages for Step 8 (run once per environment).
# You can skip this cell if everything is already installed.

%pip install --upgrade pip

# Core libraries
%pip install torch torchvision torchaudio

# NLP / vision / utilities
%pip install transformers webdataset accelerate timm sentencepiece pillow

# Lightweight web UI (in-notebook) + compatible huggingface-hub
# We pin huggingface-hub to a version that is compatible with
# transformers>=4.30,<5 and gradio.
%pip install "huggingface-hub>=0.33.5,<1.0" gradio


Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
from pathlib import Path
from typing import Any, Dict, Union, Optional

import json
import tempfile
import uuid

import torch
from torch import nn
from PIL import Image

import gradio as gr

from transformers import (
    AutoTokenizer,
    AutoImageProcessor,
    AutoModel,
    AutoModelForImageClassification,
    AutoModelForSequenceClassification,
)

# Detect project root so this works whether you start Jupyter in the repo root
# or from inside Step_8/.
cwd = Path.cwd().resolve()
if (cwd / "Step_3").is_dir():
    root = cwd
else:
    root = cwd.parent

models_root = root / "models"
text_expert_dir = models_root / "text_expert"
vision_expert_dir = models_root / "vision_expert"
mm_fusion_dir = models_root / "mm_fusion"
mm_fusion_path = mm_fusion_dir / "fusion_model.pt"

step7_dir = root / "Step_7"

# Choose which calibration/threshold files to use.
# For now we default to the pilot artifacts produced in Step 7.
MODE = "pilot"  # or "full" once you run Steps 3–7 on full data

calib_file = step7_dir / f"calibration_{MODE}.json"
thr_file = step7_dir / f"thresholds_{MODE}.json"

# Fallback to pilot filenames if you keep the original names
if not calib_file.exists():
    alt = step7_dir / "calibration_pilot.json"
    if alt.exists():
        calib_file = alt
if not thr_file.exists():
    alt = step7_dir / "thresholds_pilot.json"
    if alt.exists():
        thr_file = alt

calibrations: Dict[str, float] = {}
thresholds: Dict[str, Dict[str, float]] = {}

if calib_file.exists():
    with calib_file.open("r", encoding="utf-8") as f:
        calibrations = json.load(f)
else:
    print("[WARN] Calibration file not found, using T=1.0.", calib_file)

if thr_file.exists():
    with thr_file.open("r", encoding="utf-8") as f:
        thresholds = json.load(f)
else:
    print("[WARN] Thresholds file not found, using threshold=0.5.", thr_file)

# Temperature scaling parameters from Step 7
T_TEXT = float(calibrations.get("text_expert", 1.0))
T_VISION = float(calibrations.get("vision_expert", 1.0))
T_FUSION = float(calibrations.get("mm_fusion", 1.0))

# Decision thresholds from Step 7
THR_TEXT = float(thresholds.get("text_expert", {}).get("abuse_hate", 0.5))
THR_VISION = float(thresholds.get("vision_expert", {}).get("abuse_hate", 0.5))
THR_FUSION = float(thresholds.get("mm_fusion", {}).get("abuse_hate", 0.5))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Project root:", root)
print("Using device:", device)
print("Fusion model path:", mm_fusion_path)
print("Using calibration file:", calib_file)
print("Using thresholds file:", thr_file)
print("T_TEXT =", T_TEXT, "T_VISION =", T_VISION, "T_FUSION =", T_FUSION)
print("THR_TEXT =", THR_TEXT, "THR_VISION =", THR_VISION, "THR_FUSION =", THR_FUSION)


Project root: /Users/yashwanthreddy/Documents/GitHub/DL_Proj
Using device: cpu
Fusion model path: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/models/mm_fusion/fusion_model.pt
Using calibration file: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Step_7/calibration_pilot.json
Using thresholds file: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Step_7/thresholds_pilot.json
T_TEXT = 0.40586668252944946 T_VISION = 1.647761344909668 T_FUSION = 0.7356588840484619
THR_TEXT = 0.2 THR_VISION = 0.15000000000000002 THR_FUSION = 0.15000000000000002


In [3]:
# Fusion model definition (same architecture as in Steps 5–7)


class FusionModel(nn.Module):
    def __init__(
        self,
        text_encoder: nn.Module,
        vision_encoder: nn.Module,
        t_dim: int,
        v_dim: int,
        hidden_dim: int,
        num_labels: int,
    ) -> None:
        super().__init__()
        self.text_encoder = text_encoder
        self.vision_encoder = vision_encoder
        self.mlp = nn.Sequential(
            nn.Linear(t_dim + v_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, num_labels),
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
        # Encoders are frozen; we only use them to produce representations.
        with torch.no_grad():
            text_out = self.text_encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            if hasattr(text_out, "pooler_output") and text_out.pooler_output is not None:
                t_repr = text_out.pooler_output
            else:
                t_repr = text_out.last_hidden_state[:, 0, :]

            vision_out = self.vision_encoder(pixel_values=pixel_values)
            v_repr = vision_out.logits

        h = torch.cat([t_repr, v_repr], dim=-1)
        logits = self.mlp(h)
        return logits


def load_fusion_model() -> FusionModel:
    # Recreate encoders in the same way as training
    text_encoder = AutoModel.from_pretrained(text_expert_dir)
    vision_encoder = AutoModelForImageClassification.from_pretrained(vision_expert_dir)

    text_encoder.to(device)
    vision_encoder.to(device)

    for p in text_encoder.parameters():
        p.requires_grad = False
    for p in vision_encoder.parameters():
        p.requires_grad = False

    t_dim = text_encoder.config.hidden_size
    v_dim = vision_encoder.config.num_labels

    fusion_hidden = 512
    num_labels = 2

    model = FusionModel(
        text_encoder=text_encoder,
        vision_encoder=vision_encoder,
        t_dim=t_dim,
        v_dim=v_dim,
        hidden_dim=fusion_hidden,
        num_labels=num_labels,
    )
    state_dict = torch.load(mm_fusion_path, map_location=device)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    return model


In [4]:
# Instantiate tokenizer, image processor, unimodal experts, and fusion model

text_tokenizer = AutoTokenizer.from_pretrained(text_expert_dir)
image_processor = AutoImageProcessor.from_pretrained(vision_expert_dir)

# Text expert classifier
text_cls_model = AutoModelForSequenceClassification.from_pretrained(text_expert_dir)
text_cls_model.to(device)
text_cls_model.eval()

# Vision expert classifier
vision_cls_model = AutoModelForImageClassification.from_pretrained(vision_expert_dir)
vision_cls_model.to(device)
vision_cls_model.eval()

# Multimodal fusion model
fusion_model = load_fusion_model()

print("Loaded models:")
print(" - text_expert (T_TEXT =", T_TEXT, ", THR_TEXT =", THR_TEXT, ")")
print(" - vision_expert (T_VISION =", T_VISION, ", THR_VISION =", THR_VISION, ")")
print(" - mm_fusion (T_FUSION =", T_FUSION, ", THR_FUSION =", THR_FUSION, ")")


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loaded models:
 - text_expert (T_TEXT = 0.40586668252944946 , THR_TEXT = 0.2 )
 - vision_expert (T_VISION = 1.647761344909668 , THR_VISION = 0.15000000000000002 )
 - mm_fusion (T_FUSION = 0.7356588840484619 , THR_FUSION = 0.15000000000000002 )


In [5]:
def _load_image(image_path: Union[str, Path]) -> Image.Image:
    image_path = Path(image_path)
    if not image_path.exists():
        raise FileNotFoundError(f"Image not found: {image_path}")
    return Image.open(image_path).convert("RGB")


def predict_text(text: str) -> Dict[str, Any]:
    """Predict using the text expert only."""
    enc_text = text_tokenizer(
        text,
        padding=True,
        truncation=True,
        max_length=256,
        return_tensors="pt",
    )
    enc_text = {k: v.to(device) for k, v in enc_text.items()}

    with torch.no_grad():
        logits = text_cls_model(**enc_text).logits
        logits = logits / T_TEXT
        probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()

    prob_hate = float(probs[1])
    label = int(prob_hate >= THR_TEXT)

    return {
        "model": "text_expert",
        "prob_hate": prob_hate,
        "threshold": THR_TEXT,
        "label": label,
        "probs": probs.tolist(),
    }


def predict_image(image_path: Union[str, Path]) -> Dict[str, Any]:
    """Predict using the vision expert only."""
    img = _load_image(image_path)

    enc_img = image_processor(images=[img], return_tensors="pt")
    pixel_values = enc_img["pixel_values"].to(device)

    with torch.no_grad():
        logits = vision_cls_model(pixel_values=pixel_values).logits
        logits = logits / T_VISION
        probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()

    prob_hate = float(probs[1])
    label = int(prob_hate >= THR_VISION)

    return {
        "model": "vision_expert",
        "prob_hate": prob_hate,
        "threshold": THR_VISION,
        "label": label,
        "probs": probs.tolist(),
    }


def predict_fusion(text: str, image_path: Union[str, Path]) -> Dict[str, Any]:
    """Predict using the multimodal fusion model (text + image)."""
    img = _load_image(image_path)

    enc_text = text_tokenizer(
        text,
        padding=True,
        truncation=True,
        max_length=256,
        return_tensors="pt",
    )
    enc_img = image_processor(images=[img], return_tensors="pt")

    input_ids = enc_text["input_ids"].to(device)
    attention_mask = enc_text["attention_mask"].to(device)
    pixel_values = enc_img["pixel_values"].to(device)

    with torch.no_grad():
        logits = fusion_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
        )
        # Apply temperature scaling from Step 7
        logits = logits / T_FUSION
        probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()

    prob_hate = float(probs[1])
    label = int(prob_hate >= THR_FUSION)

    return {
        "model": "mm_fusion",
        "prob_hate": prob_hate,
        "threshold": THR_FUSION,
        "label": label,
        "probs": probs.tolist(),
    }


def predict_post(
    text: Optional[str] = None,
    image_path: Optional[Union[str, Path]] = None,
    strategy: str = "auto",
) -> Dict[str, Any]:
    """High-level prediction API supporting text, image, or both.

    Args:
        text: Text string (OCR + caption or raw text).
        image_path: Path to the meme image file.
        strategy:
            - "auto": fusion if both, else whichever modality is present.
            - "text": force text expert (requires text).
            - "image": force vision expert (requires image_path).
            - "fusion": force fusion model (requires both).

    Returns:
        Prediction dict with probability, label, model name, etc.
    """

    if text is None and image_path is None:
        raise ValueError("Provide at least one of `text` or `image_path`.")

    strategy = strategy.lower()
    if strategy not in {"auto", "text", "image", "fusion"}:
        raise ValueError(f"Unknown strategy: {strategy}")

    if strategy == "auto":
        if text is not None and image_path is not None:
            return predict_fusion(text, image_path)
        if text is not None:
            return predict_text(text)
        return predict_image(image_path)

    if strategy == "text":
        if text is None:
            raise ValueError("strategy='text' requires `text`.")
        return predict_text(text)

    if strategy == "image":
        if image_path is None:
            raise ValueError("strategy='image' requires `image_path`.")
        return predict_image(image_path)

    # strategy == "fusion"
    if text is None or image_path is None:
        raise ValueError("strategy='fusion' requires both `text` and `image_path`.")
    return predict_fusion(text, image_path)


# Backwards-compatible alias: original API name
predict = predict_fusion


In [6]:
# Example usage (adjust paths before running)

# Example combined text (OCR + caption). In practice, you can reuse the
# same text-building logic from Step 3 or Step 4, or pass raw meme text.
example_text = "[OCR] example ocr text [/OCR] [CAP] example caption text [/CAP] <lang=en>"

# TODO: set this to an actual meme image path from your dataset
# For example:
example_image_path = "/Users/yashwanthreddy/Documents/GitHub/DL_Proj/98734.png"
result = predict(example_text, example_image_path)
print(result)

print("Ready to call predict(text, image_path). Set example_image_path and uncomment the call above.")


{'model': 'mm_fusion', 'prob_hate': 0.08315027505159378, 'threshold': 0.15000000000000002, 'label': 0, 'probs': [0.916849672794342, 0.08315027505159378]}
Ready to call predict(text, image_path). Set example_image_path and uncomment the call above.


In [7]:
# Lightweight Gradio UI (final cell)


def _ui_predict(text: str, image):
    """Wrapper for Gradio: takes raw text + PIL image, uses predict_post."""
    # Normalize inputs
    text_in: Optional[str]
    if text is not None and str(text).strip():
        text_in = str(text).strip()
    else:
        text_in = None

    image_path: Optional[str] = None
    if image is not None:
        # Save uploaded image to a temporary file and pass its path
        tmp_dir = Path(tempfile.gettempdir())
        tmp_file = tmp_dir / f"mmui_{uuid.uuid4().hex}.png"
        image.save(tmp_file)
        image_path = str(tmp_file)

    if text_in is None and image_path is None:
        return "Please provide text, an image, or both.", {}

    result = predict_post(text=text_in, image_path=image_path, strategy="auto")

    label = int(result.get("label", 0))
    prob_hate = float(result.get("prob_hate", 0.0))
    model_name = str(result.get("model", "unknown_model"))

    if label == 1:
        explanation = (
            f"**Predicted: HATEFUL / ABUSIVE**  \n"
            f"Model `{model_name}` gives P(hate) = {prob_hate:.3f}.  \n"
            "This system currently makes a binary decision (hate/abuse vs non-hate); "
            "it does not predict fine-grained types of hate."
        )
    else:
        explanation = (
            f"**Predicted: NOT hateful / abusive**  \n"
            f"Model `{model_name}` gives P(hate) = {prob_hate:.3f} "
            f"(so P(non-hate) ≈ {1.0 - prob_hate:.3f})."
        )

    return explanation, result


with gr.Blocks() as demo:
    gr.Markdown(
        """## Multimodal Hate/Abuse Detection Demo (Pilot)

Provide text, an image, or both. The system will automatically choose
between text, image, or fusion models (using calibrated thresholds from
Steps 6–7) to decide whether the content is hateful/abusive.
"""
    )

    with gr.Row():
        text_in = gr.Textbox(
            lines=4,
            label="Post text (optional)",
            placeholder="Paste OCR+caption text or any post text here...",
        )
        image_in = gr.Image(
            type="pil",
            label="Image (optional)",
        )

    run_btn = gr.Button("Run")

    explanation_out = gr.Markdown(label="Explanation")
    raw_out = gr.JSON(label="Raw model output")

    run_btn.click(
        fn=_ui_predict,
        inputs=[text_in, image_in],
        outputs=[explanation_out, raw_out],
    )

# Launch inside the notebook
demo.launch()


* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.


