# Sterility Classification model training notebook

This notebook trains a binary classification model to determine instrument sterility using EfficientNet-B0 architecture.

## Overview
- **Task**: Binary classification (Sterile vs Non-sterile instruments)
- **Architecture**: EfficientNet-B0 with ImageNet pre-training
- **Dataset**: Images in `dataset/0/` (sterile) and `dataset/1/` (non-sterile)
- **Features**: 
  - Data augmentation for training
  - Optional blur filtering
  - Class-balanced loss function
  - Learning rate scheduling

## Workflow
1. Import dependencies and setup configuration
2. Define utility functions and custom dataset
3. Prepare data loaders with augmentation
4. Setup model architecture and training components
5. Train model with validation monitoring
6. Evaluate final performance


## 1. Import Dependencies

In [None]:
import os

import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models

## 2. Configuration

In [None]:
class Config:
    DATA_DIR = "dataset"          # –ø—É—Ç—å –∫ –ø–∞–ø–∫–µ —Å 0/ –∏ 1/
    BATCH_SIZE = 32
    NUM_EPOCHS = 15
    LR = 1e-4
    WEIGHT_DECAY = 1e-4
    NUM_WORKERS = 4
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    IMG_SIZE = 224
    REMOVE_BLURRY = False        # —É—Å—Ç–∞–Ω–æ–≤–∏ True, —á—Ç–æ–±—ã —Ñ–∏–ª—å—Ç—Ä–æ–≤–∞—Ç—å —Ä–∞–∑–º—ã—Ç—ã–µ
    BLUR_THRESHOLD = 100.0       # –Ω–∏–∂–µ ‚Äî —Å—á–∏—Ç–∞–µ—Ç—Å—è —Ä–∞–∑–º—ã—Ç—ã–º

print(f"Using device: {Config.DEVICE}")
print(f"Dataset directory: {Config.DATA_DIR}")
print(f"Image size: {Config.IMG_SIZE}")
print(f"Batch size: {Config.BATCH_SIZE}")
print(f"Number of epochs: {Config.NUM_EPOCHS}")
print(f"Learning rate: {Config.LR}")
print(f"Remove blurry images: {Config.REMOVE_BLURRY}")


## 3. Utility Functions


In [None]:
def is_blurry(img_path, threshold=100.0):
    """
    Check if an image is blurry using Laplacian variance.
    
    Args:
        img_path (str): Path to the image file
        threshold (float): Threshold below which image is considered blurry
        
    Returns:
        bool: True if image is blurry, False otherwise
    """
    image = cv2.imread(img_path)
    if image is None:
        return True
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    fm = cv2.Laplacian(gray, cv2.CV_64F).var()
    return fm < threshold


## 4. Custom Dataset Class


In [None]:
class SterilityDataset(Dataset):
    """
    Custom PyTorch Dataset for sterility classification.
    
    Expected directory structure:
    data_dir/
    ‚îú‚îÄ‚îÄ 0/  # Sterile instruments
    ‚îî‚îÄ‚îÄ 1/  # Non-sterile instruments
    """
    
    def __init__(self, data_dir, transform=None, remove_blurry=False, blur_threshold=100.0):
        """
        Args:
            data_dir (str): Path to dataset directory
            transform (callable): Optional transform to be applied to samples
            remove_blurry (bool): Whether to filter out blurry images
            blur_threshold (float): Threshold for blur detection
        """
        self.transform = transform
        self.samples = []
        
        for class_id in ["0", "1"]:
            class_dir = os.path.join(data_dir, class_id)
            if not os.path.isdir(class_dir):
                continue
                
            for fname in os.listdir(class_dir):
                fpath = os.path.join(class_dir, fname)
                if not fname.lower().endswith((".jpg", ".jpeg", ".png")):
                    continue
                if remove_blurry and is_blurry(fpath, blur_threshold):
                    continue
                self.samples.append((fpath, int(class_id)))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, float(label)


## 5. Data Preparation and Loading


