# 🧠 Real-time Entity Classifier - CNN Architecture & Training

This notebook handles both definition and training of a MobileNetV2-based CNN for real-time webcam classification:

**Detection Categories**:
- 🐱 **Pet** (Cat)
- 👤 **Owner**
- 🧍 **Other People**
- 🚫 **Background/Nobody**

**Key Features**:
- ⚡ Real-time inference: ~30ms/frame (640x480 resolution)
- 🔒 Privacy-focused: All processing on-device
- 🌟 Robust to: Lighting changes, partial occlusions
- 🏗️ Transfer learning: Fine-tuned from ImageNet weights

<hr>

## 📦 PyTorch & Project Imports Overview

### 🧠 Core PyTorch Components

- **`torch`**  
  Used for tensor operations, device management (`torch.cuda.is_available()`), and model saving (`torch.save()`).

- **`torch.nn`**  
  Used for building the classifier component of our model, including `nn.Linear`, `nn.ReLU`, `nn.Dropout`, and `nn.Sequential`. Also provides `CrossEntropyLoss` with label smoothing to handle class imbalance.

- **`torch.optim`**  
  Provides the `Adam` optimizer with learning rate and weight decay parameters for training our classifier layers.

- **`torch.utils.data.Dataset, DataLoader`**  
  `Dataset` is subclassed to create our custom `VideoFrameDataset`. The `DataLoader` wraps this dataset with batch processing, shuffling, and provides an iterator for training.

### 🖼️ Computer Vision & Image Processing

- **`PIL.Image`**  
  Used to convert between NumPy arrays and PIL images for compatibility with torchvision transforms.

- **`cv2` (OpenCV)**  
  Handles video operations (`VideoCapture`) to read and extract frames from video files, and image color space conversion (`cvtColor`).

- **`torchvision.models`**  
  Used to load a pre-trained `mobilenet_v2` as our base feature extractor, which we then modify for our specific classification task.

- **`torchvision.models.mobilenetv2.MobileNet_V2_Weights`**
  Provides pre-trained weight configurations for MobileNetV2 from PyTorch's model zoo. Used to initialize our model

- **`torchvision.transforms`**  
  Builds image transformation pipelines for:  
  - Standard preprocessing: resize, crop, normalization  
  - Data augmentation: random flips, rotations, color jitter, perspective changes

### 📁 Data Handling & Utilities

- **`os`**  
  Checks file existence (`os.path.exists()`) and constructs file paths (`os.path.join()`).

- **`glob`**  
  Finds image files with specific extensions (`.jpg`, `.jpeg`, `.png`) within directories.

- **`random`**  
  Samples subsets of data when we have too many files (`random.sample()`), helping maintain dataset balance.

- **`numpy as np`**  
  Used to transform images to NumPy arrays.

### 📊 Visualization & Progress Tracking

- **`matplotlib.pyplot as plt`**  
  Creates and saves training visualization plots with loss and accuracy metrics after training completes.

- **`tqdm`**  
  Wraps the training loop to provide a progress bar with real-time metrics (loss and accuracy) during model training.

These libraries together form a complete pipeline for processing video data, extracting frames, building and training a deep learning model for multi-class classification, and visualizing the results.

In [None]:
# ======================
# Core PyTorch Components
# ======================
import torch               # Base library for tensors, GPU ops, and autograd
import torch.nn as nn      # Neural network layers (Linear, Conv2d, etc.)
import torch.optim as optim  # Optimizer (Adam) for training
from torch.utils.data import Dataset, DataLoader  # Custom datasets + efficient batching

# ======================
# Computer Vision
# ======================
import cv2                        # Video capture and frame processing
from PIL import Image             # Image loading and conversion
from torchvision import models    # Pretrained models (MobileNetV2)
from torchvision.models.mobilenetv2 import ( # MobileNetV2-specific:
    MobileNet_V2_Weights                     # - Pretrained weight configurations
)
from torchvision import transforms  # Image preprocessing/augmentations

# ======================
# Data Pipeline
# ======================
import os                        # File path operations
from glob import glob            # Pattern matching for image/video files
import random                    # Shuffling datasets and sampling
import numpy as np               # Used to convert images to NumPy arrays

# ======================
# Training Utilities
# ======================
import matplotlib.pyplot as plt  # Plotting loss/accuracy curves
from tqdm import tqdm            # Progress bars for training loops

<hr>

## MobileNetV2 Transfer Learning Architecture Explanation 🚀

### Project Goal: 4-Way Classification System 🎯

This implementation adapts a pre-trained MobileNetV2 model for a specific 4-class recognition task using transfer learning principles. The model is designed to classify webcam frames into:

1. **Pet** – A specific cat 🐱  
2. **Owner**👨  
3. **Other Person** – Any human who is not the owner🧍  
4. **None/Background** – Empty frames with no people or pets 🚫

### Base Model: MobileNetV2 📱

```python
model = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1).to(device)
```

#### Why MobileNetV2? 🤔

- **Efficiency** ⚡: MobileNetV2 is designed specifically for mobile and edge devices, making it computationally efficient while maintaining good accuracy
- **Depthwise Separable Convolutions** 🧩: Uses a factorized form of standard convolutions that drastically reduces parameter count and computational cost
- **Inverted Residuals** 🔄: Unlike traditional residual blocks, MobileNetV2 uses inverted residuals with linear bottlenecks, which help preserve information flow while keeping the model lightweight
- **Pre-trained Knowledge** 🧠: ImageNet pre-training provides the model with powerful feature extraction capabilities that can transfer well to new domains
- **Size-Performance Tradeoff** ⚖️: Offers an excellent balance between model size (~14M parameters) and performance for real-time or resource-constrained applications

### Transfer Learning Approach 🔄

```python
# Freeze the feature extraction layers
for param in model.parameters():
    param.requires_grad = False
```

#### Benefits of Feature Freezing ❄️

- **Training Efficiency** 🏎️: By freezing the convolutional backbone, we dramatically reduce the number of trainable parameters (from millions to thousands)
- **Prevents Overfitting** 🛡️: With limited training data, updating only the classifier prevents the model from overfitting to peculiarities in our small dataset
- **Knowledge Preservation** 📚: Retains the robust feature extraction capabilities learned from ImageNet's diverse 1.2+ million images
- **Faster Convergence** 🏁: The classifier can adapt to the new task much more quickly when starting from well-formed features

### Custom Classifier Head 👑

```python
model.classifier[1] = nn.Sequential(
    nn.Linear(model.classifier[1].in_features, 256),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(256, 4)
).to(device)
```

#### Classifier Design Choices 🧩

