In [None]:
"""
H-DRIFT-M Quality Gate (Binary Classifier): good (x1) vs bad (x2)

Goal:
- Train a model that maps an image I -> y in {0,1}
  1 = good quality (x1)
  0 = bad quality / drifting (x2)

Why:
- Acts as a supervisory quality gate:
  - If bad: trigger re-focus / illumination check / pause / alert
  - If good: continue acquisition

Dataset layout (recommended):
data/
  train/
    good/   (x1 images)
    bad/    (x2 images)
  val/
    good/
    bad/
  test/
    good/
    bad/

Usage:
  python train_quality_gate.py --data_dir data --epochs 20 --model resnet18
"""

In [10]:
import argparse
import os
import random
from dataclasses import dataclass
from typing import Tuple, Dict

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

In [11]:
# --------------------------
# Reproducibility utilities
# --------------------------
def seed_everything(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True


In [12]:
# --------------------------
# MUSE-like degradations
# --------------------------
class RandomDefocus:
    """Simulates focus drift by applying Gaussian blur with random sigma."""
    def __init__(self, p: float = 0.4, sigma_range: Tuple[float, float] = (0.3, 2.0)):
        self.p = p
        self.sigma_range = sigma_range

    def __call__(self, img):
        if random.random() > self.p:
            return img
        sigma = random.uniform(*self.sigma_range)
        # torchvision GaussianBlur expects kernel_size odd and sigma can be float
        blur = transforms.GaussianBlur(kernel_size=9, sigma=sigma)
        return blur(img)


In [13]:
class RandomIlluminationBias:
    """
    Simulates illumination decay/bias:
    - random brightness scaling
    - random gamma adjustment
    - optional vignetting-like effect (approx)
    """
    def __init__(self, p: float = 0.5, brightness=(0.6, 1.2), gamma=(0.7, 1.5)):
        self.p = p
        self.brightness = brightness
        self.gamma = gamma

    def __call__(self, img):
        if random.random() > self.p:
            return img

        # brightness scale
        b = random.uniform(*self.brightness)
        img = transforms.functional.adjust_brightness(img, b)

        # gamma
        g = random.uniform(*self.gamma)
        img = transforms.functional.adjust_gamma(img, g)

        return img


In [14]:
class RandomKnifeArtifacts:
    """
    Simulates knife/tissue surface artifacts:
    - random thin streaks (scratches)
    - random contrast reduction
    NOTE: This is a proxy; real knife artifacts are more complex.
    """
    def __init__(self, p: float = 0.35):
        self.p = p

    def __call__(self, img):
        if random.random() > self.p:
            return img

        # Convert to tensor to inject simple streak-like artifacts
        t = transforms.functional.pil_to_tensor(img).float() / 255.0  # [C,H,W]
        c, h, w = t.shape

        # Add a few random streaks
        num_lines = random.randint(1, 5)
        for _ in range(num_lines):
            y = random.randint(0, h - 1)
            thickness = random.randint(1, 3)
            intensity = random.uniform(0.05, 0.25)
            y0 = max(0, y - thickness)
            y1 = min(h, y + thickness)
            t[:, y0:y1, :] = torch.clamp(t[:, y0:y1, :] + intensity, 0.0, 1.0)

        # Slight contrast reduction (mimic surface degradation)
        t = (t - 0.5) * random.uniform(0.6, 0.95) + 0.5
        t = torch.clamp(t, 0.0, 1.0)

        return transforms.functional.to_pil_image((t * 255.0).byte())



In [15]:
# --------------------------
# Config
# --------------------------
@dataclass
class TrainConfig:
    data_dir: str
    model_name: str = "resnet18"
    image_size: int = 224
    batch_size: int = 32
    epochs: int = 20
    lr: float = 1e-3
    weight_decay: float = 1e-4
    seed: int = 42
    num_workers: int = 4
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    save_path: str = "quality_gate_best.pt"


# --------------------------
# Model builder
# --------------------------
def build_model(model_name: str, num_classes: int = 2) -> nn.Module:
    if model_name == "resnet18":
        m = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        m.fc = nn.Linear(m.fc.in_features, num_classes)
        return m
    if model_name == "resnet50":
        m = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        m.fc = nn.Linear(m.fc.in_features, num_classes)
        return m
    if model_name == "vit_b_16":
        m = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT)
        m.heads.head = nn.Linear(m.heads.head.in_features, num_classes)
        return m
    raise ValueError(f"Unknown model_name={model_name}. Use resnet18/resnet50/vit_b_16.")


