# Step 7: Thresholds, Calibration, and Robustness (Pilot)

This notebook implements Step 7 for the 50-example pilot shard.

It:
- Re-evaluates the **text expert**, **vision expert**, and **fusion model**.
- Fits **temperature scaling** per model for better calibration.
- Sweeps **decision thresholds** for `abuse_hate` and picks recommendations.
- Generates **CSV/JSON artifacts** and an **error table** under `Step_7/`.


In [1]:
# Install required packages for Step 7 (run once per environment).
# You can skip this cell if everything is already installed.

%pip install --upgrade pip

# Core libraries
%pip install torch torchvision torchaudio

# NLP / vision / datasets / training utilities
%pip install transformers datasets webdataset accelerate timm sentencepiece


Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
from pathlib import Path

import json
import csv
from typing import Any, Dict, List, Optional, Tuple

import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import webdataset as wds

from transformers import (
    AutoTokenizer,
    AutoImageProcessor,
    AutoModel,
    AutoModelForSequenceClassification,
    AutoModelForImageClassification,
)

# Detect project root so this works whether you start Jupyter in the repo root
# or from inside Step_7/.
cwd = Path.cwd().resolve()
if (cwd / "Step_3").is_dir():
    root = cwd
else:
    root = cwd.parent

step3 = root / "Step_3"
shards_dir = step3 / "shards" / "train"
shard_pattern = str(shards_dir / "shard-000000.tar")  # 50-example pilot shard

models_root = root / "models"
text_expert_dir = models_root / "text_expert"
vision_expert_dir = models_root / "vision_expert"
mm_fusion_dir = models_root / "mm_fusion"
mm_fusion_path = mm_fusion_dir / "fusion_model.pt"