- **Intermediate Hidden Layer (256 neurons)** 🧬:
  - Provides greater representational capacity than a direct mapping
  - Allows the model to learn more complex decision boundaries between classes
  - 256 neurons balances expressivity and computational efficiency
  - Particularly helpful for distinguishing between similar human faces (owner vs. other person) 👥

- **ReLU Activation** ⚡:
  - Introduces non-linearity to capture complex relationships
  - Mitigates the vanishing gradient problem with its non-saturating form
  - Computationally efficient compared to tanh or sigmoid

- **Dropout (0.2)** 🎭:
  - Implements regularization by randomly deactivating 20% of neurons during training
  - Prevents co-adaptation of neurons (neurons becoming too dependent on each other)
  - Forces the network to learn redundant representations, improving generalization
  - Rate of 0.2 is conservative, providing regularization while preserving most information flow
  - Particularly important for this task since the dataset likely contains many similar frames of the same subjects

- **Output Layer (4 neurons)** 🎬:
  - One neuron per class (pet, owner, other person, background)
  - Used with CrossEntropyLoss which applies softmax internally to produce probabilities

### Weight Initialization Strategy 🎲

```python
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.zeros_(m.bias)
```

#### Xavier/Glorot Initialization Benefits ✨

- **Variance Control** 📊: Maintains the variance of activations and gradients across layers
- **Prevents Signal Vanishing/Exploding** 💥: Scaling weights based on the layer size helps signal propagate effectively
- **Uniform Distribution** 📈: Draws weights from a uniform distribution within a carefully calculated range
- **Faster Convergence** 🚀: Well-initialized weights allow the model to reach optimal regions more quickly
- **Zero Bias Initialization** 0️⃣: Starting biases at zero is a standard practice that works well with ReLU when batch normalization isn't used

### Model Deployment 🚀

```python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
```

#### Hardware Acceleration 💻

- **Dynamic Device Selection** 🔍: Code automatically selects GPU if available, falling back to CPU if necessary
- **Memory Management** 💾: Moving the model to the appropriate device ensures efficient memory usage
- **Computation Speed** ⚡: Running on GPU can offer 10-100x speedup for neural network operations

### Architecture Summary 📝

This architecture exemplifies modern transfer learning best practices by:

1. Leveraging a pre-trained, efficient CNN architecture 🏗️
2. Freezing feature extraction layers to preserve learned representations ❄️
3. Implementing a purpose-built classifier for the specific task 🎯
4. Using appropriate regularization techniques to prevent overfitting 🛡️
5. Applying proven weight initialization strategies for faster convergence 🏁

The resulting model balances computational efficiency with classification performance, making it suitable for deployment in resource-constrained environments while still maintaining high accuracy for this pet and person recognition task. 🤖👍

### Application-Specific Advantages 🌟

For this specific pet/owner recognition task:

1. **Fine-Grained Recognition** 🔍: The model can learn subtle differences between a specific cat and other animals, or between the owner and other people
3. **Real-time Processing** ⏱️: MobileNetV2's efficiency enables real-time classification on webcam streams
4. **Low Resource Requirements** 💪: The architecture can run on modest hardware like laptops without dedicated GPUs
5. **Quick Training** 🏎️: By using transfer learning, the model can be trained with relatively few examples of each class

In [None]:
# ======================
# Model Definition
# ======================
# Set computation device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pre-trained MobileNetV2 with ImageNet weights
# MobileNetV2 is chosen for its efficiency, performance, and lightweight nature
model = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1).to(device)

# Freeze the feature extraction layers to preserve pre-trained knowledge
# This implements transfer learning - we only train the classifier
for param in model.parameters():
    param.requires_grad = False

# Replace the classifier with our custom head for 4-class classification
# Architecture: Input features → 256 neurons → ReLU → Dropout → 4 output classes
model.classifier[1] = nn.Sequential(
    nn.Linear(model.classifier[1].in_features, 256),  # Hidden layer with 256 neurons
    nn.ReLU(),                                        # Non-linearity
    nn.Dropout(0.2),                                  # Regularization to prevent overfitting
    nn.Linear(256, 4)                                 # Output layer for our 4 classes
).to(device)

# Define weight initialization function for better convergence
# Xavier/Glorot initialization helps control variance of activations across layers
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)  # Uniform distribution within proper bounds
        nn.init.zeros_(m.bias)             # Initialize biases to zero

# Apply weight initialization to our classifier only
model.classifier[1].apply(init_weights)

# Verify model setup
print(f"Model loaded on {device}")
print("Classifier structure:")
print(model.classifier)

<hr>

## Training Configuration: Optimizer and Loss Function 🛠️

### Optimizer: Adam with Selective Training 🎯

```python
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(trainable_params, lr=0.001, weight_decay=1e-5)
```

#### Parameter Selection Strategy 🔍

We use `filter(lambda p: p.requires_grad, model.parameters())` to select only trainable parameters. This is critical for our transfer learning approach:

- **Efficiency Boost** ⚡: By optimizing only the classifier parameters ( ~few thousand) instead of the entire model ( ~14 million), we drastically reduce computation time and memory requirements
- **Focused Learning** 🔭: Updates are restricted to the task-specific components (the classifier), leaving the pre-trained feature extractor untouched
- **Lambda Function Elegance** 💻: The lambda function creates a clean, one-line filter that selects parameters where `requires_grad=True`

#### Why Adam? 🤔

Adam (Adaptive Moment Estimation) combines the benefits of two other extensions of stochastic gradient descent:

- **Adaptive Learning Rates** 📊: Automatically adjusts learning rates for each parameter based on historical gradients
- **Momentum** 🏎️: Accelerates convergence by adding a fraction of the previous update direction
- **RMSProp Integration** 📉: Adapts learning rates based on the average of recent magnitudes of gradients
- **Sparse Gradient Handling** 🌵: Performs well even when gradients are sparse or noisy

#### Hyperparameter Choices 🎛️

- **Learning Rate (5e-5)** 🏁: 
  - A smaller learning rate typically used for fine-tuning pre-trained models.
  - Helps ensure that the model updates its weights gradually, preventing large, unstable changes.
  - Useful when working with pre-trained weights, as it allows the model to make subtle adjustments without drastically altering the learned features.
  - Small enough to avoid overshooting the optimal solution, promoting stable convergence, especially in later training stages.

- **Weight Decay (1e-5)** 🌱: 
  - Implements L2 regularization by penalizing large weights
  - Helps prevent overfitting by encouraging the model to use smaller weights
  - Value of 1e-5 is conservative, providing gentle regularization
  - Particularly valuable for our problem where limited training data could lead to overfitting

### Loss Function: CrossEntropyLoss with Label Smoothing 📊

```python
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
```

#### Why CrossEntropyLoss? 🎯

