# 🧠 Real-time Entity Classifier - CNN Architecture & Training

This notebook handles both definition and training of a MobileNetV2-based CNN for real-time webcam classification:

**Detection Categories**:
- 🐱 **Pet** (Felix - British Shorthair cat)
- 👤 **Owner** (Sebastian)
- 🧍 **Other People**
- 🚫 **Background/Nobody**

**Key Features**:
- ⚡ Real-time inference: ~30ms/frame (640x480 resolution)
- 🔒 Privacy-focused: All processing on-device
- 🌟 Robust to: Lighting changes, partial occlusions
- 🏗️ Transfer learning: Fine-tuned from ImageNet weights

<hr>

## 📦 PyTorch & Project Imports Overview

### 🧠 Core PyTorch Components

- **`torch`**  
  Used for tensor operations, device management (`torch.cuda.is_available()`), and model saving (`torch.save()`).

- **`torch.nn`**  
  Used for building the classifier component of our model, including `nn.Linear`, `nn.ReLU`, `nn.Dropout`, and `nn.Sequential`. Also provides `CrossEntropyLoss` with label smoothing to handle class imbalance.

- **`torch.optim`**  
  Provides the `Adam` optimizer with learning rate and weight decay parameters for training our classifier layers.

- **`torch.utils.data.Dataset, DataLoader`**  
  `Dataset` is subclassed to create our custom `VideoFrameDataset`. The `DataLoader` wraps this dataset with batch processing, shuffling, and provides an iterator for training.

### 🖼️ Computer Vision & Image Processing

- **`torchvision.models`**  
  Used to load a pre-trained `mobilenet_v2` as our base feature extractor, which we then modify for our specific classification task.
    
- **`PIL.Image`**  
  Used to convert between NumPy arrays and PIL images for compatibility with torchvision transforms.

- **`cv2` (OpenCV)**  
  Handles video operations (`VideoCapture`) to read and extract frames from video files, and image color space conversion (`cvtColor`).

- **`torchvision.transforms`**  
  Builds image transformation pipelines for:  
  - Standard preprocessing: resize, crop, normalization  
  - Data augmentation: random flips, rotations, color jitter, perspective changes

### 📁 Data Handling & Utilities

- **`os`**  
  Checks file existence (`os.path.exists()`) and constructs file paths (`os.path.join()`).

- **`glob`**  
  Finds image files with specific extensions (`.jpg`, `.jpeg`, `.png`) within directories.

- **`random`**  
  Samples subsets of data when we have too many files (`random.sample()`), helping maintain dataset balance.

### 📊 Visualization & Progress Tracking

- **`matplotlib.pyplot as plt`**  
  Creates and saves training visualization plots with loss and accuracy metrics after training completes.

- **`tqdm`**  
  Wraps the training loop to provide a progress bar with real-time metrics (loss and accuracy) during model training.

These libraries together form a complete pipeline for processing video data, extracting frames, building and training a deep learning model for multi-class classification, and visualizing the results.

In [22]:
# ======================
# Core PyTorch Components
# ======================
import torch               # Base library for tensors, GPU ops, and autograd
import torch.nn as nn      # Neural network layers (Linear, Conv2d, etc.)
import torch.optim as optim  # Optimizer (Adam) for training
from torch.utils.data import Dataset, DataLoader  # Custom datasets + efficient batching

# ======================
# Computer Vision
# ======================
import cv2                        # Video capture and frame processing
from PIL import Image             # Image loading and conversion
from torchvision import transforms  # Image preprocessing/augmentations
from torchvision import models    # Pretrained models (MobileNetV2)
from torchvision.models.mobilenetv2 import ( # MobileNetV2-specific:
    MobileNet_V2_Weights                     # - Pretrained weight configurations
)

# ======================
# Data Pipeline
# ======================
import os                        # File path operations
from glob import glob            # Pattern matching for image/video files
import random                    # Shuffling datasets and sampling

# ======================
# Training Utilities
# ======================
import matplotlib.pyplot as plt  # Plotting loss/accuracy curves
from tqdm import tqdm            # Progress bars for training loops

<hr>

## MobileNetV2 Transfer Learning Architecture Explanation 🚀

### Project Goal: 4-Way Classification System 🎯