step7_dir = root / "Step_7"
step7_dir.mkdir(exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Using device:", device)
print("Shard pattern:", shard_pattern)
print("Text expert dir:", text_expert_dir)
print("Vision expert dir:", vision_expert_dir)
print("Fusion model path:", mm_fusion_path)
print("Step 7 output dir:", step7_dir)


Using device: cpu
Shard pattern: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Step_3/shards/train/shard-000000.tar
Text expert dir: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/models/text_expert
Vision expert dir: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/models/vision_expert
Fusion model path: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/models/mm_fusion/fusion_model.pt
Step 7 output dir: /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Step_7


In [3]:
def make_eval_examples(shard_pattern: str, max_samples: Optional[int] = None) -> List[Dict[str, Any]]:
    """Create a list of {text, image, label} from WebDataset shards for analysis."""

    ds = (
        wds.WebDataset(shard_pattern, shardshuffle=False)
        .decode("pil")
        .to_tuple("txt", "png", "json")
    )

    out: List[Dict[str, Any]] = []
    for text_obj, img, meta_obj in ds:
        # Decode text
        if isinstance(text_obj, (bytes, bytearray)):
            text = text_obj.decode("utf-8", errors="replace")
        else:
            text = str(text_obj)

        # Decode metadata
        if isinstance(meta_obj, (bytes, bytearray)):
            meta = json.loads(meta_obj.decode("utf-8"))
        else:
            meta = meta_obj

        labels = (meta or {}).get("labels", {})
        y = labels.get("abuse_hate")
        if y is None:
            continue

        out.append({
            "text": text,
            "image": img,
            "label": int(y),
        })

        if max_samples is not None and len(out) >= max_samples:
            break

    return out


eval_examples = make_eval_examples(shard_pattern, max_samples=1000)
print(f"Loaded {len(eval_examples)} evaluation examples for Step 7.")


Loaded 50 evaluation examples for Step 7.


In [4]:
class EvalDataset(Dataset):
    def __init__(self, examples: List[Dict[str, Any]]):
        self.examples = examples

    def __len__(self) -> int:
        return len(self.examples)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        return self.examples[idx]


def collate_batch(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    texts = [b["text"] for b in batch]
    images = [b["image"] for b in batch]
    labels = torch.tensor([b["label"] for b in batch], dtype=torch.long)
    return {"texts": texts, "images": images, "labels": labels}


dataset_eval = EvalDataset(eval_examples)
loader_eval = DataLoader(dataset_eval, batch_size=8, shuffle=False, collate_fn=collate_batch)

print("Batches per epoch:", len(loader_eval))


Batches per epoch: 7


In [5]:
def compute_accuracy(preds: np.ndarray, labels: np.ndarray) -> float:
    return float((preds == labels).mean()) if len(labels) > 0 else 0.0


def compute_macro_f1(preds: np.ndarray, labels: np.ndarray, num_classes: int = 2) -> float:
    f1s: List[float] = []
    for c in range(num_classes):
        tp = np.logical_and(preds == c, labels == c).sum()
        fp = np.logical_and(preds == c, labels != c).sum()
        fn = np.logical_and(preds != c, labels == c).sum()

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        if precision + recall == 0:
            f1 = 0.0
        else:
            f1 = 2 * precision * recall / (precision + recall)
        f1s.append(f1)
    return float(np.mean(f1s)) if f1s else 0.0


def compute_brier_score(probs_pos: np.ndarray, labels: np.ndarray) -> float:
    return float(np.mean((probs_pos - labels) ** 2)) if len(labels) > 0 else 0.0


def compute_ece(probs_pos: np.ndarray, labels: np.ndarray, num_bins: int = 10) -> float:
    bins = np.linspace(0.0, 1.0, num_bins + 1)
    ece = 0.0
    n = len(labels)
    if n == 0:
        return 0.0

    for i in range(num_bins):
        mask = (probs_pos >= bins[i]) & (probs_pos < bins[i + 1])
        if not np.any(mask):
            continue
        bin_conf = probs_pos[mask].mean()
        bin_acc = (labels[mask] == (probs_pos[mask] >= 0.5)).mean()
        ece += (mask.sum() / n) * abs(bin_conf - bin_acc)
    return float(ece)


def summarize_metrics(logits: torch.Tensor, labels: torch.Tensor) -> Dict[str, float]:
    probs = torch.softmax(logits, dim=-1).cpu().numpy()
    preds = probs.argmax(axis=-1)
    labels_np = labels.cpu().numpy()
    probs_pos = probs[:, 1]

    acc = compute_accuracy(preds, labels_np)
    macro_f1 = compute_macro_f1(preds, labels_np, num_classes=2)
    brier = compute_brier_score(probs_pos, labels_np)
    ece = compute_ece(probs_pos, labels_np, num_bins=10)

    return {
        "accuracy": acc,
        "macro_f1": macro_f1,
        "brier": brier,
        "ece": ece,
    }


def precision_recall_f1_at_threshold(probs_pos: np.ndarray, labels: np.ndarray, threshold: float) -> Tuple[float, float, float]:
    preds = (probs_pos >= threshold).astype(int)
    tp = np.logical_and(preds == 1, labels == 1).sum()
    fp = np.logical_and(preds == 1, labels == 0).sum()
    fn = np.logical_and(preds == 0, labels == 1).sum()

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    if precision + recall == 0:
        f1 = 0.0
    else:
        f1 = 2 * precision * recall / (precision + recall)
    return float(precision), float(recall), float(f1)


In [6]:
# Load tokenizers / processors and collect logits for text and vision experts

text_tokenizer = AutoTokenizer.from_pretrained(text_expert_dir)
image_processor = AutoImageProcessor.from_pretrained(vision_expert_dir)


def get_text_logits(loader: DataLoader) -> Tuple[torch.Tensor, torch.Tensor]:
    model = AutoModelForSequenceClassification.from_pretrained(text_expert_dir)
    model.to(device)
    model.eval()

    all_logits: List[torch.Tensor] = []
    all_labels: List[torch.Tensor] = []

    with torch.no_grad():
        for batch in loader:
            texts = batch["texts"]
            labels = batch["labels"].to(device)

            enc = text_tokenizer(
                texts,
                padding=True,
                truncation=True,
                max_length=256,
                return_tensors="pt",
            )
            enc = {k: v.to(device) for k, v in enc.items()}

            outputs = model(**enc)
            logits = outputs.logits

            all_logits.append(logits.cpu())
            all_labels.append(labels.cpu())

    logits_cat = torch.cat(all_logits, dim=0)
    labels_cat = torch.cat(all_labels, dim=0)
    return logits_cat, labels_cat


def get_vision_logits(loader: DataLoader) -> Tuple[torch.Tensor, torch.Tensor]:
    model = AutoModelForImageClassification.from_pretrained(vision_expert_dir)
    model.to(device)
    model.eval()

    all_logits: List[torch.Tensor] = []
    all_labels: List[torch.Tensor] = []

    with torch.no_grad():
        for batch in loader:
            images = batch["images"]
            labels = batch["labels"].to(device)

            enc = image_processor(images=images, return_tensors="pt")
            pixel_values = enc["pixel_values"].to(device)

            outputs = model(pixel_values=pixel_values)
            logits = outputs.logits

            all_logits.append(logits.cpu())
            all_labels.append(labels.cpu())

    logits_cat = torch.cat(all_logits, dim=0)
    labels_cat = torch.cat(all_labels, dim=0)
    return logits_cat, labels_cat


print("Collecting logits for text expert...")
logits_text, labels_text = get_text_logits(loader_eval)
print("Text logits shape:", logits_text.shape)

print("Collecting logits for vision expert...")
logits_vision, labels_vision = get_vision_logits(loader_eval)
print("Vision logits shape:", logits_vision.shape)


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Collecting logits for text expert...
Text logits shape: torch.Size([50, 2])
Collecting logits for vision expert...
Vision logits shape: torch.Size([50, 2])


In [7]:
# Fusion model definition and logits collection


class FusionModel(nn.Module):
    def __init__(
        self,
        text_encoder: nn.Module,
        vision_encoder: nn.Module,
        t_dim: int,
        v_dim: int,
        hidden_dim: int,
        num_labels: int,
    ) -> None:
        super().__init__()
        self.text_encoder = text_encoder
        self.vision_encoder = vision_encoder
        self.mlp = nn.Sequential(
            nn.Linear(t_dim + v_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, num_labels),
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
        with torch.no_grad():
            text_out = self.text_encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            if hasattr(text_out, "pooler_output") and text_out.pooler_output is not None:
                t_repr = text_out.pooler_output
            else:
                t_repr = text_out.last_hidden_state[:, 0, :]

            vision_out = self.vision_encoder(pixel_values=pixel_values)
            v_repr = vision_out.logits

        h = torch.cat([t_repr, v_repr], dim=-1)
        logits = self.mlp(h)
        return logits


def load_fusion_model() -> FusionModel:
    text_encoder = AutoModel.from_pretrained(text_expert_dir)
    vision_encoder = AutoModelForImageClassification.from_pretrained(vision_expert_dir)

    text_encoder.to(device)
    vision_encoder.to(device)

    for p in text_encoder.parameters():
        p.requires_grad = False
    for p in vision_encoder.parameters():
        p.requires_grad = False

    t_dim = text_encoder.config.hidden_size
    v_dim = vision_encoder.config.num_labels

    fusion_hidden = 512
    num_labels = 2

    model = FusionModel(
        text_encoder=text_encoder,
        vision_encoder=vision_encoder,
        t_dim=t_dim,
        v_dim=v_dim,
        hidden_dim=fusion_hidden,
        num_labels=num_labels,
    )
    state_dict = torch.load(mm_fusion_path, map_location=device)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    return model


def get_fusion_logits(loader: DataLoader) -> Tuple[torch.Tensor, torch.Tensor]:
    model = load_fusion_model()

    all_logits: List[torch.Tensor] = []
    all_labels: List[torch.Tensor] = []

    with torch.no_grad():
        for batch in loader:
            texts = batch["texts"]
            images = batch["images"]
            labels = batch["labels"].to(device)

            enc_text = text_tokenizer(
                texts,
                padding=True,
                truncation=True,
                max_length=256,
                return_tensors="pt",
            )
            enc_text = {k: v.to(device) for k, v in enc_text.items()}

            enc_img = image_processor(images=images, return_tensors="pt")
            pixel_values = enc_img["pixel_values"].to(device)

            logits = model(
                input_ids=enc_text["input_ids"],
                attention_mask=enc_text["attention_mask"],
                pixel_values=pixel_values,
            )

            all_logits.append(logits.cpu())
            all_labels.append(labels.cpu())

    logits_cat = torch.cat(all_logits, dim=0)
    labels_cat = torch.cat(all_labels, dim=0)
    return logits_cat, labels_cat


print("Collecting logits for fusion model...")
logits_fusion, labels_fusion = get_fusion_logits(loader_eval)
print("Fusion logits shape:", logits_fusion.shape)


Collecting logits for fusion model...
Fusion logits shape: torch.Size([50, 2])


In [8]:
# Temperature scaling per model and pre/post-calibration metrics

# Ensure labels are consistent across models
assert torch.equal(labels_text, labels_vision)
assert torch.equal(labels_text, labels_fusion)
labels = labels_fusion.clone()


def fit_temperature(logits: torch.Tensor, labels: torch.Tensor, max_iter: int = 200, lr: float = 0.01) -> float:
    """Fit a single temperature parameter using cross-entropy minimization."""
    logits = logits.clone().to(torch.float32)
    labels = labels.clone().to(torch.long)

    T = nn.Parameter(torch.ones(1))
    optimizer = torch.optim.Adam([T], lr=lr)

    for _ in range(max_iter):
        optimizer.zero_grad()
        scaled_logits = logits / T
        loss = F.cross_entropy(scaled_logits, labels)
        loss.backward()
        optimizer.step()

    return float(T.detach().item())


metrics_before: Dict[str, Dict[str, float]] = {}
metrics_after: Dict[str, Dict[str, float]] = {}
temperatures: Dict[str, float] = {}

# Text expert
metrics_before["text_expert"] = summarize_metrics(logits_text, labels)
T_text = fit_temperature(logits_text, labels)
logits_text_cal = logits_text / T_text
metrics_after["text_expert"] = summarize_metrics(logits_text_cal, labels)
temperatures["text_expert"] = T_text

# Vision expert
metrics_before["vision_expert"] = summarize_metrics(logits_vision, labels)
T_vision = fit_temperature(logits_vision, labels)
logits_vision_cal = logits_vision / T_vision
metrics_after["vision_expert"] = summarize_metrics(logits_vision_cal, labels)
temperatures["vision_expert"] = T_vision

# Fusion model
metrics_before["mm_fusion"] = summarize_metrics(logits_fusion, labels)
T_fusion = fit_temperature(logits_fusion, labels)
logits_fusion_cal = logits_fusion / T_fusion
metrics_after["mm_fusion"] = summarize_metrics(logits_fusion_cal, labels)
temperatures["mm_fusion"] = T_fusion

print("Pre-calibration metrics:")
for name, m in metrics_before.items():
    print(name, m)

print("\nPost-calibration metrics (temperature scaling):")
for name, m in metrics_after.items():
    print(name, m)

calib_path = step7_dir / "calibration_pilot.json"
with calib_path.open("w", encoding="utf-8") as f:
    json.dump(temperatures, f, indent=2)

print("\nSaved calibration temperatures to", calib_path)


Pre-calibration metrics:
text_expert {'accuracy': 0.9, 'macro_f1': 0.4736842105263158, 'brier': 0.12276667188716842, 'ece': 0.6138291072845459}
vision_expert {'accuracy': 0.9, 'macro_f1': 0.4736842105263158, 'brier': 0.09003111768167933, 'ece': 0.8844988291338086}
mm_fusion {'accuracy': 0.9, 'macro_f1': 0.4736842105263158, 'brier': 0.08550730560385912, 'ece': 0.7489281076192857}

Post-calibration metrics (temperature scaling):
text_expert {'accuracy': 0.9, 'macro_f1': 0.4736842105263158, 'brier': 0.08715975735207579, 'ece': 0.8113224555552006}
vision_expert {'accuracy': 0.9, 'macro_f1': 0.4736842105263158, 'brier': 0.08018320326215748, 'ece': 0.8272676227986813}
mm_fusion {'accuracy': 0.9, 'macro_f1': 0.4736842105263158, 'brier': 0.08318219306679688, 'ece': 0.8204712355136872}

Saved calibration temperatures to /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Step_7/calibration_pilot.json


In [9]:
# Threshold sweeps for each model (using calibrated logits)


def threshold_sweep_from_logits(
    logits: torch.Tensor,
    labels: torch.Tensor,
    model_name: str,
    csv_path: Path,
    num_thresholds: int = 17,
) -> Tuple[float, float]:
    probs = torch.softmax(logits, dim=-1).cpu().numpy()
    probs_pos = probs[:, 1]
    labels_np = labels.cpu().numpy()

    thresholds = np.linspace(0.1, 0.9, num_thresholds)
    rows: List[Dict[str, Any]] = []

    best_threshold = 0.5
    best_f1 = -1.0

    for thr in thresholds:
        precision, recall, f1 = precision_recall_f1_at_threshold(probs_pos, labels_np, float(thr))
        rows.append({
            "model": model_name,
            "threshold": float(thr),
            "precision": precision,
            "recall": recall,
            "f1": f1,
        })
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = float(thr)

    with csv_path.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=["model", "threshold", "precision", "recall", "f1"])
        writer.writeheader()
        writer.writerows(rows)

    return best_threshold, best_f1


thr_text, f1_text = threshold_sweep_from_logits(
    logits_text_cal,
    labels,
    "text_expert",
    step7_dir / "thresholds_curve_text_expert_pilot.csv",
)

thr_vision, f1_vision = threshold_sweep_from_logits(
    logits_vision_cal,
    labels,
    "vision_expert",
    step7_dir / "thresholds_curve_vision_expert_pilot.csv",
)

thr_fusion, f1_fusion = threshold_sweep_from_logits(
    logits_fusion_cal,
    labels,
    "mm_fusion",
    step7_dir / "thresholds_curve_mm_fusion_pilot.csv",
)

thresholds_summary = {
    "text_expert": {"abuse_hate": thr_text, "best_f1": f1_text},
    "vision_expert": {"abuse_hate": thr_vision, "best_f1": f1_vision},
    "mm_fusion": {"abuse_hate": thr_fusion, "best_f1": f1_fusion},
}

thr_path = step7_dir / "thresholds_pilot.json"
with thr_path.open("w", encoding="utf-8") as f:
    json.dump(thresholds_summary, f, indent=2)

print("Recommended thresholds (pilot):")
for name, info in thresholds_summary.items():
    print(name, "-> threshold =", info["abuse_hate"], "best F1 =", info["best_f1"])

print("\nSaved threshold curves and summary to", step7_dir)


Recommended thresholds (pilot):
text_expert -> threshold = 0.2 best F1 = 0.33333333333333337
vision_expert -> threshold = 0.15000000000000002 best F1 = 0.5714285714285715
mm_fusion -> threshold = 0.15000000000000002 best F1 = 0.4444444444444445

Saved threshold curves and summary to /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Step_7


In [10]:
# Error analysis using the fusion model as primary decision maker

probs_text = torch.softmax(logits_text_cal, dim=-1).cpu().numpy()
probs_vision = torch.softmax(logits_vision_cal, dim=-1).cpu().numpy()
probs_fusion = torch.softmax(logits_fusion_cal, dim=-1).cpu().numpy()

probs_text_pos = probs_text[:, 1]
probs_vision_pos = probs_vision[:, 1]
probs_fusion_pos = probs_fusion[:, 1]

labels_np = labels.cpu().numpy()

preds_text = (probs_text_pos >= thr_text).astype(int)
preds_vision = (probs_vision_pos >= thr_vision).astype(int)
preds_fusion = (probs_fusion_pos >= thr_fusion).astype(int)

rows: List[Dict[str, Any]] = []

for i, ex in enumerate(eval_examples):
    label = int(labels_np[i])
    pt = float(probs_text_pos[i])
    pv = float(probs_vision_pos[i])
    pf = float(probs_fusion_pos[i])
    yt = int(preds_text[i])
    yv = int(preds_vision[i])
    yf = int(preds_fusion[i])

    if yf == 1 and label == 1:
        err_type = "TP"
    elif yf == 0 and label == 0:
        err_type = "TN"
    elif yf == 1 and label == 0:
        err_type = "FP"
    else:
        err_type = "FN"

    all_agree = int((yt == yv) and (yv == yf))
    fusion_correct_both_wrong = int((yf == label) and (yt != label) and (yv != label))

    rows.append({
        "index": i,
        "text": ex["text"],
        "label": label,
        "text_prob": pt,
        "text_pred": yt,
        "vision_prob": pv,
        "vision_pred": yv,
        "fusion_prob": pf,
        "fusion_pred": yf,
        "fusion_error_type": err_type,
        "all_models_agree": all_agree,
        "fusion_correct_both_wrong": fusion_correct_both_wrong,
    })

errors_path = step7_dir / "errors_pilot.csv"
with errors_path.open("w", newline="", encoding="utf-8") as f:
    writer = csv.DictWriter(
        f,
        fieldnames=[
            "index",
            "text",
            "label",
            "text_prob",
            "text_pred",
            "vision_prob",
            "vision_pred",
            "fusion_prob",
            "fusion_pred",
            "fusion_error_type",
            "all_models_agree",
            "fusion_correct_both_wrong",
        ],
    )
    writer.writeheader()
    writer.writerows(rows)

# Print a brief summary of fusion error types
fusion_preds = preds_fusion
fp = int(np.logical_and(fusion_preds == 1, labels_np == 0).sum())
fn = int(np.logical_and(fusion_preds == 0, labels_np == 1).sum())
tp = int(np.logical_and(fusion_preds == 1, labels_np == 1).sum())
tn = int(np.logical_and(fusion_preds == 0, labels_np == 0).sum())

print("Saved detailed error table to", errors_path)
print("Fusion confusion counts (TP, TN, FP, FN):", tp, tn, fp, fn)
print("Examples where fusion is correct but both unimodal experts are wrong:", int(sum(r["fusion_correct_both_wrong"] for r in rows)))


Saved detailed error table to /Users/yashwanthreddy/Documents/GitHub/DL_Proj/Step_7/errors_pilot.csv
Fusion confusion counts (TP, TN, FP, FN): 2 43 2 3
Examples where fusion is correct but both unimodal experts are wrong: 1