CrossEntropyLoss is the standard choice for multi-class classification problems:

- **Softmax Integration** 🔄: Internally applies softmax to convert logits to probabilities
- **Class Probability** 📊: Measures the difference between predicted probability distribution and the target distribution
- **Numerical Stability** 🛡️: Implements log-softmax and negative log-likelihood in a numerically stable way
- **Single-Label Focus** 🏷️: Optimized for problems where each example belongs to exactly one class

#### Label Smoothing (0.1) 🥤

Label smoothing is a regularization technique that modifies the target distribution:

- **Target Softening** ☁️: Instead of using hard labels (0,0,1,0), it creates soft targets (0.03,0.03,0.9,0.03)
- **Confidence Penalty** 📉: Discourages the model from becoming too confident in its predictions
- **Overfitting Prevention** 🛡️: Makes the model less likely to memorize training data noise or errors
- **Class Imbalance Handling** ⚖️: Particularly valuable for our use case with potential class imbalance (may have more background frames than pet frames)
- **Value Selection (0.1)** 🎚️: 
  - 0.1 is a moderate smoothing value that provides regularization benefits
  - High enough to prevent overconfidence
  - Low enough to maintain class separation
  
### Combined Effect on Training Dynamics 🔄

Together, these choices create a training configuration that:

1. **Focuses computational effort** on adapting the classifier to our specific classes
2. **Adapts learning dynamically** based on gradient behavior during training
3. **Regularizes from multiple angles** (weight decay and label smoothing) to prevent overfitting
4. **Improves generalization** especially in the face of class imbalance or limited training data
5. **Accelerates convergence** compared to standard SGD or simpler optimizers

This configuration represents modern deep learning best practices for transfer learning on classification tasks with limited, potentially imbalanced datasets.

In [None]:
# Optimizer (only train classifier parameters)
trainable_params = filter(lambda p: p.requires_grad, model.parameters())  
# Filter out the parameters that require gradients (only the custom classifier will be trained)
# This is efficient because we freeze the feature extraction layers, so we don't need to train them again.
optimizer = optim.Adam(trainable_params, lr=5e-5, weight_decay=1e-5)  
# Adam optimizer is chosen for adaptive learning rates, which helps to converge faster and more efficiently
# lr=5e-5 (=0.00005): A learning rate of 5e-5 is a commonly used starting point for fine-tuning pre-trained models.
# It allows the model to adjust the new classifier layers without drastically altering the pre-trained weights.
# It's generally small enough to avoid overfitting but large enough to enable the classifier to learn efficiently.
# weight_decay=1e-5: L2 regularization helps prevent overfitting by penalizing large weights (helps generalization).

# Loss function
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  
# CrossEntropyLoss is commonly used for multi-class classification tasks, as it combines softmax and negative log-likelihood loss
# label_smoothing=0.1: This technique reduces the confidence of the model when predicting the target class, 
# making it less likely to overfit on noisy or incorrect labels and improving generalization, especially for imbalanced classes

<hr>

## 🎥📸 VideoImageDataset Class: A Comprehensive Overview

### 🏗️ Introduction

The `VideoImageDataset` class is a custom PyTorch `Dataset` implementation designed for handling **mixed media (videos 🎥 and images 📸)** in machine learning classification tasks. Although initially designed for person/pet 🐕 classification, it is general enough to apply to various domains like surveillance, smart home applications, or content moderation.

This class demonstrates advanced data engineering strategies to improve efficiency, flexibility, and robustness during model training.

### ⚙️ Core Functionality

This class addresses several key challenges in multimedia datasets:

1. **Mixed Media Handling** 🔄  
   Combines both video frames and image files into a unified dataset. No need to treat them separately during model training.

2. **Memory Efficiency** 🧠  
   Lazy loads frames only when needed and uses an internal cache to avoid repeated expensive disk I/O.

3. **Data Augmentation** 🎨  
   Augmentation strategies can be applied consistently across both images and video frames, which improves generalization and robustness.

4. **Robust Error Handling** 🛡️  
   Can skip unreadable or corrupted files without interrupting training — useful when datasets come from diverse or noisy sources.

### 🏛️ Class Architecture

#### 🔧 Initialization

Constructor parameters:

- `data_sources`: Dict mapping class names (e.g., "pet", "owner") to file paths (video/image)
- `label_mapping`: Maps each class name to a unique integer (used as label)
- `transform`: Optional torchvision transform pipeline (e.g., Resize, Normalize)
- `augment`: Enables or disables augmentation (`True/False`)
- `num_augments`: Number of augmentations to generate per frame/image (helps simulate diversity)
- `max_frames_per_video`: Limits how many frames are sampled per video
- `frame_interval`: Controls temporal sampling by picking every nth frame
- `cache_size`: Max number of decoded frames to keep in memory

**Initialization flow:**

1. 📥 Stores all input arguments  
2. 🧠 Initializes internal memory-efficient cache (FIFO by default)  
3. 🗂️ Builds an internal index of data with paths, types, labels, and metadata  
4. 🧪 Validates all media files (image size, video readability)  
5. 📊 Reports useful statistics (class distribution, number of samples, etc.)

### 📂 Data Indexing

#### `_index_all_data()`

This core method:

- Iterates over each class label in `data_sources`
- Differentiates between video and image paths
- Indexes each file with a type tag (image or video)
- Validates existence/readability early to avoid runtime errors
- Collects metadata: dimensions, number of frames, FPS, etc.

#### `_index_image(path)`

- ✔️ Confirms file exists and is readable
- 🖼️ Loads dimensions and filters out corrupt or extremely small files
- 🧾 Records metadata to avoid reloading this information later

#### `_index_video(path)`

- ✔️ Uses `cv2.VideoCapture` to extract video info
- ⏳ Computes duration, total frames, and ensures file is valid
- 🎯 Selects which frames to use (based on the sampling strategy)
- 📌 Saves metadata for lazy loading later

### 🎯 Data Sampling Strategies

Three customizable sampling strategies for videos:

1. **Uniform Sampling** ⏱️  
   Evenly spaced frames (default) — useful for general-purpose training.

2. **Front-Weighted Sampling** ⏩  
   Biases frame selection toward the beginning — useful for intro-heavy clips.

3. **Random Sampling** 🎲  
   Randomly selects a subset of frames — adds stochasticity to training.

Can be set via a sampling policy attribute or parameter.

### 📤 Data Loading

#### `__getitem__(index)`

- 🧮 Computes actual sample from the internal index  
- 🔍 Loads the specific frame or image  
- 🧠 Checks if it's already cached; if not, reads and caches it  
- 🎨 Applies augmentation if enabled  
- 🧪 Applies base transform (resizing, normalization, etc.)  
- 🏷️ Returns a `(tensor_image, label)` pair