In [None]:
# Define data transforms
train_transform = transforms.Compose([
    transforms.Resize((Config.IMG_SIZE, Config.IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomRotation(degrees=5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((Config.IMG_SIZE, Config.IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("Data transforms defined")
print("Training transforms: Resize, RandomFlip, ColorJitter, Rotation, Normalization")
print("Validation transforms: Resize, Normalization only")


In [None]:
# Create datasets
full_dataset = SterilityDataset(
    Config.DATA_DIR,
    transform=train_transform,
    remove_blurry=Config.REMOVE_BLURRY,
    blur_threshold=Config.BLUR_THRESHOLD
)

# Split dataset into train/validation (80/20)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

# Apply validation transforms to validation dataset
val_dataset.dataset.transform = val_transform

print("Dataset loaded and split:")
print(f"   Total samples: {len(full_dataset)}")
print(f"   Training samples: {train_size}")
print(f"   Validation samples: {val_size}")

# Check class distribution in the full dataset
class_0_count = sum(1 for _, label in full_dataset.samples if label == 0)
class_1_count = sum(1 for _, label in full_dataset.samples if label == 1)
print(f"   Class 0 (Sterile): {class_0_count}")
print(f"   Class 1 (Non-sterile): {class_1_count}")
print(f"   Class balance ratio: {class_1_count/class_0_count:.3f}")


In [None]:
# Create data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=Config.BATCH_SIZE, 
    shuffle=True, 
    num_workers=Config.NUM_WORKERS
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=Config.BATCH_SIZE, 
    shuffle=False, 
    num_workers=Config.NUM_WORKERS
)

print(f"Data loaders created:")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Batch size: {Config.BATCH_SIZE}")
print(f"Number of workers: {Config.NUM_WORKERS}")


## 6. Model Setup and Training Configuration


In [None]:
# Initialize model - EfficientNet-B0 with pre-trained weights
model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)

# Modify classifier for binary classification (1 output neuron)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 1)

# Move model to device
model = model.to(Config.DEVICE)

print("Model initialized:")
print("   Architecture: EfficientNet-B0")
print("   Pre-trained: ImageNet weights")
print("   Output neurons: 1 (binary classification)")
print(f"   Device: {Config.DEVICE}")
print(f"   Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"   Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


In [None]:
# Setup loss function with class balancing
# Note: Using hard-coded values from original code - ideally should calculate from data
n0, n1 = 2305, 2338  # Class counts: sterile, non-sterile

# Increased positive weight to heavily penalize missing non-sterile cases
# This makes the model more sensitive to non-sterile instruments
# Higher weight = model pays more attention to not missing non-sterile cases
pos_weight = torch.tensor([n0 / n1 * 3.0]).to(Config.DEVICE)  # 3x multiplier for safety
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# Setup optimizer with weight decay
optimizer = optim.AdamW(
    model.parameters(), 
    lr=Config.LR, 
    weight_decay=Config.WEIGHT_DECAY
)

# Setup learning rate scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 
    T_max=Config.NUM_EPOCHS
)

print(f"""
TRAINING COMPONENTS CONFIGURATION

Loss Function       : BCEWithLogitsLoss
Positive Weight     : {pos_weight.item():.4f} (increased 3x for safety)
Optimizer           : AdamW
Learning Rate       : {Config.LR}
Weight Decay        : {Config.WEIGHT_DECAY}
LR Scheduler        : CosineAnnealingLR
Decision Threshold  : 0.25 (conservative to catch all non-sterile cases)
""")


## 7. Training Loop


In [None]:
# Training loop with Recall-focused metric
from sklearn.metrics import recall_score
import numpy as np

best_val_recall = 0.0
DECISION_THRESHOLD = 0.25  # Lower threshold to be more sensitive to non-sterile cases
training_history = {
    "epoch": [], 
    "train_loss": [], 
    "val_recall_class1": [], 
    "val_acc": []
}

print(f"""
STARTING TRAINING

Total Epochs        : {Config.NUM_EPOCHS}
Primary Metric      : Recall for Class 1 (Non-sterile)
Decision Threshold  : {DECISION_THRESHOLD} (conservative)
""")

for epoch in range(Config.NUM_EPOCHS):
    # Training phase
    model.train()
    train_loss = 0.0
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images).squeeze()
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
        # Print progress every 50 batches
        if (batch_idx + 1) % 50 == 0:
            print(f"Batch {batch_idx+1:4d}/{len(train_loader):4d} | Loss: {loss.item():.4f}")

    avg_train_loss = train_loss / len(train_loader)
    
    # Validation phase with adjusted threshold
    model.eval()
    val_preds, val_labels = [], []
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)
            outputs = model(images).squeeze()
            # Use lower threshold (0.25) to be more conservative
            preds = (torch.sigmoid(outputs) > DECISION_THRESHOLD).float()
            # val_preds.extend(preds.cpu().numpy())
            # val_labels.extend(labels.cpu().numpy())
            val_preds.extend(preds.cpu().flatten().tolist())
            val_labels.extend(labels.cpu().flatten().tolist())

    # Convert to numpy arrays
    val_preds_np = np.array(val_preds)
    val_labels_np = np.array(val_labels)
    
    # Calculate metrics
    val_acc = (val_preds_np == val_labels_np).mean()
    val_recall_class1 = recall_score(val_labels_np, val_preds_np, pos_label=1)
    
    # Calculate confusion matrix
    cm = confusion_matrix(val_labels_np, val_preds_np)
    tn, fp, fn, tp = cm.ravel()
    
    # Save training history
    training_history["epoch"].append(epoch + 1)
    training_history["train_loss"].append(avg_train_loss)
    training_history["val_recall_class1"].append(val_recall_class1)
    training_history["val_acc"].append(val_acc)
    
    # Print epoch results
    saved_marker = ""
    if val_recall_class1 > best_val_recall:
        best_val_recall = val_recall_class1
        torch.save(model.state_dict(), "best_model_v0.0.1.pth")
        saved_marker = " [MODEL SAVED]"
    
    # Update learning rate
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    
    print(f"""
EPOCH {epoch+1}/{Config.NUM_EPOCHS}{saved_marker}

Training Loss                : {avg_train_loss:.4f}
Validation Accuracy          : {val_acc:.4f}
Validation Recall (PRIMARY)  : {val_recall_class1:.4f}
Current Learning Rate        : {current_lr:.6f}

Confusion Matrix:
                        Predicted
                    Sterile    Non-sterile
Actual Sterile      {tn:4d}       {fp:4d}
       Non-sterile  {fn:4d}       {tp:4d}

False Negatives (CRITICAL)   : {fn} (Non-sterile predicted as Sterile)
False Positives              : {fp} (Sterile predicted as Non-sterile)
""")

print(f"""
TRAINING COMPLETED

Best Recall (Non-sterile)    : {best_val_recall:.4f}
Model saved as               : best_model_v0.0.1.pth
""")


NameError: name 'Config' is not defined

## 8. Final Evaluation and Metrics


In [None]:
# Final evaluation with detailed metrics
from sklearn.metrics import precision_score, f1_score

# Calculate metrics
cm = confusion_matrix(val_labels_np, val_preds_np)
tn, fp, fn, tp = cm.ravel()
precision = precision_score(val_labels_np, val_preds_np)
recall = recall_score(val_labels_np, val_preds_np)
f1 = f1_score(val_labels_np, val_preds_np)

# Generate classification report
class_report = classification_report(
    val_labels_np, 
    val_preds_np, 
    target_names=["Sterile (0)", "Non-sterile (1)"],
    digits=4
)

print(f"""
FINAL VALIDATION METRICS

Classification Report:
{class_report}

Confusion Matrix:
                        Predicted
                    Sterile    Non-sterile
Actual Sterile      {tn:4d}       {fp:4d}
       Non-sterile  {fn:4d}       {tp:4d}

False Negatives (CRITICAL)   : {fn} (Non-sterile marked as Sterile)
False Positives              : {fp} (Sterile marked as Non-sterile)

Performance Metrics:
Overall Accuracy             : {val_acc:.4f}
Recall (Non-sterile)         : {recall:.4f}  [PRIMARY METRIC]
Precision (Non-sterile)      : {precision:.4f}
F1-Score                     : {f1:.4f}
Decision Threshold           : {DECISION_THRESHOLD}

Model Information:
Model File                   : best_model_v0.0.1.pth
Architecture                 : EfficientNet-B0
Input Size                   : {Config.IMG_SIZE}x{Config.IMG_SIZE}
Training Epochs              : {Config.NUM_EPOCHS}
Best Recall (Non-sterile)    : {best_val_recall:.4f}
Decision Threshold           : {DECISION_THRESHOLD} (conservative)
Positive Class Weight        : 3x (penalizes missing non-sterile cases)
""")
