# Sterile Field Detector: Comparative Experimentation

This notebook modifies the original Sterile Field Detector pipeline to benchmark different model configurations and preprocessing techniques. 

**Objective:** Compare the "Current Best" (Top-Hat Preprocessing + CNN/AE) against other variations (Raw Data, Generic Filtering) to validate the architectural decisions.

**Experiments:**
1. **Baseline:** Original Morphological Top-Hat Transform.
2. **No Filter:** Raw video frames (to test robustness without preprocessing).
3. **Gaussian Blur:** Generic noise reduction (to test if specific texture extraction is necessary).

In [None]:
# --- 1. Setup & Dependencies ---
!pip install inference-sdk python-dotenv

import os
import shutil
import pathlib
import cv2
import gdown
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models, Input
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, recall_score, precision_score
from tqdm import tqdm

# Reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print(f"TensorFlow Version: {tf.__version__}")

In [None]:
# --- 2. Dataset Download & Extraction ---
DATASET_LINK = "https://drive.google.com/drive/folders/1gLRc8noJhQjpkD0F6hnvUO92qi5YrbXH?usp=drive_link"
ROOT = pathlib.Path(".").resolve()
DOWNLOAD_DIR = ROOT / "dataset" / "download"
FRAMES_DIR = ROOT / "dataset" / "frames"
VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}

def download_dataset():
    DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
    if list(DOWNLOAD_DIR.glob("*")):
        print("Dataset already downloaded.")
        return
    print("Downloading dataset...")
    try:
        gdown.download_folder(DATASET_LINK, output=str(DOWNLOAD_DIR), quiet=False)
    except:
        # Fallback or manual download message if gdown fails on folder
        print("Please ensure the dataset videos are in", DOWNLOAD_DIR)

def extract_frames(video_path, target_count=500):
    # Extract frames from video
    class_name = video_path.stem.split(' ')[0].lower() # Assumes format like 'Clean.mp4' or 'Trash.mp4'
    out_dir = FRAMES_DIR / class_name
    if out_dir.exists(): return # Skip if done
    out_dir.mkdir(parents=True, exist_ok=True)
    
    cap = cv2.VideoCapture(str(video_path))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    indices = np.linspace(0, total_frames-1, target_count, dtype=int)
    
    count = 0
    for idx in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
        ret, frame = cap.read()
        if ret:
            cv2.imwrite(str(out_dir / f"frame_{count:04d}.jpg"), frame)
            count += 1
    cap.release()

download_dataset()
videos = list(DOWNLOAD_DIR.glob("*.mp4"))
for v in videos: 
    extract_frames(v)

In [None]:
# --- 3. Define Preprocessing Variations ---
# We define multiple filter functions to test different hypotheses

def filter_baseline_tophat(image):
    """The original 'Best' method: TopHat + Contrast Enhancement"""
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    # Bilateral filter to remove noise while keeping edges
    smoothed = cv2.bilateralFilter(gray, 9, 75, 75)
    
    # TopHat to highlight small light objects (crumbs)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (19, 19))
    white_tophat = cv2.morphologyEx(smoothed, cv2.MORPH_TOPHAT, kernel)
    
    # BlackHat to highlight small dark objects (hair)
    black_tophat = cv2.morphologyEx(smoothed, cv2.MORPH_BLACKHAT, kernel)
    
    # Combine: Base + CrumbBoost - HairBoost
    result = smoothed.astype(np.float32) + (white_tophat * 2.0) - (black_tophat * 2.0)
    result = np.clip(result, 0, 255).astype(np.uint8)
    
    # CLAHE for final contrast
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    return clahe.apply(result)

def filter_none(image):
    """Control group: Raw grayscale conversion only"""
    return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

def filter_gaussian(image):
    """Generic noise reduction: Simple Gaussian Blur"""
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    return cv2.GaussianBlur(gray, (5, 5), 0)

In [None]:
# --- 4. Model Definitions ---
# We define factories for the models so we can instantiate fresh ones for each experiment

IMG_SIZE = (128, 128)

def build_cnn_classifier(num_classes=3):
    """Standard VGG-style CNN for classification"""
    model = models.Sequential([
        Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 1)),
        layers.Conv2D(32, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(num_classes, activation='softmax')
    ])
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    return model

def build_autoencoder():
    """Unsupervised Anomaly Detector"""
    input_img = Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 1))
    
    # Encoder
    x = layers.Conv2D(32, (3, 3), activation='relu', padding='same', strides=2)(input_img)
    x = layers.Conv2D(64, (3, 3), activation='relu', padding='same', strides=2)(x)
    encoded = layers.Conv2D(128, (3, 3), activation='relu', padding='same', strides=2)(x)
    
    # Decoder
    x = layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same', strides=2)(encoded)
    x = layers.Conv2DTranspose(32, (3, 3), activation='relu', padding='same', strides=2)(x)
    decoded = layers.Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same', strides=2)(x)
    
    autoencoder = models.Model(input_img, decoded)
    autoencoder.compile(optimizer='adam', loss='mse')
    return autoencoder