### 🖼️ Frame Loading Logic

#### `_load_video_frame(video_path, frame_index)`

- Checks frame cache  
- If missing, reads frame with OpenCV and adds to cache  
- Converts BGR → RGB and normalizes  
- Handles any reading or decoding errors gracefully

#### `_load_image(image_path)`

- Also supports caching  
- Performs format conversion and integrity check  
- Skips unreadable files and optionally returns dummy data

### 🎨 Data Augmentation

Augmentations can simulate realistic conditions, such as:

1. **Color Augmentations** 🌈  
   - Brightness/contrast shifts  
   - Hue/saturation jitter  

2. **Geometric Augmentations** 📐  
   - Random crops, flips, rotations  

3. **Mixed Augmentations** 🎭  
   - Combine both color and spatial transforms  

These can be toggled or extended with `torchvision.transforms`.

### 💾 Cache Management

- 🗃️ Internal dict stores decoded frames
- 🔁 Uses FIFO eviction when full (to save memory)
- 📉 Tracks hit/miss rate for performance diagnostics
- 🛠️ Makes repeated accesses much faster (e.g., DataLoader with `num_workers > 0`)

### 🧠 Design Choices Explained

#### ⏳ Lazy Loading  
Only loads what you use, when you use it. Saves RAM and GPU memory.

#### 💾 Why Use a Cache?  
Reading from disk is slow — cache helps especially when using augmentation or repeated data access.

#### 🎨 Multiple Augmentations  
Improves model robustness and performance across unseen scenarios or lighting conditions.

#### 📊 Statistics Tracking  
Important for debugging, understanding class balance, and verifying dataset health.

#### 🛡️ Robust Error Handling  
Doesn’t crash on corrupted data — reports it and optionally replaces with dummy image.

### 🏆 Practical Applications  

Great for:  
1. 🐕 Pet/Owner Recognition  
2. 🔒 Security Systems  
3. 🏠 Smart Home Applications  

### 🎬 Conclusion  

The VideoImageDataset class demonstrates advanced PyTorch techniques for:  
- 🧠 Memory efficiency  
- ⚡ Performance  
- 🔄 Flexibility 