This implementation adapts a pre-trained MobileNetV2 model for a specific 4-class recognition task using transfer learning principles. The model is designed to classify webcam frames into:

1. **Pet** – A specific British Shorthair cat named Felix 🐱  
2. **Owner** – Sebastian (the project creator) 👨  
3. **Other Person** – Any human who is not Sebastian 🧍  
4. **None/Background** – Empty frames with no people or pets 🚫

### Base Model: MobileNetV2 📱

```python
model = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1).to(device)
```

#### Why MobileNetV2? 🤔

- **Efficiency** ⚡: MobileNetV2 is designed specifically for mobile and edge devices, making it computationally efficient while maintaining good accuracy
- **Depthwise Separable Convolutions** 🧩: Uses a factorized form of standard convolutions that drastically reduces parameter count and computational cost
- **Inverted Residuals** 🔄: Unlike traditional residual blocks, MobileNetV2 uses inverted residuals with linear bottlenecks, which help preserve information flow while keeping the model lightweight
- **Pre-trained Knowledge** 🧠: ImageNet pre-training provides the model with powerful feature extraction capabilities that can transfer well to new domains
- **Size-Performance Tradeoff** ⚖️: Offers an excellent balance between model size (~14M parameters) and performance for real-time or resource-constrained applications

### Transfer Learning Approach 🔄

```python
# Freeze the feature extraction layers
for param in model.parameters():
    param.requires_grad = False
```

#### Benefits of Feature Freezing ❄️

- **Training Efficiency** 🏎️: By freezing the convolutional backbone, we dramatically reduce the number of trainable parameters (from millions to thousands)
- **Prevents Overfitting** 🛡️: With limited training data, updating only the classifier prevents the model from overfitting to peculiarities in our small dataset
- **Knowledge Preservation** 📚: Retains the robust feature extraction capabilities learned from ImageNet's diverse 1.2+ million images
- **Faster Convergence** 🏁: The classifier can adapt to the new task much more quickly when starting from well-formed features

### Custom Classifier Head 👑

```python
model.classifier[1] = nn.Sequential(
    nn.Linear(model.classifier[1].in_features, 256),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(256, 4)
).to(device)
```

#### Classifier Design Choices 🧩

- **Intermediate Hidden Layer (256 neurons)** 🧬:
  - Provides greater representational capacity than a direct mapping
  - Allows the model to learn more complex decision boundaries between classes
  - 256 neurons balances expressivity and computational efficiency
  - Particularly helpful for distinguishing between similar human faces (owner vs. other person) 👥

- **ReLU Activation** ⚡:
  - Introduces non-linearity to capture complex relationships
  - Mitigates the vanishing gradient problem with its non-saturating form
  - Computationally efficient compared to tanh or sigmoid

- **Dropout (0.2)** 🎭:
  - Implements regularization by randomly deactivating 20% of neurons during training
  - Prevents co-adaptation of neurons (neurons becoming too dependent on each other)
  - Forces the network to learn redundant representations, improving generalization
  - Rate of 0.2 is conservative, providing regularization while preserving most information flow
  - Particularly important for this task since the dataset likely contains many similar frames of the same subjects

- **Output Layer (4 neurons)** 🎬:
  - One neuron per class (pet, owner, other person, background)
  - Used with CrossEntropyLoss which applies softmax internally to produce probabilities

### Weight Initialization Strategy 🎲

```python
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.zeros_(m.bias)
```

#### Xavier/Glorot Initialization Benefits ✨

- **Variance Control** 📊: Maintains the variance of activations and gradients across layers
- **Prevents Signal Vanishing/Exploding** 💥: Scaling weights based on the layer size helps signal propagate effectively
- **Uniform Distribution** 📈: Draws weights from a uniform distribution within a carefully calculated range
- **Faster Convergence** 🚀: Well-initialized weights allow the model to reach optimal regions more quickly
- **Zero Bias Initialization** 0️⃣: Starting biases at zero is a standard practice that works well with ReLU when batch normalization isn't used

### Model Deployment 🚀

```python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
```

#### Hardware Acceleration 💻

