In [None]:
# --- Setup & Imports ---
import os
import sys
import math
import shutil
import pathlib
import argparse
from typing import Dict, List, Tuple

import cv2
import gdown
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models
from tqdm import tqdm

# Ensure reproducible results
np.random.seed(123)
tf.random.set_seed(123)

print(f"TensorFlow Version: {tf.__version__}")
print(f"Num GPUs Available: {len(tf.config.list_physical_devices('GPU'))}")

## 1. Dataset Preparation
Downloads the dataset from Google Drive and extracts frames.

In [None]:
# --- Configuration ---
DATASET_LINK = "https://drive.google.com/drive/folders/1gLRc8noJhQjpkD0F6hnvUO92qi5YrbXH?usp=drive_link"
ROOT = pathlib.Path(".").resolve()
DOWNLOAD_DIR = ROOT / "dataset" / "download"
FRAMES_DIR = ROOT / "dataset" / "frames"
VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
CLASS_KEYS = ["sterile", "hair", "trash", "both"]

def slugify(stem: str) -> str:
    s = stem.lower().strip().replace(" ", "-").replace("_", "-")
    s = "".join(ch for ch in s if ch.isalnum() or ch == "-")
    return s or "video"

def download_dataset(force: bool = False) -> None:
    DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
    existing = [p for p in DOWNLOAD_DIR.rglob("*") if p.suffix.lower() in VIDEO_EXTS]
    if existing and not force:
        print(f"Found {len(existing)} local video(s); skipping download.")
        return
    print(f"Downloading dataset into: {DOWNLOAD_DIR}")
    try:
        gdown.download_folder(DATASET_LINK, output=str(DOWNLOAD_DIR), quiet=False, use_cookies=False)
    except Exception:
        gdown.download(DATASET_LINK, output=str(DOWNLOAD_DIR), quiet=False)

def extract_frames(video_path: Path, target_count: int, out_dir: Path) -> int:
    out_dir.mkdir(parents=True, exist_ok=True)
    cap = cv2.VideoCapture(str(video_path))
    if not cap.isOpened() or target_count <= 0: return 0
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
    if total <= 0: cap.release(); return 0
    
    desired = min(target_count, total)
    indices = np.linspace(0, total - 1, num=desired, dtype=np.int64)
    saved = 0
    for fidx in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, int(fidx))
        ret, frame = cap.read()
        if ret and frame is not None:
            cv2.imwrite(str(out_dir / f"frame_{saved:05d}.png"), frame)
            saved += 1
    cap.release()
    return saved

# --- Run Dataset Builder ---
# Uncomment to run download and extraction
# download_dataset()
# videos = sorted([p for p in DOWNLOAD_DIR.rglob("*") if p.suffix.lower() in VIDEO_EXTS])
# for v in videos:
#     out = FRAMES_DIR / slugify(v.stem)
#     extract_frames(v, 1000, out)

## 2. Preprocessing
Applies morphological filters (TopHat) and contrast enhancement to highlight hair and crumbs while removing large noise (logos).

In [None]:
def remove_large_noise(tophat_image, threshold_value=15, max_area=80):
    _, binary = cv2.threshold(tophat_image, threshold_value, 255, cv2.THRESH_BINARY)
    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    clean_mask = np.zeros_like(tophat_image)
    for cnt in contours:
        if 0 < cv2.contourArea(cnt) < max_area:
            cv2.drawContours(clean_mask, [cnt], -1, 255, -1)
    return cv2.bitwise_and(tophat_image, tophat_image, mask=clean_mask)