In [None]:
class VideoImageDataset(Dataset):
    def __init__(self, data_sources, label_mapping, transform=None, 
             augment=False, num_augments=3, max_frames_per_video=300,
             frame_interval=1, cache_size=500):
        """
        Initialize a dataset which can handle both video and image data for person/pet classification.
        
        Args:
            data_sources (dict): Dictionary containing paths to video/image files organized by class
                               Format: {'owner': {'video_paths': [...], 'image_dirs': [...]}, 'pet': {...}}
            label_mapping (dict): Mapping from class names to integer labels 
                                 (e.g., {'owner': 0, 'pet': 1})
            transform (callable, optional): Torchvision transforms for image preprocessing
            augment (bool): Whether to apply data augmentation
            num_augments (int): Number of augmented versions to create per original frame
            max_frames_per_video (int): Maximum frames to sample from each video
            frame_interval (int): Sample every nth frame from videos
            cache_size (int): Maximum number of frames to keep in memory cache
        """
        
        # Store initialization parameters
        self.data_sources = data_sources        # Source paths organized by class
        self.label_mapping = label_mapping      # Class name to label index mapping
        self.transform = transform              # Image preprocessing pipeline
        self.augment = augment                  # Data augmentation flag
        self.num_augments = num_augments        # Number of augmentations per frame
        self.max_frames_per_video = max_frames_per_video  # Frame sampling limit
        self.frame_interval = frame_interval    # Interval between sampled frames
        
        # Setup caching system
        self.use_cache = cache_size > 0
        if self.use_cache:
            self.cache_limit = cache_size
            self.frame_cache = {}
            self.cache_stats = {'hits': 0, 'misses': 0}
        
        # Data storage structures
        self.frame_sources = []  # Stores tuples of (source_type, path, frame_idx, label, metadata)
        self.class_counts = {name: 0 for name in label_mapping}  # Tracks samples per class
        self.video_metadata = {}  # Store video properties
        self.image_metadata = {}  # Store image properties
        
        # Track the source files for statistics
        self.source_tracking = {class_name: {'video_paths': set(), 'image_paths': set()} 
                               for class_name in label_mapping}
        
        # Build the dataset by indexing all available data
        self._index_all_data()
        
        # Validate that we found some data
        if len(self.frame_sources) == 0:
            raise ValueError("No valid frames were indexed. Please check your input paths.")
        
        # Calculate total samples including augmentations
        total_samples = len(self.frame_sources)
        if self.augment:
            total_samples *= (1 + self.num_augments)
        
        # Print dataset statistics
        print(f"\nDataset created with {total_samples} samples")
        print("Class distribution:")
        for name, idx in label_mapping.items():
            # Calculate effective count including augmentations
            count = self.class_counts[name] * (1 + self.num_augments if self.augment else 1)
            print(f"- {name}: {count} samples ({count/total_samples:.1%})")

    def _index_all_data(self):
        """
        Master method to index all data types including videos and image directories.
        This processes each class defined in data_sources, including owner, pet, other people, and background.
        Progress is reported to console during indexing operations.
        """
        print("\n" + "="*50)
        print("STARTING DATA INDEXING")
        print("="*50)
        
        # Track statistics for final report
        indexed_counts = {class_name: {'images': 0, 'videos': 0} for class_name in self.label_mapping}
        
        # Process each class in the data sources
        for class_name, sources in self.data_sources.items():
            if class_name not in self.label_mapping:
                print(f"Warning: Class '{class_name}' not found in label_mapping, skipping")
                continue
                
            class_label = self.label_mapping[class_name]
            
            print(f"\n{'-'*20}")
            print(f"INDEXING {class_name.upper()} DATA")
            print(f"{'-'*20}")
            
            # Process image directories first
            if 'image_dirs' in sources and sources['image_dirs']:
                image_dirs = sources['image_dirs'] if isinstance(sources['image_dirs'], list) else [sources['image_dirs']]
                
                for directory in image_dirs:
                    if os.path.isdir(directory):
                        print(f"\nIndexing {class_name} images from: {directory}")
                        image_paths = self._get_image_paths(directory)
                        
                        if image_paths:
                            # Sample a reasonable number of images if there are too many
                            sample_size = min(500, len(image_paths))
                            if len(image_paths) > sample_size:
                                print(f"Found {len(image_paths)} images, sampling {sample_size}...")
                                sampled_paths = random.sample(image_paths, sample_size)
                            else:
                                print(f"Found {len(image_paths)} images, processing all...")
                                sampled_paths = image_paths
                            
                            # Process the images
                            successful_count = 0
                            for i, img_path in enumerate(sampled_paths, 1):
                                if i % 50 == 0:  # Progress update every 50 images
                                    print(f"Indexed {i}/{len(sampled_paths)} {class_name} images")
                                if self._index_image(img_path, class_label, class_name):
                                    indexed_counts[class_name]['images'] += 1
                                    successful_count += 1
                                    # Add to tracking for statistics
                                    self.source_tracking[class_name]['image_paths'].add(img_path)
                            
                            print(f"Successfully indexed {successful_count} {class_name} images")
                            print()
                        else:
                            print(f"No valid images found in {directory}")
                    else:
                        print(f"WARNING: {class_name} image directory does not exist: {directory}")
            
            # Then process videos if specified
            if 'video_paths' in sources and sources['video_paths']:
                video_paths = sources['video_paths'] if isinstance(sources['video_paths'], list) else [sources['video_paths']]
                
                for path in video_paths:
                    if os.path.exists(path):
                        print(f"\nIndexing {class_name} video: {path}")
                        frames_processed = self._index_video(path, class_label, class_name)
                        if frames_processed > 0:
                            indexed_counts[class_name]['videos'] += 1
                            # Add to tracking for statistics
                            self.source_tracking[class_name]['video_paths'].add(path)
                        print(f"Processed {frames_processed} frames from {class_name} video")
                    else:
                        print(f"WARNING: {class_name} video path does not exist: {path}")
        
        # Print summary of indexed data
        print("\n" + "="*50)
        print("DATA INDEXING COMPLETE")
        print("="*50)
        print("\nINDEXING SUMMARY:")
        for class_name in self.label_mapping:
            counts = indexed_counts.get(class_name, {'images': 0, 'videos': 0})
            print(f"{class_name.capitalize()} data: {counts['images']} images, {counts['videos']} videos")
    
    def _get_image_paths(self, directory):
        """
        Get all image file paths from a directory.
        Supports .jpg, .jpeg, and .png files.
        
        Args:
            directory (str): Path to the directory containing images
            
        Returns:
            list: List of image file paths
        """
        return glob(os.path.join(directory, "*.jpg")) + \
               glob(os.path.join(directory, "*.jpeg")) + \
               glob(os.path.join(directory, "*.png"))

    def _index_video(self, path, label, class_name):
        """
        Index frames from a video file and store references for lazy loading.
        
        Args:
            path (str): Path to the video file
            label (int): Numeric label for classification
            class_name (str): Class name for tracking statistics
        
        Returns:
            int: Number of frames successfully indexed
        """
        frames_indexed = 0
        
        try:
            # Validate file exists
            if not os.path.exists(path):
                print(f"Error: Video file not found - {path}")
                return frames_indexed
            
            # Open video capture
            cap = cv2.VideoCapture(path)
            if not cap.isOpened():
                print(f"Error: Could not open video - {path}")
                return frames_indexed
            
            # Get video properties
            fps = cap.get(cv2.CAP_PROP_FPS)
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            duration = total_frames / fps if fps > 0 else 0
            
            if total_frames <= 0:
                print(f"Warning: Video has {total_frames} frames - {path}")
                cap.release()
                return frames_indexed
            
            # Calculate exact number of frames to sample
            target_frames = min(self.max_frames_per_video, total_frames)
            
            # Different sampling strategies based on video length
            if total_frames <= target_frames:
                # If video is short, take all frames with frame_interval
                frame_indices = list(range(0, total_frames, self.frame_interval))
            else:
                # For longer videos, choose between:
                # 1. Uniform sampling (default)
                # 2. Front-weighted sampling (more frames from beginning)
                # 3. Random sampling
                sampling_strategy = getattr(self, 'sampling_strategy', 'uniform')
                
                if sampling_strategy == 'front_weighted':
                    # Sample more from the beginning (useful for pet/owner videos)
                    first_half = int(total_frames * 0.5)
                    first_half_frames = int(target_frames * 0.7)
                    second_half_frames = target_frames - first_half_frames
                    
                    step1 = max(self.frame_interval, first_half // first_half_frames)
                    step2 = max(self.frame_interval, (total_frames - first_half) // second_half_frames)
                    
                    frame_indices = list(range(0, first_half, step1))[:first_half_frames]
                    frame_indices += list(range(first_half, total_frames, step2))[:second_half_frames]
                elif sampling_strategy == 'random':
                    # Randomly sample frames
                    frame_indices = sorted(random.sample(range(total_frames), target_frames))
                else:  # uniform
                    # Uniform sampling throughout video
                    step = max(self.frame_interval, total_frames // target_frames)
                    frame_indices = list(range(0, total_frames, step))[:target_frames]
            
            print(f"Indexing {len(frame_indices)} frames from {total_frames} total frames ({duration:.1f}s)")
            
            # Store frame references for lazy loading
            video_info = {
                'path': path,
                'fps': fps,
                'total_frames': total_frames,
                'duration': duration,
                'width': int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
                'height': int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            }
            
            for frame_idx in frame_indices:
                self.frame_sources.append(("video", path, frame_idx, label, video_info))
                frames_indexed += 1
                self.class_counts[class_name] += 1
            
            # Add video metadata to tracking
            self.video_metadata[path] = video_info
            
            cap.release()
            print(f"Successfully indexed {frames_indexed} frames from {path}")
            
        except Exception as e:
            print(f"Error indexing video {path}: {str(e)}")
            traceback.print_exc()
        
        return frames_indexed
    
    def _index_image(self, img_path, label, class_name):
        """
        Index an image file and store reference for lazy loading.
        
        Args:
            img_path (str): Path to the image file
            label (int): Numeric label for classification
            class_name (str): Class name for tracking statistics
        
        Returns:
            bool: True if image was successfully indexed, False otherwise
        """
        try:
            # Validate file exists
            if not os.path.exists(img_path):
                print(f"Error: Image file not found - {img_path}")
                return False
            
            # Check if it's a valid image by reading dimensions
            # This helps catch corrupt images early without loading full content
            try:
                with Image.open(img_path) as img:
                    width, height = img.size
                    image_info = {
                        'path': img_path,
                        'width': width,
                        'height': height,
                        'format': img.format,
                        'mode': img.mode
                    }
                    
                    # Skip images that are too small
                    min_size = getattr(self, 'min_image_size', 32)
                    if width < min_size or height < min_size:
                        print(f"Skipping small image ({width}x{height}): {img_path}")
                        return False
            except Exception as img_error:
                print(f"Invalid image file {img_path}: {str(img_error)}")
                return False
            
            # Store reference to the image
            self.frame_sources.append(("image", img_path, 0, label, image_info))
            self.class_counts[class_name] += 1
            
            # Track image metadata
            self.image_metadata[img_path] = image_info
            
            return True
            
        except Exception as e:
            print(f"Error indexing image {img_path}: {str(e)}")
            return False
    
    def __len__(self):
        """
        Return the total length of the dataset including augmentations.
        
        Returns:
            int: Total number of samples in the dataset
        """
        base_len = len(self.frame_sources)
        if self.augment and base_len > 0:
            return base_len * (1 + self.num_augments)
        return base_len
    
    def __getitem__(self, idx):
        """
        Get a sample from the dataset with lazy loading and augmentation.
        
        Args:
            idx (int): Index of the sample to retrieve
            
        Returns:
            tuple: (image_tensor, label_tensor)
        """
        try:
            # Calculate which frame and whether this is an augmented version
            if self.augment:
                base_idx = idx // (1 + self.num_augments)
                aug_version = idx % (1 + self.num_augments)
            else:
                base_idx = idx
                aug_version = 0
            
            # Validate index
            if base_idx >= len(self.frame_sources):
                raise IndexError(f"Index {base_idx} out of range for dataset with {len(self.frame_sources)} items")
            
            # Get the frame source info
            source_info = self.frame_sources[base_idx]
            if len(source_info) >= 5:
                source_type, path, frame_idx, label, metadata = source_info
            else:
                source_type, path, frame_idx, label = source_info
                metadata = {}
            
            # Load the frame on demand
            if source_type == "video":
                frame = self._load_video_frame(path, frame_idx)
            else:  # image
                frame = self._load_image(path)
            
            # Track cache hits/misses if using cache
            if self.use_cache:
                self._update_cache_stats(path, frame_idx)
            
            # Apply augmentation if needed (ensure we only augment when aug_version > 0)
            if self.augment and aug_version > 0:
                frame = self._augment_frame(frame, aug_version)
            
            # Apply standard transformation
            if self.transform:
                if not isinstance(frame, Image.Image):
                    frame = Image.fromarray(frame)
                frame = self.transform(frame)
            
            return frame, torch.tensor(label, dtype=torch.long)
            
        except Exception as e:
            print(f"Error getting item {idx}: {str(e)}")
            # Log more details about the problematic item
            if 'base_idx' in locals() and base_idx < len(self.frame_sources):
                print(f"Problematic item details: {self.frame_sources[base_idx]}")
            
            # Return a default item instead of crashing
            dummy_shape = getattr(self, 'input_shape', (3, 224, 224))
            dummy_data = torch.zeros(dummy_shape)
            dummy_label = 0
            return dummy_data, torch.tensor(dummy_label, dtype=torch.long)
    
    def _load_video_frame(self, video_path, frame_idx):
        """
        Load a specific frame from a video file with caching support.
        
        Args:
            video_path (str): Path to the video file
            frame_idx (int): Index of the frame to load
            
        Returns:
            numpy.ndarray: RGB image as numpy array
        """
        # Check cache first if enabled
        if self.use_cache:
            cache_key = f"{video_path}_{frame_idx}"
            if cache_key in self.frame_cache:
                return self.frame_cache[cache_key]
        
        # Load frame from video
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            raise IOError(f"Failed to open video: {video_path}")
        
        # Position video to the specific frame
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        cap.release()
        
        if not ret or frame is None:
            raise ValueError(f"Failed to load frame {frame_idx} from {video_path}")
        
        # Convert from BGR to RGB color space
        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # Cache the frame if caching is enabled
        if self.use_cache:
            cache_key = f"{video_path}_{frame_idx}"
            self.frame_cache[cache_key] = rgb_frame
            
            # Limit cache size if needed
            if len(self.frame_cache) > self.cache_limit:
                # Remove oldest entry (simple FIFO implementation)
                oldest_key = next(iter(self.frame_cache))
                del self.frame_cache[oldest_key]
        
        return rgb_frame
    
    def _load_image(self, img_path):
        """
        Load an image file with caching support.
        
        Args:
            img_path (str): Path to the image file
            
        Returns:
            numpy.ndarray: RGB image as numpy array
        """
        # Check cache first if enabled
        if self.use_cache:
            if img_path in self.frame_cache:
                return self.frame_cache[img_path]
        
        # Try loading with PIL first (better color handling)
        try:
            with Image.open(img_path) as pil_img:
                # Convert to RGB if needed
                if pil_img.mode != 'RGB':
                    pil_img = pil_img.convert('RGB')
                img = np.array(pil_img)
        except Exception as pil_error:
            # Fall back to OpenCV
            img = cv2.imread(img_path)
            if img is None:
                raise ValueError(f"Failed to load image {img_path}")
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Cache the image if caching is enabled
        if self.use_cache:
            self.frame_cache[img_path] = img
            
            # Limit cache size if needed
            if len(self.frame_cache) > self.cache_limit:
                # Remove oldest entry
                oldest_key = next(iter(self.frame_cache))
                del self.frame_cache[oldest_key]
        
        return img
    
    def _augment_frame(self, frame, aug_version=1):
        """
        Apply data augmentation to an image.
        
        Args:
            frame (numpy.ndarray): Input image as numpy array
            aug_version (int): Version of augmentation to apply (allows for different augmentation sets)
            
        Returns:
            numpy.ndarray or PIL.Image: Augmented image
        """
        # Convert frame to PIL if it's numpy array
        if not isinstance(frame, Image.Image):
            frame_pil = Image.fromarray(frame)
        else:
            frame_pil = frame
        
        # Apply different augmentation strategies based on version
        if aug_version % 3 == 1:
            # Color augmentations
            augmenter = transforms.Compose([
                transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.1),
                transforms.RandomGrayscale(p=0.05),
                transforms.RandomAutocontrast(p=0.2),
            ])
        elif aug_version % 3 == 2:
            # Geometric augmentations
            augmenter = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(15),
                transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
                transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            ])
        else:
            # Mixed augmentations
            augmenter = transforms.Compose([
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(10),
                transforms.RandomAffine(degrees=0, translate=(0.08, 0.08)),
            ])
        
        # Apply the augmentation
        augmented = augmenter(frame_pil)
        
        # Return in the same format as input
        if isinstance(frame, np.ndarray):
            return np.array(augmented)
        return augmented
    
    def _update_cache_stats(self, path, frame_idx):
        """
        Track cache hit/miss statistics for performance monitoring.
        
        Args:
            path (str): Path to the file
            frame_idx (int): Frame index (0 for images)
        """
        cache_key = f"{path}_{frame_idx}" if frame_idx > 0 else path
        
        if cache_key in self.frame_cache:
            self.cache_stats['hits'] += 1
        else:
            self.cache_stats['misses'] += 1
        
        # Log cache stats occasionally
        total = self.cache_stats['hits'] + self.cache_stats['misses']
        if total % 1000 == 0:
            hit_rate = self.cache_stats['hits'] / total * 100 if total > 0 else 0
            print(f"Cache stats: {hit_rate:.1f}% hit rate ({self.cache_stats['hits']}/{total})")
    
    def get_source_statistics(self):
        """
        Return statistics about the data sources in this dataset.
        
        Returns:
            dict: Dictionary containing counts of videos and images per class
        """
        stats = {}
        for class_name in self.label_mapping.keys():
            stats[class_name] = {
                'videos': len(self.source_tracking[class_name]['video_paths']),
                'images': len(self.source_tracking[class_name]['image_paths']),
                'frames': self.class_counts[class_name]
            }
                
        return stats