- **Dynamic Device Selection** 🔍: Code automatically selects GPU if available, falling back to CPU if necessary
- **Memory Management** 💾: Moving the model to the appropriate device ensures efficient memory usage
- **Computation Speed** ⚡: Running on GPU can offer 10-100x speedup for neural network operations

### Architecture Summary 📝

This architecture exemplifies modern transfer learning best practices by:

1. Leveraging a pre-trained, efficient CNN architecture 🏗️
2. Freezing feature extraction layers to preserve learned representations ❄️
3. Implementing a purpose-built classifier for the specific task 🎯
4. Using appropriate regularization techniques to prevent overfitting 🛡️
5. Applying proven weight initialization strategies for faster convergence 🏁

The resulting model balances computational efficiency with classification performance, making it suitable for deployment in resource-constrained environments while still maintaining high accuracy for this pet and person recognition task. 🤖👍

### Application-Specific Advantages 🌟

For this specific pet/owner recognition task:

1. **Fine-Grained Recognition** 🔍: The model can learn subtle differences between a specific cat (Felix) and other animals, or between the owner and other people
3. **Real-time Processing** ⏱️: MobileNetV2's efficiency enables real-time classification on webcam streams
4. **Low Resource Requirements** 💪: The architecture can run on modest hardware like laptops without dedicated GPUs
5. **Quick Training** 🏎️: By using transfer learning, the model can be trained with relatively few examples of each class

In [26]:
# ======================
# Model Definition
# ======================
# Set computation device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pre-trained MobileNetV2 with ImageNet weights
# MobileNetV2 is chosen for its efficiency, performance, and lightweight nature
model = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1).to(device)

# Freeze the feature extraction layers to preserve pre-trained knowledge
# This implements transfer learning - we only train the classifier
for param in model.parameters():
    param.requires_grad = False

# Replace the classifier with our custom head for 4-class classification
# Architecture: Input features → 256 neurons → ReLU → Dropout → 4 output classes
model.classifier[1] = nn.Sequential(
    nn.Linear(model.classifier[1].in_features, 256),  # Hidden layer with 256 neurons
    nn.ReLU(),                                        # Non-linearity
    nn.Dropout(0.2),                                  # Regularization to prevent overfitting
    nn.Linear(256, 4)                                 # Output layer for our 4 classes
).to(device)

# Define weight initialization function for better convergence
# Xavier/Glorot initialization helps control variance of activations across layers
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)  # Uniform distribution within proper bounds
        nn.init.zeros_(m.bias)             # Initialize biases to zero

# Apply weight initialization to our classifier only
model.classifier[1].apply(init_weights)

# Verify model setup
print(f"Model loaded on {device}")
print("Classifier structure:")
print(model.classifier)

Model loaded on cpu
Classifier structure:
Sequential(
  (0): Dropout(p=0.2, inplace=False)
  (1): Sequential(
    (0): Linear(in_features=1280, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=4, bias=True)
  )
)


<hr>

## Training Configuration: Optimizer and Loss Function 🛠️

### Optimizer: Adam with Selective Training 🎯

```python
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(trainable_params, lr=0.001, weight_decay=1e-5)
```

#### Parameter Selection Strategy 🔍

We use `filter(lambda p: p.requires_grad, model.parameters())` to select only trainable parameters. This is critical for our transfer learning approach:

- **Efficiency Boost** ⚡: By optimizing only the classifier parameters ( ~few thousand) instead of the entire model ( ~14 million), we drastically reduce computation time and memory requirements
- **Focused Learning** 🔭: Updates are restricted to the task-specific components (the classifier), leaving the pre-trained feature extractor untouched
- **Lambda Function Elegance** 💻: The lambda function creates a clean, one-line filter that selects parameters where `requires_grad=True`

#### Why Adam? 🤔

Adam (Adaptive Moment Estimation) combines the benefits of two other extensions of stochastic gradient descent:

- **Adaptive Learning Rates** 📊: Automatically adjusts learning rates for each parameter based on historical gradients
- **Momentum** 🏎️: Accelerates convergence by adding a fraction of the previous update direction
- **RMSProp Integration** 📉: Adapts learning rates based on the average of recent magnitudes of gradients
- **Sparse Gradient Handling** 🌵: Performs well even when gradients are sparse or noisy

#### Hyperparameter Choices 🎛️