def morphological_contrast_enhancement(image, kernel_size=19, crumb_boost=4.0, hair_boost=4.0, shadow_gamma=0.6):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    smoothed = cv2.bilateralFilter(gray, d=9, sigmaColor=75, sigmaSpace=75)
    norm_img = smoothed.astype(np.float32) / 255.0
    lifted = np.power(norm_img, shadow_gamma)
    lifted_uint8 = (lifted * 255).astype(np.uint8)
    
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
    white_tophat = cv2.morphologyEx(lifted_uint8, cv2.MORPH_TOPHAT, kernel)
    white_tophat_clean = remove_large_noise(white_tophat, threshold_value=10, max_area=60)
    black_tophat = cv2.morphologyEx(lifted_uint8, cv2.MORPH_BLACKHAT, kernel)
    
    flat_background = np.full_like(lifted_uint8, 128, dtype=np.float32)
    result = flat_background + (white_tophat_clean.astype(np.float32) * crumb_boost)
    result -= (black_tophat.astype(np.float32) * hair_boost)
    
    result = np.clip(result, 0, 255).astype(np.uint8)
    clahe = cv2.createCLAHE(clipLimit=5.0, tileGridSize=(8, 8))
    return clahe.apply(result)

def process_images(input_dir, output_dir, num_images_per_folder=1000):
    if os.path.exists(output_dir): shutil.rmtree(output_dir)
    os.makedirs(output_dir)
    subfolders = [f for f in os.listdir(input_dir) if os.path.isdir(os.path.join(input_dir, f))]
    if not subfolders: subfolders = ['.']
    
    for folder in subfolders:
        in_path = os.path.join(input_dir, folder)
        out_path = os.path.join(output_dir, folder)
        os.makedirs(out_path, exist_ok=True)
        
        files = sorted([f for f in os.listdir(in_path) if f.lower().endswith(('.png', '.jpg'))])[:num_images_per_folder]
        for f in tqdm(files, desc=f"Processing {folder}"):
            img = cv2.imread(os.path.join(in_path, f))
            if img is not None:
                res = morphological_contrast_enhancement(img)
                cv2.imwrite(os.path.join(out_path, f"{os.path.splitext(f)[0]}_filtered.png"), res)

# --- Run Preprocessing ---
# process_images('dataset/frames', 'filtering/processed')

## 3. Autoencoder (Anomaly Detection)
Trains a Convolutional Autoencoder on clean images. Anomalies are detected based on reconstruction error.

In [None]:
# --- Autoencoder Configuration ---
AE_BATCH_SIZE = 32
IMG_SIZE = (128, 128)
AE_EPOCHS = 20

data_dir = pathlib.Path("filtering/processed")
clean_dir = data_dir / "clean"
hair_dir = data_dir / "hair"
trash_dir = data_dir / "trash"
trash_hair_dir = data_dir / "trash-hair"

def load_image_ae(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img, channels=1)
    img = tf.image.resize(img, IMG_SIZE)
    img = img / 255.0
    return img

def get_files(directory, limit=None):
    if not directory.exists(): return []
    files = sorted([str(p) for p in directory.glob("*") if p.suffix.lower() in ['.png', '.jpg']])
    return files[:limit] if limit else files

# Prepare Datasets
all_clean = get_files(clean_dir)
test_clean_paths = all_clean[:100]
train_clean_paths = all_clean[100:]

train_ds_ae = tf.data.Dataset.from_tensor_slices(train_clean_paths)
train_ds_ae = train_ds_ae.map(load_image_ae).map(lambda x: (x, x))
train_ds_ae = train_ds_ae.batch(AE_BATCH_SIZE).cache().shuffle(1000).prefetch(tf.data.AUTOTUNE)

# Build Model
class ConvolutionalAutoencoder(models.Model):
    def __init__(self):
        super(ConvolutionalAutoencoder, self).__init__()
        self.encoder = tf.keras.Sequential([
            layers.Input(shape=(128, 128, 1)),
            layers.Conv2D(32, (3, 3), activation='relu', padding='same', strides=2),
            layers.Conv2D(64, (3, 3), activation='relu', padding='same', strides=2),
            layers.Conv2D(128, (3, 3), activation='relu', padding='same', strides=2)
        ])
        self.decoder = tf.keras.Sequential([
            layers.Conv2DTranspose(128, kernel_size=3, strides=2, activation='relu', padding='same'),
            layers.Conv2DTranspose(64, kernel_size=3, strides=2, activation='relu', padding='same'),
            layers.Conv2DTranspose(32, kernel_size=3, strides=2, activation='relu', padding='same'),
            layers.Conv2D(1, kernel_size=(3, 3), activation='sigmoid', padding='same')
        ])
    def call(self, x): return self.decoder(self.encoder(x))