In [None]:
# --- 5. Experiment Logic ---

def load_data_with_filter(filter_func):
    """Loads and processes data on the fly"""
    X = []
    y = []
    classes = sorted([d.name for d in FRAMES_DIR.iterdir() if d.is_dir()])
    class_map = {name: i for i, name in enumerate(classes)}
    
    # We limit to a subset to speed up experimentation
    for cls in classes:
        files = list((FRAMES_DIR / cls).glob("*.jpg"))[:300] 
        for f in files:
            img = cv2.imread(str(f))
            processed = filter_func(img)
            processed = cv2.resize(processed, IMG_SIZE)
            
            # Normalize to 0-1
            processed = processed.astype('float32') / 255.0
            processed = np.expand_dims(processed, axis=-1)
            
            X.append(processed)
            y.append(class_map[cls])
            
    return np.array(X), np.array(y), classes

def run_experiment(name, filter_func):
    print(f"\n--- Running Experiment: {name} ---")
    
    # 1. Prepare Data
    X, y, classes = load_data_with_filter(filter_func)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    # One-hot encode for CNN
    y_train_cat = to_categorical(y_train, num_classes=len(classes))
    y_test_cat = to_categorical(y_test, num_classes=len(classes))
    
    # 2. Train CNN
    print(f"Training CNN ({name})...")
    cnn = build_cnn_classifier(len(classes))
    history_cnn = cnn.fit(X_train, y_train_cat, epochs=10, batch_size=32, validation_split=0.1, verbose=0)
    cnn_loss, cnn_acc = cnn.evaluate(X_test, y_test_cat, verbose=0)
    
    # 3. Train Autoencoder (Anomaly Detection)
    # AE trains ONLY on 'clean' class (index 0 usually, assuming alphabetical)
    clean_idx = [i for i, c in enumerate(classes) if 'clean' in c.lower()][0]
    X_clean_train = X_train[y_train == clean_idx]
    
    print(f"Training Autoencoder ({name})...")
    ae = build_autoencoder()
    ae.fit(X_clean_train, X_clean_train, epochs=10, batch_size=32, shuffle=True, verbose=0)
    
    # Evaluate AE: Calculate reconstruction error on Test set
    reconstructions = ae.predict(X_test, verbose=0)
    mse = np.mean(np.power(X_test - reconstructions, 2), axis=(1, 2, 3))
    
    # Simple anomaly classification: If error > threshold -> Anomaly
    # We determine threshold dynamically based on clean test data
    clean_test_mask = (y_test == clean_idx)
    threshold = np.mean(mse[clean_test_mask]) + 2 * np.std(mse[clean_test_mask])
    
    y_pred_anomaly = (mse > threshold).astype(int)
    y_true_anomaly = (y_test != clean_idx).astype(int)
    
    ae_recall = recall_score(y_true_anomaly, y_pred_anomaly)
    
    return {
        "CNN Accuracy": cnn_acc,
        "AE Recall": ae_recall,
        "AE Threshold": threshold
    }

In [None]:
# --- 6. Execute Comparison ---
results = {}

# Define the experiments
experiments = {
    "Baseline (TopHat)": filter_baseline_tophat,
    "No Filter (Raw)": filter_none,
    "Gaussian Blur": filter_gaussian
}

for exp_name, func in experiments.items():
    results[exp_name] = run_experiment(exp_name, func)

print("\n--- FINAL RESULTS ---")
print("{:<20} | {:<15} | {:<15}".format("Experiment", "CNN Accuracy", "AE Recall"))
print("-"*56)
for name, metrics in results.items():
    print("{:<20} | {:.2f}%          | {:.2f}%".format(
        name, 
        metrics['CNN Accuracy']*100, 
        metrics['AE Recall']*100
    ))

In [None]:
# --- 7. Visualization ---
labels = list(results.keys())
cnn_scores = [results[k]['CNN Accuracy'] for k in labels]
ae_scores = [results[k]['AE Recall'] for k in labels]

x = np.arange(len(labels))
width = 0.35

fig, ax = plt.subplots(figsize=(10, 6))
rects1 = ax.bar(x - width/2, cnn_scores, width, label='CNN Accuracy')
rects2 = ax.bar(x + width/2, ae_scores, width, label='AE Recall')

ax.set_ylabel('Score')
ax.set_title('Performance Comparison: Preprocessing Impact')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()
ax.set_ylim(0, 1.1)

def autolabel(rects):
    for rect in rects:
        height = rect.get_height()
        ax.annotate('{:.1f}%'.format(height*100),
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),
                    textcoords="offset points",
                    ha='center', va='bottom')

autolabel(rects1)
autolabel(rects2)
plt.show()