- **Learning Rate (5e-5)** 🏁: 
  - A smaller learning rate typically used for fine-tuning pre-trained models.
  - Helps ensure that the model updates its weights gradually, preventing large, unstable changes.
  - Useful when working with pre-trained weights, as it allows the model to make subtle adjustments without drastically altering the learned features.
  - Small enough to avoid overshooting the optimal solution, promoting stable convergence, especially in later training stages.

- **Weight Decay (1e-5)** 🌱: 
  - Implements L2 regularization by penalizing large weights
  - Helps prevent overfitting by encouraging the model to use smaller weights
  - Value of 1e-5 is conservative, providing gentle regularization
  - Particularly valuable for our problem where limited training data could lead to overfitting

### Loss Function: CrossEntropyLoss with Label Smoothing 📊

```python
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
```

#### Why CrossEntropyLoss? 🎯

CrossEntropyLoss is the standard choice for multi-class classification problems:

- **Softmax Integration** 🔄: Internally applies softmax to convert logits to probabilities
- **Class Probability** 📊: Measures the difference between predicted probability distribution and the target distribution
- **Numerical Stability** 🛡️: Implements log-softmax and negative log-likelihood in a numerically stable way
- **Single-Label Focus** 🏷️: Optimized for problems where each example belongs to exactly one class

#### Label Smoothing (0.1) 🥤

Label smoothing is a regularization technique that modifies the target distribution:

- **Target Softening** ☁️: Instead of using hard labels (0,0,1,0), it creates soft targets (0.03,0.03,0.9,0.03)
- **Confidence Penalty** 📉: Discourages the model from becoming too confident in its predictions
- **Overfitting Prevention** 🛡️: Makes the model less likely to memorize training data noise or errors
- **Class Imbalance Handling** ⚖️: Particularly valuable for our use case with potential class imbalance (may have more background frames than pet frames)
- **Value Selection (0.1)** 🎚️: 
  - 0.1 is a moderate smoothing value that provides regularization benefits
  - High enough to prevent overconfidence
  - Low enough to maintain class separation
  
### Combined Effect on Training Dynamics 🔄

Together, these choices create a training configuration that:

1. **Focuses computational effort** on adapting the classifier to our specific classes
2. **Adapts learning dynamically** based on gradient behavior during training
3. **Regularizes from multiple angles** (weight decay and label smoothing) to prevent overfitting
4. **Improves generalization** especially in the face of class imbalance or limited training data
5. **Accelerates convergence** compared to standard SGD or simpler optimizers

This configuration represents modern deep learning best practices for transfer learning on classification tasks with limited, potentially imbalanced datasets.

In [17]:
# Optimizer (only train classifier parameters)
trainable_params = filter(lambda p: p.requires_grad, model.parameters())  
# Filter out the parameters that require gradients (only the custom classifier will be trained)
# This is efficient because we freeze the feature extraction layers, so we don't need to train them again.
optimizer = optim.Adam(trainable_params, lr=5e-5, weight_decay=1e-5)  
# Adam optimizer is chosen for adaptive learning rates, which helps to converge faster and more efficiently
# lr=5e-5 (=0.00005): A learning rate of 5e-5 is a commonly used starting point for fine-tuning pre-trained models.
# It allows the model to adjust the new classifier layers without drastically altering the pre-trained weights.
# It's generally small enough to avoid overfitting but large enough to enable the classifier to learn efficiently.
# weight_decay=1e-5: L2 regularization helps prevent overfitting by penalizing large weights (helps generalization).

# Loss function
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  
# CrossEntropyLoss is commonly used for multi-class classification tasks, as it combines softmax and negative log-likelihood loss
# label_smoothing=0.1: This technique reduces the confidence of the model when predicting the target class, 
# making it less likely to overfit on noisy or incorrect labels and improving generalization, especially for imbalanced classes

<hr>

## 📦 DataLoader and Dataset Setup

Before training our model, we need to properly load and preprocess the image data. This is handled using PyTorch’s `ImageFolder` and `DataLoader` utilities.