autoencoder = ConvolutionalAutoencoder()
autoencoder.compile(optimizer='adam', loss='mse')

# Train
# history_ae = autoencoder.fit(train_ds_ae, epochs=AE_EPOCHS)

## 4. CNN Classifier
Trains a supervised CNN to classify images into 4 categories: Clean, Hair, Trash, Trash-Hair.

In [None]:
# --- Classifier Configuration ---
CLS_BATCH_SIZE = 32
CLS_EPOCHS = 15

print("Loading Classifier Data...")
# Note: Ensure 'filtering/processed' exists and has subfolders
if data_dir.exists():
    train_ds_cls = tf.keras.utils.image_dataset_from_directory(
        data_dir, validation_split=0.2, subset="training", seed=123,
        color_mode='grayscale', image_size=IMG_SIZE, batch_size=CLS_BATCH_SIZE
    )
    val_ds_cls = tf.keras.utils.image_dataset_from_directory(
        data_dir, validation_split=0.2, subset="validation", seed=123,
        color_mode='grayscale', image_size=IMG_SIZE, batch_size=CLS_BATCH_SIZE
    )
    
    class_names = train_ds_cls.class_names
    print(f"Classes: {class_names}")
    
    train_ds_cls = train_ds_cls.cache().shuffle(1000).prefetch(tf.data.AUTOTUNE)
    val_ds_cls = val_ds_cls.cache().prefetch(tf.data.AUTOTUNE)

    # Build Model
    classifier = models.Sequential([
        layers.Rescaling(1./255, input_shape=(128, 128, 1)),
        layers.Conv2D(32, 3, padding='same', activation='relu'),
        layers.MaxPooling2D(),
        layers.Conv2D(64, 3, padding='same', activation='relu'),
        layers.MaxPooling2D(),
        layers.Conv2D(128, 3, padding='same', activation='relu'),
        layers.MaxPooling2D(),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(len(class_names))
    ])

    classifier.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    
    # Train
    # history_cls = classifier.fit(train_ds_cls, validation_data=val_ds_cls, epochs=CLS_EPOCHS)
else:
    print("Data directory not found. Run preprocessing first.")

## 5. Roboflow Workflow
The Roboflow workflow `find-white-dots-and-thin-black-lines` uses a SAM-based approach to detect:
- **White dots** (Trash/Crumbs)
- **Thin black lines** (Hair)

### Evaluation Results
The following results were obtained by running the workflow on the last 100 images of each class:

| Class | Avg White Dots | Avg Black Lines |
| :--- | :--- | :--- |
| **Clean** | 0.00 | 1.13 |
| **Hair** | 0.00 | 0.03 |
| **Trash** | 1.65 | 0.65 |
| **Trash-Hair** | 0.00 | 1.70 |

**Observations:**
- **Trash Detection**: The model successfully detects "White dots" in the Trash class (Avg 1.65) compared to Clean/Hair (0.00).
- **Hair Detection**: The "Thin black lines" detection is inconsistent. While "Trash-Hair" shows the highest count (1.70), the pure "Hair" class shows very few (0.03), and surprisingly, "Clean" images show some false positives (1.13). This suggests the "Thin black lines" prompt or model might need refinement to distinguish actual hair from background artifacts.

In [None]:
# --- Install Roboflow Dependencies ---
# Uncomment the line below to install the required packages in Colab or your local environment
# !pip install inference-sdk python-dotenv