if __name__ == "__main__":
    # ======================
    # Dataset Configuration
    # ======================
    
    # Define class labels and their corresponding integer mappings
    label_mapping = {
        "owner": 0,       # Class 0: Owner 
        "pet": 1,         # Class 1: Pet
        "other": 2,       # Class 2: Other people
        "background": 3   # Class 3: Empty background/scenes
    }
    
    # Data source configurations
    data_sources = {
        "owner": {
            "video_paths": ["../data/owner/owner.mp4"],
            "image_dirs": ["../data/owner/images"]
        },
        "pet": {
            "video_paths": ["../data/pet/pet.mp4"],
            "image_dirs": ["../data/pet/images"]
        },
        "other": {
            "video_paths": ["../data/other_people/other_people.mp4"],
            "image_dirs": ["../data/other_people/images"]
        },
        "background": {
            "video_paths": ["../data/background/background.mp4"],
            "image_dirs": ["../data/background/images"]
        }
    }
    
    # Image preprocessing pipeline
    transform = transforms.Compose([
        transforms.Resize(256),            # Resize shorter side to 256px
        transforms.CenterCrop(224),        # Crop to 224x224 (standard for ImageNet models)
        transforms.ToTensor(),             # Convert to PyTorch tensor
        transforms.Normalize(              # Normalize with ImageNet stats
            mean=[0.485, 0.456, 0.406],    # RGB mean values
            std=[0.229, 0.224, 0.225]      # RGB standard deviations
        )
    ])
    
    # ======================
    # Dataset Creation
    # ======================
    try:
        print("Creating dataset...")
        dataset = VideoImageDataset(
            data_sources=data_sources,     # Combined video and image sources
            label_mapping=label_mapping,   # Class mappings
            transform=transform,           # Preprocessing
            augment=True,                  # Enable data augmentation
            num_augments=4,                # 4 augmented versions per frame
            max_frames_per_video=1750,     # Limit frames per video
            frame_interval=1,              # Sample every frame
            cache_size=1000                # Cache size for processed frames
        )
        
        # ======================
        # DataLoader Setup
        # ======================
        # Start with fewer workers for testing/debugging
        dataloader = DataLoader(
            dataset,
            batch_size=16,                 # Number of samples per batch
            shuffle=True,                  # Shuffle data each epoch
            num_workers=0,                 # Synchronous loading (safe for debugging)
            pin_memory=torch.cuda.is_available()  # Faster GPU transfer if available
        )
        
        # ======================
        # Sanity Check
        # ======================
        print()
        print("Testing dataloader...")
        class_counts = {i: 0 for i in range(len(label_mapping))}
        
        for i, (x, y) in enumerate(dataloader):
            # x: batch of images (shape: [batch_size, 3, 224, 224])
            # y: batch of labels (shape: [batch_size])
            print(f"Batch {i}: x shape = {x.shape}, y shape = {y.shape}")
            
            # Count class distribution
            for label in y.numpy():
                class_counts[label] += 1
            
            # Only check first 3 batches for quick verification
            if i == 2:
                break

        print()
        print("Class distribution in sampled batches:")
        for class_name, class_idx in label_mapping.items():
            print(f"  {class_name}: {class_counts[class_idx]} samples")

        print()
        print("Dataloader test successful!")
        
        # Report dataset statistics
        print()
        print(f"Total dataset size: {len(dataset)} samples")
        print(f"Sources loaded: {dataset.get_source_statistics()}")
        
    except Exception as e:
        # Handle any errors during dataset creation or loading
        print(f"Error during dataset setup: {str(e)}")
        import traceback
        traceback.print_exc()