### 🖼️ Image Preprocessing
We define a series of transformations to ensure consistency and improve model performance:
- **Resize**: All images are resized to a fixed dimension of **224x224** pixels. This is a common size for CNNs and ensures a uniform input shape.
- **ToTensor**: Images are converted into PyTorch tensors, enabling them to be processed by the model.
- **Normalize**: Pixel values are normalized using the mean and standard deviation from the ImageNet dataset:
  - Mean: `[0.485, 0.456, 0.406]`
  - Standard Deviation: `[0.229, 0.224, 0.225]`
  This helps with model convergence and consistency.

### 🗂️ Dataset Structure
The dataset is organized using folders representing each class. Each subdirectory (e.g., `nobody`, `pet`, `owner`, `other_person`) contains images specific to that class. PyTorch’s `ImageFolder` automatically maps these folders to class labels.

### 🛠️ DataLoader Creation
We create two `DataLoader` objects:
- **Training Loader**: Loads the dataset in shuffled batches of 32 images to ensure randomness and improve generalization.
- **Validation Loader**: Loads the validation data without shuffling, maintaining the order for consistent evaluation.

The `num_workers=4` setting allows data loading to happen in parallel for better performance.


In [2]:
class VideoFrameDataset(Dataset):
    def __init__(self, video_paths, label_mapping, transform=None, 
                 augment=False, num_augments=3, bg_paths=None, 
                 face_image_dir=None, max_frames_per_video=300):
        self.video_paths = video_paths
        self.label_mapping = label_mapping
        self.transform = transform
        self.augment = augment
        self.num_augments = num_augments
        self.bg_paths = bg_paths or []
        self.face_image_dir = face_image_dir
        self.max_frames_per_video = max_frames_per_video
        
        # Instead of storing frames, store frame references
        self.frame_sources = []  # Will contain (source_type, path, frame_idx, label)
        self.class_counts = {name: 0 for name in label_mapping}
        
        self._index_all_data()
        
        if len(self.frame_sources) == 0:
            raise ValueError("No valid frames were indexed. Please check your input paths.")
        
        total_samples = len(self.frame_sources)
        if self.augment:
            total_samples *= (1 + self.num_augments)
            
        print(f"\nDataset created with {total_samples} samples")
        print("Class distribution:")
        for name, idx in label_mapping.items():
            count = self.class_counts[name] * (1 + self.num_augments if self.augment else 1)
            print(f"- {name}: {count} samples ({count/total_samples:.1%})")

    def _index_all_data(self):
        print("\n" + "="*50)
        print("STARTING DATA INDEXING")
        print("="*50)
        
        # Index owner video
        if "owner" in self.video_paths:
            print("\n" + "-"*20)
            print("INDEXING OWNER VIDEOS")
            print("-"*20)
            for path in self.video_paths["owner"]:
                print(f"\nIndexing: {path}")
                self._index_video(path, self.label_mapping["owner"], "owner")

        # Index pet video
        if "pet" in self.video_paths:
            print("\n" + "-"*20)
            print("INDEXING PET VIDEOS")
            print("-"*20)
            for path in self.video_paths["pet"]:
                print(f"\nIndexing: {path}")
                self._index_video(path, self.label_mapping["pet"], "pet")

        # Index face images if directory provided
        if self.face_image_dir:
            print("\n" + "-"*20)
            print("INDEXING FACE IMAGES FROM DIRECTORY")
            print("-"*20)
            print(f"Directory: {self.face_image_dir}")
            
            face_paths = glob(os.path.join(self.face_image_dir, "*.jpg")) + \
                        glob(os.path.join(self.face_image_dir, "*.jpeg")) + \
                        glob(os.path.join(self.face_image_dir, "*.png"))
            
            if face_paths:
                print(f"Found {len(face_paths)} face images")
                # Sample exactly 350 faces (or all if less available)
                sample_size = min(2500, len(face_paths))
                print(f"Randomly sampling {sample_size} faces...")
                sampled_paths = random.sample(face_paths, sample_size)
                
                for i, img_path in enumerate(sampled_paths, 1):
                    if i % 100 == 0:  # Print progress every 100 images
                        print(f"Indexed {i}/{len(sampled_paths)} face images")
                    self._index_image(img_path, self.label_mapping["other"], "other")
                print(f"Finished indexing {len(sampled_paths)} face images")
            else:
                print(f"No face images found in {self.face_image_dir}")

        # Index background data
        if self.bg_paths:
            print("\n" + "-"*20)
            print("INDEXING BACKGROUND DATA")
            print("-"*20)
            for path in self.bg_paths:
                if os.path.isdir(path):
                    print(f"\nIndexing background images from: {path}")
                    img_paths = [
                        f for f in glob(os.path.join(path, "*")) 
                        if f.lower().endswith(('.png', '.jpg', '.jpeg'))
                    ]
                    if img_paths:
                        sample_size = min(400, len(img_paths))
                        print(f"Found {len(img_paths)} images, sampling {sample_size}...")
                        sampled_paths = random.sample(img_paths, sample_size)
                        for i, img_path in enumerate(sampled_paths, 1):
                            if i % 100 == 0:
                                print(f"Indexed {i}/{len(sampled_paths)} background images")
                            self._index_image(img_path, self.label_mapping["background"], "background")
                    else:
                        print(f"No valid images found in {path}")
                else:
                    print(f"\nIndexing background video: {path}")
                    self._index_video(path, self.label_mapping["background"], "background")

        print("\n" + "="*50)
        print("DATA INDEXING COMPLETE")
        print("="*50)

    def _index_video(self, path, label, class_name):
        try:
            if not os.path.exists(path):
                print(f"Error: File not found - {path}")
                return
    
            cap = cv2.VideoCapture(path)
            if not cap.isOpened():
                print(f"Error: Could not open video - {path}")
                return
    
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            if total_frames == 0:
                print(f"Warning: Video has 0 frames - {path}")
                cap.release()
                return
    
            # Calculate exact number of frames to sample
            target_frames = min(self.max_frames_per_video, total_frames)
            step = max(1, total_frames // target_frames)
            
            print(f"Indexing {target_frames} frames from {total_frames} total frames")
            
            frames_indexed = 0
            frame_indices = list(range(0, total_frames, step))[:target_frames]
            
            # Just store references to the frames
            for frame_idx in frame_indices:
                self.frame_sources.append(("video", path, frame_idx, label))
                frames_indexed += 1
                self.class_counts[class_name] += 1
    
            cap.release()
            print(f"Indexed {frames_indexed} frames from {path}")
    
        except Exception as e:
            print(f"Error indexing video {path}: {str(e)}")

    def _index_image(self, img_path, label, class_name):
        try:
            if not os.path.exists(img_path):
                print(f"Error: File not found - {img_path}")
                return

            # Just store reference to the image
            self.frame_sources.append(("image", img_path, 0, label))
            self.class_counts[class_name] += 1
            
        except Exception as e:
            print(f"Error indexing image {img_path}: {str(e)}")

    def __len__(self):
        base_len = len(self.frame_sources)
        if self.augment:
            return base_len * (1 + self.num_augments)
        return base_len

    def __getitem__(self, idx):
        try:
            # Calculate which frame and whether this is an augmented version
            if self.augment:
                base_idx = idx // (1 + self.num_augments)
                aug_version = idx % (1 + self.num_augments)
            else:
                base_idx = idx
                aug_version = 0
                
            # Get the frame source info
            source_type, path, frame_idx, label = self.frame_sources[base_idx]
            
            # Load the frame on demand
            if source_type == "video":
                frame = self._load_video_frame(path, frame_idx)
            else:  # image
                frame = self._load_image(path)
                
            # Apply augmentation if needed
            if aug_version > 0:
                frame = self._augment_frame(frame)
            
            # Apply standard transformation
            if self.transform:
                if not isinstance(frame, Image.Image):
                    frame = Image.fromarray(frame)
                frame = self.transform(frame)
            
            return frame, torch.tensor(label, dtype=torch.long)
            
        except Exception as e:
            print(f"Error getting item {idx}: {str(e)}")
            # Return a simple default item instead of recursion
            dummy_data = torch.zeros((3, 224, 224))
            return dummy_data, torch.tensor(0, dtype=torch.long)

    def _load_video_frame(self, video_path, frame_idx):
        cap = cv2.VideoCapture(video_path)
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        cap.release()
        
        if not ret or frame is None:
            raise ValueError(f"Failed to load frame {frame_idx} from {video_path}")
            
        return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

    def _load_image(self, img_path):
        img = cv2.imread(img_path)
        if img is None:
            raise ValueError(f"Failed to load image {img_path}")
        return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    def _augment_frame(self, frame):
        augmenter = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomApply([
                transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.1),
            ], p=0.8),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(10),
            transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        ])
        return augmenter(frame)