In [16]:
# --------------------------
# Metrics
# --------------------------
@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: str) -> Dict[str, float]:
    model.eval()
    correct = 0
    total = 0
    tp = fp = tn = fn = 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        pred = torch.argmax(logits, dim=1)

        correct += (pred == y).sum().item()
        total += y.numel()

        # Assume label mapping: bad=0, good=1 (because folders alphabetical => bad, good)
        tp += ((pred == 1) & (y == 1)).sum().item()
        tn += ((pred == 0) & (y == 0)).sum().item()
        fp += ((pred == 1) & (y == 0)).sum().item()
        fn += ((pred == 0) & (y == 1)).sum().item()

    acc = correct / max(total, 1)
    precision = tp / max(tp + fp, 1)
    recall = tp / max(tp + fn, 1)
    f1 = 2 * precision * recall / max(precision + recall, 1e-12)

    return {"acc": acc, "precision_good": precision, "recall_good": recall, "f1_good": f1}



In [17]:
# --------------------------
# Training
# --------------------------
def main(cfg: TrainConfig) -> None:
    seed_everything(cfg.seed)

    train_dir = os.path.join(cfg.data_dir, "train")
    val_dir = os.path.join(cfg.data_dir, "val")
    test_dir = os.path.join(cfg.data_dir, "test")

    # NOTE: ImageFolder assigns class indices alphabetically.
    # If your folders are named: bad/, good/
    # then class_to_idx will be {"bad": 0, "good": 1}
    base_aug = [
        transforms.Resize((cfg.image_size, cfg.image_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.1),
        transforms.RandomApply([transforms.ColorJitter(contrast=0.2)], p=0.3),
        RandomDefocus(p=0.35),
        RandomIlluminationBias(p=0.4),
        RandomKnifeArtifacts(p=0.25),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225)),
    ]
    train_tf = transforms.Compose(base_aug)

    val_tf = transforms.Compose([
        transforms.Resize((cfg.image_size, cfg.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225)),
    ])

    train_ds = datasets.ImageFolder(train_dir, transform=train_tf)
    val_ds = datasets.ImageFolder(val_dir, transform=val_tf)
    test_ds = datasets.ImageFolder(test_dir, transform=val_tf)

    print("Class mapping:", train_ds.class_to_idx)


In [23]:
# -------------------------
# Main training
# -------------------------
def main():
    cfg = TrainConfig()
    seed_everything()

    # Transforms
    train_tf = transforms.Compose([
        transforms.Resize((cfg.image_size, cfg.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.GaussianBlur(kernel_size=9, sigma=(0.5, 3.0)),  # simulate defocus
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225)),
    ])

    val_tf = transforms.Compose([
        transforms.Resize((cfg.image_size, cfg.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225)),
    ])

    # Datasets
    train_ds = datasets.ImageFolder(os.path.join(cfg.data_dir, "train"), transform=train_tf)
    val_ds = datasets.ImageFolder(os.path.join(cfg.data_dir, "val"), transform=val_tf)
    test_ds = datasets.ImageFolder(os.path.join(cfg.data_dir, "test"), transform=val_tf)

    print("Class mapping:", train_ds.class_to_idx)

    # DataLoaders
    train_loader = DataLoader(
        train_ds,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_ds,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True
    )

In [24]:

# -------------------------
# Main training
# -------------------------
def main():
    cfg = TrainConfig()
    seed_everything()

    # Transforms
    train_tf = transforms.Compose([
        transforms.Resize((cfg.image_size, cfg.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.GaussianBlur(kernel_size=9, sigma=(0.5, 3.0)),  # simulate defocus
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225)),
    ])

    val_tf = transforms.Compose([
        transforms.Resize((cfg.image_size, cfg.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225)),
    ])

    # Datasets
    train_ds = datasets.ImageFolder(os.path.join(cfg.data_dir, "train"), transform=train_tf)
    val_ds = datasets.ImageFolder(os.path.join(cfg.data_dir, "val"), transform=val_tf)
    test_ds = datasets.ImageFolder(os.path.join(cfg.data_dir, "test"), transform=val_tf)

    print("Class mapping:", train_ds.class_to_idx)

    # DataLoaders
    train_loader = DataLoader(
        train_ds,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_ds,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True
    )


In [26]:
# -------------------------
# Model builder
# -------------------------
def build_model(model_name: str, num_classes: int = 2) -> nn.Module:
    if model_name == "resnet18":
        model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        return model
    else:
        raise ValueError("Only resnet18 is supported in this script.")

In [None]:
# Evaluation
# -------------------------
@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: str) -> Dict[str, float]:
    model.eval()
    correct = 0
    total = 0
    tp = fp = fn = 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        preds = torch.argmax(logits, dim=1)

        correct += (preds == y).sum().item()
        total += y.size(0)

        tp += ((preds == 1) & (y == 1)).sum().item()
        fp += ((preds == 1) & (y == 0)).sum().item()
        fn += ((preds == 0) & (y == 1)).sum().item()

    acc = correct / max(total, 1)
    precision = tp / max(tp + fp, 1)
    recall = tp / max(tp + fn, 1)
    f1 = 2 * precision * recall / max(precision + recall, 1e-12)

    return {"acc": acc, "f1_good": f1}

In [28]:
# -------------------------
# Main training
# -------------------------
def main():
    cfg = TrainConfig()
    seed_everything()

    # Transforms
    train_tf = transforms.Compose([
        transforms.Resize((cfg.image_size, cfg.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.GaussianBlur(kernel_size=9, sigma=(0.5, 3.0)),  # simulate defocus
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225)),
    ])

    val_tf = transforms.Compose([
        transforms.Resize((cfg.image_size, cfg.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225)),
    ])

In [29]:
    # Datasets
    train_ds = datasets.ImageFolder(os.path.join(cfg.data_dir, "train"), transform=train_tf)
    val_ds = datasets.ImageFolder(os.path.join(cfg.data_dir, "val"), transform=val_tf)
    test_ds = datasets.ImageFolder(os.path.join(cfg.data_dir, "test"), transform=val_tf)

    print("Class mapping:", train_ds.class_to_idx)

    # DataLoaders
    train_loader = DataLoader(
        train_ds,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_ds,
        batch_size=cfg.batch_size,
        shuffle=False,
        num_workers=cfg.num_workers,
        pin_memory=True
    )

NameError: name 'cfg' is not defined

In [None]:


    # Model
    model = build_model(cfg.model_name, num_classes=2).to(cfg.device)

    # Class weights (handle imbalance)
    labels = [y for _, y in train_ds.samples]
    counts = np.bincount(labels)
    weights = counts.sum() / np.maximum(counts, 1)
    class_weights = torch.tensor(weights, dtype=torch.float32, device=cfg.device)

    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.epochs)

    best_f1 = -1.0

    # Training loop
    for epoch in range(1, cfg.epochs + 1):
        model.train()
        running_loss = 0.0

        for x, y in train_loader:
            x, y = x.to(cfg.device), y.to(cfg.device)

            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * y.size(0)

        scheduler.step()
        train_loss = running_loss / max(len(train_ds), 1)

        val_metrics = evaluate(model, val_loader, cfg.device)

        print(
            f"Epoch {epoch:02d}/{cfg.epochs} | "
            f"train_loss={train_loss:.4f} | "
            f"val_acc={val_metrics['acc']:.4f} | "
            f"val_f1_good={val_metrics['f1_good']:.4f}"
        )

        if val_metrics["f1_good"] > best_f1:
            best_f1 = val_metrics["f1_good"]
            torch.save(model.state_dict(), cfg.save_path)

    print("Training complete. Best model saved to:", cfg.save_path)


if __name__ == "__main__":
    main()