In [None]:
# --- Roboflow Evaluation Script ---
import os
import glob
import statistics
from pathlib import Path
from dotenv import load_dotenv
from inference_sdk import InferenceHTTPClient
from tqdm import tqdm

# Load environment variables (create a .env file with ROBOFLOW_API_KEY=...)
load_dotenv()

# You can set your API key here directly if not using .env
# os.environ["ROBOFLOW_API_KEY"] = "YOUR_API_KEY"

api_key = os.getenv("ROBOFLOW_INFERENCE_API_KEY") or os.getenv("ROBOFLOW_API_KEY")

if not api_key:
    print("⚠️ ROBOFLOW_API_KEY not found. Please set it in a .env file or environment variables to run this cell.")
else:
    try:
        client = InferenceHTTPClient(
            api_url="https://serverless.roboflow.com",
            api_key=api_key
        )

        # Configuration
        WORKSPACE_NAME = "dontloseyourheadsu"
        WORKFLOW_ID = "find-white-dots-and-thin-black-lines"
        CLASSES = ["clean", "hair", "trash", "trash-hair"]
        
        # Adjust path for Notebook environment
        BASE_DIR = pathlib.Path("filtering/processed")
        LIMIT = 100

        results_summary = {}

        print(f"Starting evaluation on last {LIMIT} images per class...")

        for class_name in CLASSES:
            class_dir = BASE_DIR / class_name
            if not class_dir.exists():
                print(f"Directory not found: {class_dir}")
                continue

            # Get all images, sort them to get the 'last' ones consistently
            images = sorted(list(class_dir.glob("*.jpg")) + list(class_dir.glob("*.png")))
            
            if not images:
                print(f"No images found for {class_name}")
                continue
                
            # Take last N images
            test_images = images[-LIMIT:] if len(images) > LIMIT else images
            print(f"Processing {len(test_images)} images for class '{class_name}'...")
            
            white_dots_counts = []
            black_lines_counts = []
            
            for img_path in tqdm(test_images):
                try:
                    # Run workflow
                    result = client.run_workflow(
                        workspace_name=WORKSPACE_NAME,
                        workflow_id=WORKFLOW_ID,
                        images={"image": str(img_path)},
                        use_cache=True
                    )
                    
                    # Parse results
                    predictions = []
                    if isinstance(result, list):
                        result = result[0]
                    
                    # Look for prediction lists in the result dictionary
                    for key, value in result.items():
                        if isinstance(value, dict) and 'predictions' in value:
                            predictions.extend(value['predictions'])
                        elif isinstance(value, list):
                            if value and isinstance(value[0], dict) and 'class' in value[0]:
                                predictions.extend(value)

                    wd_count = 0
                    bl_count = 0
                    
                    for pred in predictions:
                        if 'class' in pred:
                            label = pred['class']
                            if "White dots" in label or "white dots" in label:
                                wd_count += 1
                            elif "thin black lines" in label or "black lines" in label:
                                bl_count += 1
                    
                    white_dots_counts.append(wd_count)
                    black_lines_counts.append(bl_count)
                    
                except Exception as e:
                    print(f"Error processing {img_path.name}: {e}")

            # Calculate stats
            avg_wd = statistics.mean(white_dots_counts) if white_dots_counts else 0
            avg_bl = statistics.mean(black_lines_counts) if black_lines_counts else 0
            
            results_summary[class_name] = {
                "avg_white_dots": avg_wd,
                "avg_black_lines": avg_bl
            }

        print("\n--- Roboflow SAM Evaluation Results ---")
        for cls, stats in results_summary.items():
            print(f"Class: {cls}")
            print(f"  Avg White Dots: {stats['avg_white_dots']:.2f}")
            print(f"  Avg Black Lines: {stats['avg_black_lines']:.2f}")

    except Exception as e:
        print(f"An error occurred during initialization: {e}")