# --- Example usage ---
if __name__ == "__main__":
    # Paths config
    label_mapping = {
        "owner": 0,    # You
        "pet": 1,      # Your cat
        "other": 2,    # Other people
        "background": 3  # Empty background
    }

    video_paths = {
        "owner": ["../data/owner/owner.mp4"],
        "pet": ["../data/pet/pet.mp4"],
    }

    bg_paths = [
        "../data/background/background.mp4"
    ]

    face_image_dir = "../data/other_people/other_people"

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    try:
        print("Creating dataset...")
        dataset = VideoFrameDataset(
            video_paths=video_paths,
            label_mapping=label_mapping,
            transform=transform,
            augment=True,
            num_augments=4,
            bg_paths=bg_paths,
            face_image_dir=face_image_dir,
            max_frames_per_video=1750
        )
        
        # Start with fewer workers for testing
        dataloader = DataLoader(
            dataset,
            batch_size=16,  # Increased batch size
            shuffle=True,
            num_workers=0,  # Start with 0 workers for testing
            pin_memory=torch.cuda.is_available()
        )
        
        print("Testing dataloader...")
        for i, (x, y) in enumerate(dataloader):
            print(f"Batch {i}: x shape = {x.shape}, y shape = {y.shape}")
            if i == 2:  # Limit output
                break
        print("Dataloader test successful!")
        
    except Exception as e:
        print(f"Error: {str(e)}")