<hr>

## 🧠 Model Training Loop

Let’s break down what’s *really* happening under the hood during training — from setup to plotting final results.

### 🚀 Setup and Launch

Before training kicks off:

- **Model to Device** 📦  
  Send the model to GPU (`cuda`) or CPU using `model.to(device)` for efficient computation.

- **Loss Function** 🎯  
  We're using `CrossEntropyLoss` – perfect for multi-class classification tasks.

- **Optimizer** ⚙️  
  `Adam` is our optimizer of choice, with a customizable learning rate.

- **Device Check** 💻  
  Automatically detects if CUDA is available and switches device accordingly.

### 🔁 Epochs

Training runs over **`num_epochs`** (default: `10`).  
Each epoch = **1 full pass** through the training data.

Multiple epochs help the model gradually improve its performance through weight updates.

### 📊 Metrics Tracked

We're collecting and visualizing:

- **Loss** 📉 → Measures prediction error (lower is better).
- **Accuracy** ✅ → Percentage of correct predictions.

These are stored in:

```python
train_losses = []
train_accuracies = []
```

### 🔄 Inside Each Epoch

Here's what happens **for each epoch**:

1. **Set Model to Training Mode** 🏋️  
   Activates layers like `Dropout`, `BatchNorm`, etc.  
   ```python
   model.train()
   ```

2. **TQDM Progress Bar** ⏳  
   Real-time batch tracking using:  
   ```python
   loop = tqdm(train_loader)
   ```

3. **Move Batch to Device** 💨  
   Efficient GPU/CPU usage:
   ```python
   images = images.to(device)
   labels = labels.to(device)
   ```

4. **Zero Gradients** 🧽  
   Clears leftover gradients from previous batch:
   ```python
   optimizer.zero_grad()
   ```

5. **Forward Pass** 📤  
   Feed data into the model:
   ```python
   outputs = model(images)
   ```

6. **Compute Loss** 📏  
   Compare predictions vs labels:
   ```python
   loss = criterion(outputs, labels)
   ```

7. **Backward Pass** 🔙  
   Backpropagate the loss:
   ```python
   loss.backward()
   ```

8. **Update Weights** 🔧  
   Apply gradients to update parameters:
   ```python
   optimizer.step()
   ```