Creating dataset...

STARTING DATA INDEXING

--------------------
INDEXING OWNER VIDEOS
--------------------

Indexing: ../data/owner/owner.mp4
Indexing 1750 frames from 3150 total frames
Indexed 1750 frames from ../data/owner/owner.mp4

--------------------
INDEXING PET VIDEOS
--------------------

Indexing: ../data/pet/pet.mp4
Indexing 915 frames from 915 total frames
Indexed 915 frames from ../data/pet/pet.mp4

--------------------
INDEXING FACE IMAGES FROM DIRECTORY
--------------------
Directory: ../data/other_people/other_people
Found 7219 face images
Randomly sampling 2500 faces...
Indexed 100/2500 face images
Indexed 200/2500 face images
Indexed 300/2500 face images
Indexed 400/2500 face images
Indexed 500/2500 face images
Indexed 600/2500 face images
Indexed 700/2500 face images
Indexed 800/2500 face images
Indexed 900/2500 face images
Indexed 1000/2500 face images
Indexed 1100/2500 face images
Indexed 1200/2500 face images
Indexed 1300/2500 face images
Indexed 1400/2500 face 

<hr>

## 🧠 Model Training Loop Explained

This section explains the logic behind training our deep learning model using PyTorch.

### 🔁 Epochs

We train the model over **100 epochs** — one epoch is a complete pass over the entire training dataset. Repeating this allows the model to learn gradually and improve its performance over time.

### 📊 Metrics Tracked

During training, we monitor two key metrics every epoch:
- **Loss** (📉): Measures how far off the model’s predictions are from the actual labels.
- **Accuracy** (✅): Tells us how many predictions were correct out of the total.

These are stored in `train_losses` and `train_accuracies` lists to help us visualize progress later.

### 🔄 Training Steps Per Epoch

Each epoch consists of several steps that happen for every batch of images:

1. **Set model to training mode** 🏋️  
   Enables dropout and batch normalization, which behave differently during training.

2. **Loop through batches of data** 📦  
   We use `tqdm` to wrap our DataLoader, which gives us a nice real-time progress bar.

3. **Zero out gradients** 🧽  
   We reset gradients using `optimizer.zero_grad()` so that past gradient values don’t accumulate.

4. **Forward pass** 📤  
   The input images are passed through the model to generate predictions.

5. **Calculate loss** 📏  
   We use a loss function (CrossEntropyLoss) to compute how wrong the model was.