9. **Update Metrics** 📋  
   Track running loss and count correct predictions:
   ```python
   running_loss += loss.item()
   _, predicted = outputs.max(1)
   correct += predicted.eq(labels).sum().item()
   ```

### ✅ End of Epoch

Once all batches are done:

- Compute **average loss and accuracy**.
- Log the results with:
  ```python
  print(f"Epoch {epoch+1} Results | Loss: ... | Acc: ...")
  ```
- Save the model weights:
  ```python
  torch.save(model.state_dict(), 'entity_classifier.pth')
  ```

### 📈 Final Visualization

After training ends, we generate a side-by-side plot:

- **Left Plot:** Training Loss  
- **Right Plot:** Training Accuracy  

Both with epochs on the x-axis.

> 💾 The plot is saved as:  
> `mobilenetv2_4class_finetune_YYYYMMDD.png`  
> With final metrics and model details in the title & footer.

In [None]:
def train_model(model, train_loader, criterion, optimizer, num_epochs=10, device='cuda'):
    """Main training loop for the classification model
    
    Args:
        model: PyTorch model to train
        train_loader: DataLoader for training data
        criterion: Loss function
        optimizer: Optimization algorithm
        num_epochs: Number of complete passes through dataset
        device: 'cuda' or 'cpu' for training
        
    Returns:
        Trained model with updated weights
    """
    print("Starting training process...")
    # Move model to target device (GPU/CPU) - critical for performance
    model.to(device)
    
    # Lists to store metrics for visualization
    train_losses = []  # Track loss per epoch
    train_accuracies = []  # Track accuracy per epoch
    
    # Main training loop
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        running_loss = 0.0  # Accumulate loss across batches
        correct = 0  # Count correct predictions
        total = 0  # Total samples processed
        
        # Initialize progress bar with tqdm
        loop = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=True)
        
        # Process each batch of data
        for batch_idx, (images, labels) in enumerate(loop):
            # Move data to target device (GPU/CPU)
            images = images.to(device)
            labels = labels.to(device)
            
            # Zero out gradients from previous batch
            optimizer.zero_grad()
            
            # Forward pass - get model predictions
            outputs = model(images)
            
            # Calculate loss between predictions and ground truth
            loss = criterion(outputs, labels)
            
            # Backward pass - compute gradients
            loss.backward()
            
            # Update model weights
            optimizer.step()
            
            # Update training statistics
            running_loss += loss.item()  # Add batch loss to running total
            _, predicted = outputs.max(1)  # Get predicted class indices
            total += labels.size(0)  # Increment total sample count
            correct += predicted.eq(labels).sum().item()  # Count correct predictions
            
            # Update progress bar with current metrics
            loop.set_postfix(
                loss=running_loss/(batch_idx+1),  # Average loss so far
                acc=100.*correct/total  # Current batch accuracy
            )
        
        # Calculate epoch-level metrics
        epoch_loss = running_loss / len(train_loader)  # Average epoch loss
        epoch_acc = 100. * correct / total  # Epoch accuracy
        
        # Store metrics for visualization
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)
        
        # Print epoch summary
        print(f'Epoch {epoch+1} Results | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.2f}%')
    
        # Save model checkpoint after each epoch
        torch.save(model.state_dict(), 'entity_classifier.pth')
        print('Model weights saved to entity_classifier.pth')
    
    # Visualize training results
    plt.figure(figsize=(13, 5))

    # Loss plot (left subplot)
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, 'b', label='Loss')
    plt.title("Training Loss", pad=10)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True, alpha=0.3)  # Add subtle grid
    plt.legend(title="MobileNetV2 Variant\n(Modified Head)")
    
    # Accuracy plot (right subplot)
    plt.subplot(1, 2, 2) 
    plt.plot(train_accuracies, 'r', label='Accuracy')
    plt.title("Training Accuracy", pad=10) 
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.grid(True, alpha=0.3)  # Add subtle grid
    
    # Main title with final metrics
    plt.suptitle(
        "MobileNetV2 Fine-Tuning (Custom 4-Class Head)\n"
        f"Final Accuracy: {train_accuracies[-1]:.1f}% | Loss: {train_accuracies[-1]:.4f}",
        y=1.02,
        fontweight='bold'
    )
    
    # Footer note with architecture details
    plt.gcf().text(
        0.5, -0.08,
        "Backbone: ImageNet weights (frozen) | Head: 256FC-ReLU-Dropout-4FC (trainable)",
        ha='center',
        fontsize=9,
        color='#555'
    )
    
    # Final plot adjustments
    plt.tight_layout()
    plt.savefig(
        f"mobilenetv2_4class_finetune_{datetime.now().strftime('%Y%m%d')}.png", 
        dpi=300,  # High resolution output
        bbox_inches='tight'
    )
    
    return model


def setup_and_train(pretrained_model, dataloader, learning_rate=0.001, num_epochs=10):
    """Complete training setup and execution
    
    Args:
        pretrained_model: Model with initialized weights
        dataloader: Configured DataLoader instance
        learning_rate: Initial learning rate
        num_epochs: Number of training epochs
        
    Returns:
        Fully trained model
    """
    # Define loss function - CrossEntropy for classification
    criterion = torch.nn.CrossEntropyLoss()
    
    # Configure optimizer - Adam with specified learning rate
    optimizer = torch.optim.Adam(pretrained_model.parameters(), lr=learning_rate)
    
    # Auto-detect available device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f'Using device: {device.upper()}')
    
    # Launch training process
    print("Starting model training...")
    trained_model = train_model(
        model=pretrained_model,
        train_loader=dataloader,
        criterion=criterion,
        optimizer=optimizer,
        num_epochs=num_epochs,
        device=device
    )
    
    return trained_model

# Usage
trained_model = setup_and_train(
    pretrained_model=model,  # Your initialized model
    dataloader=dataloader,   # Your configured DataLoader
    learning_rate=5e-5,      # Lower learning rate for fine-tuning
    num_epochs=10            # Training cycles
)

🎉 **Success!** Our model is now fully trained and we have the metrics to prove it! 🎉

We've completed the training process, and our model is ready for action. By evaluating the **training loss** and **accuracy** over the epochs, we can confirm the model’s progress and how well it has learned from the data.

Now that we've achieved solid performance on the training set, it's time to move on to **real-time inference**! 🚀

<hr>

### What’s next? 👀
Instead of continuing in a notebook, the next step is to integrate the trained model into a **PyCharm application** for real-time usage. We'll use the model to make predictions on live data, such as from a **webcam feed**. This allows us to test the model’s performance in real-world scenarios and see how it handles new, unseen data in an interactive application.

Let’s take the model for a spin and see how it performs when it’s really put to work in a **PyCharm app**. 💻✨