6. **Backward pass** 🧮  
   We call `.backward()` on the loss to compute gradients of the model parameters.

7. **Update weights** 🔧  
   The optimizer (Adam) updates the model parameters based on the gradients with `optimizer.step()`.

8. **Track loss and accuracy** 🧾  
   We add the loss to a running total and count how many predictions were correct.

### 🧮 After Each Epoch

Once all batches are processed in an epoch:
- We calculate the **average loss** and **accuracy** for that epoch.
- We print this info to the console.
- We append the values to our tracking lists for later use.

### 💾 Saving the Model

At the end of training, we save the learned weights of our model to a file `entity_classifier.pth`. This allows us to reuse the model later without retraining it from scratch! 💡

### 📈 Plotting Metrics

We use `matplotlib` to generate two side-by-side plots:
- **Training Loss per Epoch** (in red)
- **Training Accuracy per Epoch** (in green)

These plots visually show how well the model learned over time — ideally, loss should decrease while accuracy increases! 📉📈

In [None]:
def train_model(model, train_loader, criterion, optimizer, num_epochs=10, device='cuda'):
    print("🔥 Starting training process...")
    # Move model to the specified device
    model.to(device)
    
    # Lists to store metrics for plotting
    train_losses = []
    train_accuracies = []
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        # Initialize progress bar
        loop = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=True)
        
        for batch_idx, (images, labels) in enumerate(loop):
            # Move data to device
            images = images.to(device)
            labels = labels.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass - simplified since we're no longer dealing with lists of frames
            outputs = model(images)
            
            # Calculate loss
            loss = criterion(outputs, labels)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Update statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Update progress bar
            loop.set_postfix(
                loss=running_loss/(batch_idx+1),
                acc=100.*correct/total
            )
        
        # Calculate epoch metrics
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100. * correct / total
        
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)
        
        print(f'Epoch {epoch+1} | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.2f}%')
    
        # Save model
        torch.save(model.state_dict(), 'entity_classifier.pth')
        print('Model saved to entity_classifier.pth')
    
    # Plot results
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Train Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_results.png')  # Save the plot before showing it
    plt.show()
    
    return model

# Complete setup and training code
def setup_and_train(pretrained_model, dataloader, learning_rate=0.001, num_epochs=10):
    # Define loss function and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    
    # Configure optimizer
    optimizer = torch.optim.Adam(pretrained_model.parameters(), lr=learning_rate)
    
    # Set device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Train the model
    print("Starting model training...")
    trained_model = train_model(
        model=pretrained_model,
        train_loader=dataloader,
        criterion=criterion,
        optimizer=optimizer,
        num_epochs=num_epochs,
        device=device
    )
    
    return trained_model

trained_model = setup_and_train(
    pretrained_model=model,  # Your model instance (e.g., models.resnet18())
    dataloader=dataloader,   # Your DataLoader from VideoFrameDataset
    learning_rate=5e-5,          # Optional (default is 0.001)
    num_epochs=10                 # Optional (default is 10)
)

Starting model training...
🔥 Starting training process...


Epoch 1/10: 100%|███████████████████████████████████████████████| 2161/2161 [49:12<00:00,  1.37s/it, acc=95, loss=0.16]


Epoch 1 | Loss: 0.1600 | Acc: 95.01%
Model saved to entity_classifier.pth


Epoch 2/10:   4%|█▋                                           | 84/2161 [01:42<35:00,  1.01s/it, acc=97.6, loss=0.0668]

🎉 **Success!** Our model is now fully trained and we have the metrics to prove it! 🎉

We've completed the training process, and our model is ready for action. By evaluating the **training loss** and **accuracy** over the epochs, we can confirm the model’s progress and how well it has learned from the data.

Now that we've achieved solid performance on the training set, it's time to move on to **real-time inference**! 🚀

<hr>

### What’s next? 👀
Instead of continuing in a notebook, the next step is to integrate the trained model into a **PyCharm application** for real-time usage. We'll use the model to make predictions on live data, such as from a **webcam feed**. This allows us to test the model’s performance in real-world scenarios and see how it handles new, unseen data in an interactive application.

Let’s take the model for a spin and see how it performs when it’s really put to work in a **PyCharm app**. 💻✨