In [1]:
import logging
import multiprocessing
import os
import random
import sys

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from facenet_pytorch import MTCNN
from PIL import Image
from sklearn.metrics import (
    accuracy_score,
    precision_recall_curve,
    roc_auc_score,
    roc_curve,
)
from sklearn.model_selection import train_test_split
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler

from deepfake_dataset import DeepfakeDataset, collate_fn
from worker import worker_init_fn

In [2]:
BASE_VIDEO_PATH = "./dataset"
METADATA_PATH = "./dataset/celeb_df_metadata.csv"

# Data Preprocessing Pipeline

Imports and Logging Setup: This section imports necessary libraries (e.g., cv2 for video handling, torch for tensors, MTCNN for face detection) and configures logging. Logging is crucial for robustness, as the PDF highlights potential issues like corrupted videos or face detection failures; it allows tracking warnings/errors without crashing the pipeline.

Class Initialization (__init__): Initializes the dataset with metadata, transformations, frame count, and image size. It sets up MTCNN for face detection on GPU if available, ensuring efficiency for large datasets like Celeb-DF-v2 with varying video qualities.

Length Method (__len__): Returns the number of videos in the metadata, enabling PyTorch's DataLoader to iterate correctly. This is standard but essential for handling the imbalanced dataset size (890 real vs. 5639 fake).

Face Extraction Method (extract_faces_from_video): Handles video reading with error checks (e.g., file existence, opening failures) to prevent crashes on problematic files. It uses adaptive sampling (random for short videos, uniform for longer ones) to capture temporal information without bias, extracts faces via MTCNN, and logs failures for debugging.

Item Retrieval Method (__getitem__): Fetches a video's data by index, assigns binary labels (0 for real, 1 for fake), extracts faces, and handles shortages with noise-added repetitions to avoid overfitting on duplicates. It applies transformations and stacks frames into a tensor, preparing consistent inputs despite dataset variability.

In [3]:
# Setup logging for robustness
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# Reconfigure the root logger to ensure output appears in the notebook
logger = logging.getLogger()
logger.handlers = []  # Clear any existing handlers
handler = logging.StreamHandler(
    sys.stdout
)  # Create a new handler to stream to the console
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)

# Data Augmentation and Transforms

Training Transforms (train_transforms): Composes a sequence of augmentations starting with PIL conversion for compatibility, resizing to 224x224 (standard for backbones like EfficientNet), and random flips/rotations to simulate pose variations. Color jitter and affine shear address lighting/quality inconsistencies, enhancing robustness without masking deepfake artifacts; normalization uses ImageNet stats for transfer learning.

Validation Transforms (val_transforms): A minimal composition for evaluation, including PIL conversion, resizing, tensor conversion, and normalization. This ensures consistent inputs without random augmentations, allowing fair assessment of model generalization on the imbalanced validation set.

In [4]:
# Define enhanced data augmentation for training robustness, considering dataset variability
train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.RandomAffine(degrees=0, shear=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Model Architecture Implementation

Class Initialization (__init__): Sets up the model with frame count (20 for better temporal coverage), selects a pretrained CNN backbone (e.g., EfficientNet-B4 for efficiency on facial details), and defines bidirectional LSTMs for temporal analysis, attention layers for focusing on key frames, and a classifier with batch norm/dropout for stability on imbalanced data.

Forward Pass (forward): Reshapes input for batch CNN processing, extracts features with mixed precision for speed, reshapes for LSTM, applies bidirectional temporal modeling to detect inconsistencies, weights frames via attention (useful for variable-length videos), and classifies via dense layers, outputting a sigmoid probability for real/fake.

In [5]:
import torchvision.models as models
from torch.cuda.amp import autocast
from torch.nn import functional as F


class DeepfakeDetector(nn.Module):
    def __init__(self, num_frames=20, backbone="efficientnet_b4", dropout_rate=0.6):
        super(DeepfakeDetector, self).__init__()
        self.num_frames = num_frames

        # CNN Backbone for feature extraction
        if backbone == "efficientnet_b4":
            weights = models.EfficientNet_B4_Weights.IMAGENET1K_V1
            # self.backbone = models.efficientnet_b4(pretrained=True)
            self.backbone = models.efficientnet_b4(weights=weights)
            self.backbone.classifier = nn.Identity()  # Remove final classifier
            feature_dim = 1792
        elif backbone == "resnet50":
            weights = models.ResNet50_Weights.IMAGENET1K_V1
            self.backbone = models.resnet50(weights=weights)
            # self.backbone = models.resnet50(pretrained=True)
            self.backbone.fc = nn.Identity()
            feature_dim = 2048

        # Temporal processing layers with bidirectional LSTM for better sequence modeling
        self.lstm = nn.LSTM(
            input_size=feature_dim,
            hidden_size=512,
            num_layers=3,
            batch_first=True,
            dropout=dropout_rate,
            bidirectional=True,
        )

        # Attention mechanism for frame importance weighting
        self.attention = nn.Sequential(
            nn.Linear(1024, 256), nn.ReLU(), nn.Linear(256, 1), nn.Softmax(dim=1)
        )

        # Final classification layers with additional regularization
        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),  # Adjusted for bidirectional
            nn.BatchNorm1d(256),  # Added for stability
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, 1),
            # nn.Sigmoid()
        )

    def forward(self, x):
        batch_size, num_frames, C, H, W = x.shape
        x = x.view(batch_size * num_frames, C, H, W)

        # Use torch.amp.autocast for mixed precision
        with torch.amp.autocast(device_type="cuda"):
            features = self.backbone(x)

        features = features.view(batch_size, num_frames, -1)
        lstm_out, _ = self.lstm(features)

        attention_weights = self.attention(lstm_out)

        # The correct approach is to multiply lstm_out by the attention weights
        # and then sum across the time dimension.
        # weighted_features = (lstm_out * attention_weights).sum(dim=1)
        # weighted_features = lstm_out * attention_weights + lstm_out  # Helps with gradient flow.
        # weighted_features = weighted_features.mean(dim=1)  # Reduce over sequence dimension (average pooling)

        # output = self.classifier(weighted_features)
        context_vector = (lstm_out * attention_weights).sum(dim=1)
        output = self.classifier(context_vector)

        return output

# Training Setup and Configuration

Training Setup Function (setup_training): Loads metadata, performs stratified split to preserve imbalance ratios, creates datasets, computes sample weights for oversampling (addressing PDF's 1:6 imbalance), and sets up DataLoaders with a sampler for balanced batching.

Training Function (train_model): Configures device/loss/optimizer with weighted BCE for imbalance, initializes mixed-precision scaler, and runs epochs with training loops (forward pass, backprop with clipping for stability), validation (no-grad inference), metrics calculation, scheduling, and early stopping to save the best model and prevent overfitting.

In [6]:
import matplotlib.pyplot as plt
import torch.amp
from sklearn.metrics import accuracy_score, precision_recall_curve, roc_auc_score
from sklearn.model_selection import train_test_split
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import ReduceLROnPlateau


class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        bce_loss = nn.BCEWithLogitsLoss()(inputs, targets)
        pt = torch.exp(-bce_loss)
        return self.alpha * (1 - pt) ** self.gamma * bce_loss



def setup_training():
    # Load metadata from dataset analysis
    metadata = pd.read_csv(METADATA_PATH)
    # metadata['filepath'] = metadata['filepath'].apply(lambda x: os.path.join(BASE_VIDEO_PATH, x))
    metadata["filepath"] = metadata["filepath"].apply(
        # lambda x: os.path.join(BASE_VIDEO_PATH, x)
        lambda x: os.path.join(BASE_VIDEO_PATH, x)
    )

    # Train/validation split with stratification
    train_df, val_df = train_test_split(
        metadata, test_size=0.2, stratify=metadata["label"], random_state=42
    )

    # Define a cache directory in a writable location
    face_cache_directory = "/kaggle/working/face_cache"
    # Create datasets
    train_dataset = DeepfakeDataset(train_df, transform=train_transforms)
    val_dataset = DeepfakeDataset(val_df, transform=val_transforms)

    # Compute class weights for oversampling to handle imbalance (real: ~890, fake: ~5639 total)
    class_counts = metadata["label"].value_counts()
    # class_weights = {0: 1.0 / class_counts['real'], 1: 1.0 / class_counts['fake']}
    class_weights = {
        0: class_counts["fake"] / class_counts["real"],
        1: 1.0,
    }  # This oversamples real videos more heavily.

    # sample_weights = [class_weights[0] if label == 'real' else class_weights[1] for label in train_df['label']]
    # sample_weights = [class_weights[label] for label in train_df['label']]
    sample_weights = [
        class_weights[1 if label == "fake" else 0] for label in train_df["label"]
    ]

    sampler = WeightedRandomSampler(
        sample_weights, len(sample_weights), replacement=True
    )

    # Calculate pos_weight from the dataframe
    # This is safe because it's before any data loading failures
    real_count = (train_df["label"] == "real").sum()
    fake_count = (train_df["label"] == "fake").sum()
    pos_weight = torch.tensor([fake_count / real_count])

    # Create data loaders with oversampling
    # Apply the collate_fn to both training and validation loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=6,
        shuffle=False,  # Shuffle is handled by WeightedRandomSampler
        sampler=sampler,
        num_workers=4,
        pin_memory=True,
        collate_fn=collate_fn,
        worker_init_fn=worker_init_fn
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=6,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        collate_fn=collate_fn,
        worker_init_fn=worker_init_fn
    )

    return train_loader, val_loader, pos_weight


# Training function with early stopping and gradient clipping
def train_model(
    model,
    train_loader,
    val_loader,
    pos_weight,
    num_epochs=30,
    learning_rate=1e-5,
    patience=10,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))

    # Use BCEWithLogitsLoss for numerical stability
    # criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=learning_rate, weight_decay=1e-4
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=3, factor=0.5
    )

    criterion = FocalLoss().to(device)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=learning_rate, weight_decay=1e-4
    )
    scheduler = ReduceLROnPlateau(optimizer, patience=3, factor=0.5)
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == 'cuda'))

    best_val_loss = float("inf")
    epochs_no_improve = 0
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss, train_correct = 0.0, 0

        for batch_idx, (data, targets) in enumerate(train_loader):
            if data.nelement() == 0:  # Skip empty batches
                continue
            data, targets = data.to(device), targets.to(device)
            targets = targets.unsqueeze(1)

            optimizer.zero_grad()

            # Use torch.amp.autocast
            with torch.amp.autocast(device_type="cuda"):
                outputs = model(data)
                loss = criterion(outputs, targets)

            scaler.scale(loss).backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping

            scaler.unscale_(optimizer)  # Unscale gradients before clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()
            predictions = (torch.sigmoid(outputs) > 0.5).float()
            train_correct += (predictions == targets).sum().item()

            if batch_idx % 10 == 0:
                logging.info(
                    f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}"
                )

        # Validation phase
        model.eval()
        val_loss, val_correct = 0.0, 0
        all_predictions, all_targets = [], []

        with torch.no_grad(), autocast():
            for data, targets in val_loader:
                data, targets = data.to(device), targets.to(device)
                targets = targets.unsqueeze(1)

                outputs = model(data)
                loss = criterion(outputs, targets)

                probabilities = torch.sigmoid(outputs)
                predictions = (probabilities > 0.5).float()

                val_loss += loss.item()
                predictions = (torch.sigmoid(outputs) > 0.5).float()
                val_correct += (predictions == targets).sum().item()

                all_predictions.extend(torch.sigmoid(outputs).cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        # Calculate metrics
        train_acc = train_correct / len(train_loader.dataset)
        val_acc = val_correct / len(val_loader.dataset)
        auc_score = roc_auc_score(all_targets, all_predictions)
        avg_val_loss = val_loss / len(val_loader)

        train_losses.append(train_loss / len(train_loader))
        val_losses.append(avg_val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)

        scheduler.step(avg_val_loss)

        logging.info(
            f"Epoch {epoch+1}: Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, AUC: {auc_score:.4f}"
        )

        # Early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), "best_deepfake_detector_model.pth")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                logging.info(f"Early stopping at epoch {epoch+1}")
                break

    return train_losses, val_losses, train_accs, val_accs

# Model Validation and Evaluation

Evaluation Function (evaluate_model): Sets model to eval mode, collects predictions/probabilities with no-grad and mixed precision, computes metrics like accuracy/AUC/PR curves, and calculates EER (key for imbalanced security tasks per PDF). It generates and saves plots for visual analysis.

Main Execution Block (if __name__ == "__main__"): Orchestrates the pipeline by setting up loaders, initializing the model, training, evaluating, and saving the final model, with logging for progress tracking.

In [7]:
from sklearn.metrics import roc_curve


def evaluate_model(model, val_loader):
    """Comprehensive model evaluation with multiple metrics including EER"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()

    all_predictions = []
    all_probabilities = []
    all_targets = []

    with torch.no_grad(), autocast():
        for data, targets in val_loader:
            data, targets = data.to(device), targets.to(device)
            targets = targets.unsqueeze(1)

            outputs = model(data)  # Raw logits

            probabilities = torch.sigmoid(outputs)
            predictions = (probabilities > 0.5).float()

            all_predictions.extend(predictions.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

    # Calculate comprehensive metrics
    accuracy = accuracy_score(all_targets, all_predictions)
    auc_score = roc_auc_score(all_targets, all_probabilities)

    # Precision-Recall curve
    precision, recall, thresholds = precision_recall_curve(
        all_targets, all_probabilities
    )

    # Equal Error Rate (EER) calculation
    fpr, tpr, thresh = roc_curve(all_targets, all_probabilities)
    fnr = 1 - tpr
    eer_threshold = thresh[np.nanargmin(np.absolute(fnr - fpr))]
    if np.all(fpr == 0) or np.all(fnr == 0):
        eer = 0.5  # Fallback for degenerate cases
    else:
        eer = fpr[np.nanargmin(np.absolute(fnr - fpr))]

    # Plot evaluation metrics
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # ROC Curve
    axes[0].plot(fpr, tpr, label=f"ROC Curve (AUC = {auc_score:.3f})")
    axes[0].plot([0, 1], [0, 1], "k--")
    axes[0].set_xlabel("False Positive Rate")
    axes[0].set_ylabel("True Positive Rate")
    axes[0].set_title("ROC Curve")
    axes[0].legend()

    # Precision-Recall Curve
    axes[1].plot(recall, precision, label=f"PR Curve")
    axes[1].set_xlabel("Recall")
    axes[1].set_ylabel("Precision")
    axes[1].set_title("Precision-Recall Curve")
    axes[1].legend()

    plt.tight_layout()
    plt.savefig("model_evaluation_metrics.png", dpi=300, bbox_inches="tight")
    plt.show()

    print(f"Model Evaluation Results:")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"AUC-ROC: {auc_score:.4f}")
    print(f"Equal Error Rate (EER): {eer:.4f}")

    return accuracy, auc_score, eer


# Main execution code
if __name__ == "__main__":    
    # Explicitly tells CUDA to use the 'spawn' start method
    multiprocessing.set_start_method("spawn", force=True)

    logging.info("Initializing deepfake detection model...")

    # Setup data loaders
    train_loader, val_loader, pos_weight = setup_training()

    # Initialize model
    model = DeepfakeDetector(num_frames=15, backbone="efficientnet_b4")

    logging.info("Starting model training...")
    train_losses, val_losses, train_accs, val_accs = train_model(
        model, train_loader, val_loader, pos_weight, num_epochs=20, patience=5
    )
    # Plot the training and validation loss
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, len(train_losses) + 1), train_losses, label="Training Loss")
    plt.plot(range(1, len(val_losses) + 1), val_losses, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training and Validation Loss Over Epochs")
    plt.legend()
    plt.grid(True)
    plt.savefig("loss_curves.png")
    plt.show()

    logging.info("Evaluating model performance...")
    accuracy, auc_score, eer = evaluate_model(model, val_loader)

    # Save final model
    torch.save(model.state_dict(), "deepfake_detector_model.pth")
    logging.info("Model saved successfully!")

2025-10-07 04:39:54,754 - INFO - Initializing deepfake detection model...
2025-10-07 04:39:55,484 - INFO - Starting model training...


  scaler = torch.cuda.amp.GradScaler(enabled=(device.type == 'cuda'))


2025-10-07 04:40:26,582 - INFO - Epoch 1/20, Batch 0/871, Loss: 0.0567
2025-10-07 04:40:44,857 - INFO - Epoch 1/20, Batch 10/871, Loss: 0.0432


[2025-10-07 04:40:52,302] [INFO] [Worker ID: 8331]: Processing video 100/5223...


2025-10-07 04:41:13,594 - INFO - Epoch 1/20, Batch 20/871, Loss: 0.0422
2025-10-07 04:41:38,061 - INFO - Epoch 1/20, Batch 30/871, Loss: 0.0502


[2025-10-07 04:41:39,007] [INFO] [Worker ID: 8259]: Processing video 200/5223...


2025-10-07 04:42:08,330 - INFO - Epoch 1/20, Batch 40/871, Loss: 0.0482


[2025-10-07 04:42:24,590] [INFO] [Worker ID: 8126]: Processing video 300/5223...


2025-10-07 04:42:33,698 - INFO - Epoch 1/20, Batch 50/871, Loss: 0.0407
2025-10-07 04:43:07,246 - INFO - Epoch 1/20, Batch 60/871, Loss: 0.0415


[2025-10-07 04:43:12,372] [INFO] [Worker ID: 8259]: Processing video 400/5223...


2025-10-07 04:43:36,141 - INFO - Epoch 1/20, Batch 70/871, Loss: 0.0438
2025-10-07 04:44:00,820 - INFO - Epoch 1/20, Batch 80/871, Loss: 0.0613


[2025-10-07 04:44:02,551] [INFO] [Worker ID: 8192]: Processing video 500/5223...


2025-10-07 04:44:34,949 - INFO - Epoch 1/20, Batch 90/871, Loss: 0.0426


[2025-10-07 04:44:51,858] [INFO] [Worker ID: 8126]: Processing video 600/5223...


2025-10-07 04:44:59,790 - INFO - Epoch 1/20, Batch 100/871, Loss: 0.0496
2025-10-07 04:45:31,049 - INFO - Epoch 1/20, Batch 110/871, Loss: 0.0459


[2025-10-07 04:45:41,616] [INFO] [Worker ID: 8192]: Processing video 700/5223...


2025-10-07 04:45:58,424 - INFO - Epoch 1/20, Batch 120/871, Loss: 0.0428


[2025-10-07 04:46:31,144] [INFO] [Worker ID: 8331]: Processing video 800/5223...


2025-10-07 04:46:31,865 - INFO - Epoch 1/20, Batch 130/871, Loss: 0.0449
2025-10-07 04:47:03,962 - INFO - Epoch 1/20, Batch 140/871, Loss: 0.0436


[2025-10-07 04:47:21,664] [INFO] [Worker ID: 8126]: Processing video 900/5223...


2025-10-07 04:47:30,253 - INFO - Epoch 1/20, Batch 150/871, Loss: 0.0394
2025-10-07 04:47:57,432 - INFO - Epoch 1/20, Batch 160/871, Loss: 0.0399


[2025-10-07 04:48:09,463] [INFO] [Worker ID: 8259]: Processing video 1000/5223...


2025-10-07 04:48:31,863 - INFO - Epoch 1/20, Batch 170/871, Loss: 0.0395
2025-10-07 04:48:58,816 - INFO - Epoch 1/20, Batch 180/871, Loss: 0.0361


[2025-10-07 04:48:59,891] [INFO] [Worker ID: 8331]: Processing video 1100/5223...


2025-10-07 04:49:30,509 - INFO - Epoch 1/20, Batch 190/871, Loss: 0.0489


[2025-10-07 04:49:46,805] [INFO] [Worker ID: 8331]: Processing video 1200/5223...


2025-10-07 04:49:53,743 - INFO - Epoch 1/20, Batch 200/871, Loss: 0.0523
2025-10-07 04:50:30,733 - INFO - Epoch 1/20, Batch 210/871, Loss: 0.0433


[2025-10-07 04:50:37,007] [INFO] [Worker ID: 8126]: Processing video 1300/5223...


2025-10-07 04:50:55,293 - INFO - Epoch 1/20, Batch 220/871, Loss: 0.0302


[2025-10-07 04:51:26,782] [INFO] [Worker ID: 8192]: Processing video 1400/5223...


2025-10-07 04:51:27,002 - INFO - Epoch 1/20, Batch 230/871, Loss: 0.0434
2025-10-07 04:51:54,183 - INFO - Epoch 1/20, Batch 240/871, Loss: 0.0475


[2025-10-07 04:52:15,308] [INFO] [Worker ID: 8126]: Processing video 1500/5223...


2025-10-07 04:52:25,308 - INFO - Epoch 1/20, Batch 250/871, Loss: 0.0381
2025-10-07 04:52:55,704 - INFO - Epoch 1/20, Batch 260/871, Loss: 0.0554


[2025-10-07 04:53:06,507] [INFO] [Worker ID: 8331]: Processing video 1600/5223...


2025-10-07 04:53:24,234 - INFO - Epoch 1/20, Batch 270/871, Loss: 0.0426


[2025-10-07 04:53:55,084] [INFO] [Worker ID: 8126]: Processing video 1700/5223...


2025-10-07 04:53:56,870 - INFO - Epoch 1/20, Batch 280/871, Loss: 0.0450
2025-10-07 04:54:33,697 - INFO - Epoch 1/20, Batch 290/871, Loss: 0.0356


[2025-10-07 04:54:50,247] [INFO] [Worker ID: 8259]: Processing video 1800/5223...


2025-10-07 04:55:04,695 - INFO - Epoch 1/20, Batch 300/871, Loss: 0.0435
2025-10-07 04:55:28,525 - INFO - Epoch 1/20, Batch 310/871, Loss: 0.0374


[2025-10-07 04:55:37,756] [INFO] [Worker ID: 8259]: Processing video 1900/5223...


2025-10-07 04:56:00,386 - INFO - Epoch 1/20, Batch 320/871, Loss: 0.0528
2025-10-07 04:56:23,154 - INFO - Epoch 1/20, Batch 330/871, Loss: 0.0392


[2025-10-07 04:56:25,976] [INFO] [Worker ID: 8259]: Processing video 2000/5223...


2025-10-07 04:57:02,846 - INFO - Epoch 1/20, Batch 340/871, Loss: 0.0345


[2025-10-07 04:57:17,902] [INFO] [Worker ID: 8192]: Processing video 2100/5223...


2025-10-07 04:57:28,264 - INFO - Epoch 1/20, Batch 350/871, Loss: 0.0560
2025-10-07 04:58:02,653 - INFO - Epoch 1/20, Batch 360/871, Loss: 0.0464


[2025-10-07 04:58:06,604] [INFO] [Worker ID: 8192]: Processing video 2200/5223...


2025-10-07 04:58:28,869 - INFO - Epoch 1/20, Batch 370/871, Loss: 0.0455


[2025-10-07 04:58:54,967] [INFO] [Worker ID: 8192]: Processing video 2300/5223...


2025-10-07 04:58:59,848 - INFO - Epoch 1/20, Batch 380/871, Loss: 0.0505
2025-10-07 04:59:27,309 - INFO - Epoch 1/20, Batch 390/871, Loss: 0.0560


[2025-10-07 04:59:44,166] [INFO] [Worker ID: 8192]: Processing video 2400/5223...


2025-10-07 04:59:56,665 - INFO - Epoch 1/20, Batch 400/871, Loss: 0.0434
2025-10-07 05:00:30,135 - INFO - Epoch 1/20, Batch 410/871, Loss: 0.0408


[2025-10-07 05:00:35,300] [INFO] [Worker ID: 8192]: Processing video 2500/5223...


2025-10-07 05:01:03,300 - INFO - Epoch 1/20, Batch 420/871, Loss: 0.0543


[2025-10-07 05:01:24,435] [INFO] [Worker ID: 8331]: Processing video 2600/5223...


2025-10-07 05:01:29,338 - INFO - Epoch 1/20, Batch 430/871, Loss: 0.0396
2025-10-07 05:02:00,857 - INFO - Epoch 1/20, Batch 440/871, Loss: 0.0416


[2025-10-07 05:02:10,337] [INFO] [Worker ID: 8192]: Processing video 2700/5223...


2025-10-07 05:02:23,967 - INFO - Epoch 1/20, Batch 450/871, Loss: 0.0531


[2025-10-07 05:03:02,803] [INFO] [Worker ID: 8259]: Processing video 2800/5223...


2025-10-07 05:03:04,281 - INFO - Epoch 1/20, Batch 460/871, Loss: 0.0485
2025-10-07 05:03:28,861 - INFO - Epoch 1/20, Batch 470/871, Loss: 0.0389


[2025-10-07 05:03:52,644] [INFO] [Worker ID: 8192]: Processing video 2900/5223...


2025-10-07 05:03:58,100 - INFO - Epoch 1/20, Batch 480/871, Loss: 0.0378
2025-10-07 05:04:22,814 - INFO - Epoch 1/20, Batch 490/871, Loss: 0.0361


[2025-10-07 05:04:41,068] [INFO] [Worker ID: 8192]: Processing video 3000/5223...


2025-10-07 05:04:54,243 - INFO - Epoch 1/20, Batch 500/871, Loss: 0.0480
2025-10-07 05:05:23,742 - INFO - Epoch 1/20, Batch 510/871, Loss: 0.0465


[2025-10-07 05:05:27,893] [INFO] [Worker ID: 8259]: Processing video 3100/5223...


2025-10-07 05:05:50,701 - INFO - Epoch 1/20, Batch 520/871, Loss: 0.0422


[2025-10-07 05:06:20,534] [INFO] [Worker ID: 8192]: Processing video 3200/5223...


2025-10-07 05:06:30,299 - INFO - Epoch 1/20, Batch 530/871, Loss: 0.0390
2025-10-07 05:06:50,975 - INFO - Epoch 1/20, Batch 540/871, Loss: 0.0450


[2025-10-07 05:07:08,682] [INFO] [Worker ID: 8331]: Processing video 3300/5223...


2025-10-07 05:07:25,726 - INFO - Epoch 1/20, Batch 550/871, Loss: 0.0514
2025-10-07 05:07:48,182 - INFO - Epoch 1/20, Batch 560/871, Loss: 0.0418


[2025-10-07 05:07:59,298] [INFO] [Worker ID: 8331]: Processing video 3400/5223...


2025-10-07 05:08:19,184 - INFO - Epoch 1/20, Batch 570/871, Loss: 0.0432
2025-10-07 05:08:45,796 - INFO - Epoch 1/20, Batch 580/871, Loss: 0.0531


[2025-10-07 05:08:46,485] [INFO] [Worker ID: 8126]: Processing video 3500/5223...


2025-10-07 05:09:12,545 - INFO - Epoch 1/20, Batch 590/871, Loss: 0.0329


[2025-10-07 05:09:36,681] [INFO] [Worker ID: 8259]: Processing video 3600/5223...


2025-10-07 05:09:51,348 - INFO - Epoch 1/20, Batch 600/871, Loss: 0.0353
2025-10-07 05:10:18,922 - INFO - Epoch 1/20, Batch 610/871, Loss: 0.0407


[2025-10-07 05:10:25,939] [INFO] [Worker ID: 8331]: Processing video 3700/5223...


2025-10-07 05:10:54,655 - INFO - Epoch 1/20, Batch 620/871, Loss: 0.0510


[2025-10-07 05:11:17,550] [INFO] [Worker ID: 8126]: Processing video 3800/5223...


2025-10-07 05:11:20,658 - INFO - Epoch 1/20, Batch 630/871, Loss: 0.0534
2025-10-07 05:11:56,430 - INFO - Epoch 1/20, Batch 640/871, Loss: 0.0622


[2025-10-07 05:12:06,593] [INFO] [Worker ID: 8331]: Processing video 3900/5223...


2025-10-07 05:12:21,185 - INFO - Epoch 1/20, Batch 650/871, Loss: 0.0572
2025-10-07 05:12:49,773 - INFO - Epoch 1/20, Batch 660/871, Loss: 0.0462


[2025-10-07 05:12:57,399] [INFO] [Worker ID: 8192]: Processing video 4000/5223...


2025-10-07 05:13:17,662 - INFO - Epoch 1/20, Batch 670/871, Loss: 0.0460


[2025-10-07 05:13:45,144] [INFO] [Worker ID: 8192]: Processing video 4100/5223...


2025-10-07 05:13:54,734 - INFO - Epoch 1/20, Batch 680/871, Loss: 0.0452
2025-10-07 05:14:19,771 - INFO - Epoch 1/20, Batch 690/871, Loss: 0.0286


[2025-10-07 05:14:36,469] [INFO] [Worker ID: 8126]: Processing video 4200/5223...


2025-10-07 05:14:57,493 - INFO - Epoch 1/20, Batch 700/871, Loss: 0.0374
2025-10-07 05:15:22,764 - INFO - Epoch 1/20, Batch 710/871, Loss: 0.0341


[2025-10-07 05:15:26,600] [INFO] [Worker ID: 8259]: Processing video 4300/5223...


2025-10-07 05:16:01,319 - INFO - Epoch 1/20, Batch 720/871, Loss: 0.0412


[2025-10-07 05:16:17,797] [INFO] [Worker ID: 8126]: Processing video 4400/5223...


2025-10-07 05:16:21,309 - INFO - Epoch 1/20, Batch 730/871, Loss: 0.0462
2025-10-07 05:16:56,855 - INFO - Epoch 1/20, Batch 740/871, Loss: 0.0553


[2025-10-07 05:17:08,457] [INFO] [Worker ID: 8331]: Processing video 4500/5223...


2025-10-07 05:17:21,756 - INFO - Epoch 1/20, Batch 750/871, Loss: 0.0534
2025-10-07 05:17:54,591 - INFO - Epoch 1/20, Batch 760/871, Loss: 0.0446


[2025-10-07 05:18:00,056] [INFO] [Worker ID: 8192]: Processing video 4600/5223...


2025-10-07 05:18:20,075 - INFO - Epoch 1/20, Batch 770/871, Loss: 0.0432


[2025-10-07 05:18:48,152] [INFO] [Worker ID: 8259]: Processing video 4700/5223...


2025-10-07 05:18:52,503 - INFO - Epoch 1/20, Batch 780/871, Loss: 0.0461
2025-10-07 05:19:17,556 - INFO - Epoch 1/20, Batch 790/871, Loss: 0.0494


[2025-10-07 05:19:38,924] [INFO] [Worker ID: 8192]: Processing video 4800/5223...


2025-10-07 05:19:46,825 - INFO - Epoch 1/20, Batch 800/871, Loss: 0.0471
2025-10-07 05:20:13,729 - INFO - Epoch 1/20, Batch 810/871, Loss: 0.0359


[2025-10-07 05:20:24,223] [INFO] [Worker ID: 8259]: Processing video 4900/5223...


2025-10-07 05:20:43,166 - INFO - Epoch 1/20, Batch 820/871, Loss: 0.0513
2025-10-07 05:21:08,523 - INFO - Epoch 1/20, Batch 830/871, Loss: 0.0444


[2025-10-07 05:21:10,046] [INFO] [Worker ID: 8259]: Processing video 5000/5223...


2025-10-07 05:21:36,186 - INFO - Epoch 1/20, Batch 840/871, Loss: 0.0401


[2025-10-07 05:21:59,340] [INFO] [Worker ID: 8192]: Processing video 5100/5223...


2025-10-07 05:22:14,386 - INFO - Epoch 1/20, Batch 850/871, Loss: 0.0509
2025-10-07 05:22:37,561 - INFO - Epoch 1/20, Batch 860/871, Loss: 0.0482


[2025-10-07 05:22:46,454] [INFO] [Worker ID: 8192]: Processing video 5200/5223...


2025-10-07 05:23:08,512 - INFO - Epoch 1/20, Batch 870/871, Loss: 0.0491


  with torch.no_grad(), autocast():
[2025-10-07 05:24:03,283] [INFO] [Worker ID: 92901]: Processing video 100/1306...
[2025-10-07 05:24:51,689] [INFO] [Worker ID: 92901]: Processing video 200/1306...
[2025-10-07 05:25:36,940] [INFO] [Worker ID: 92834]: Processing video 300/1306...
[2025-10-07 05:26:23,815] [INFO] [Worker ID: 92703]: Processing video 400/1306...
[2025-10-07 05:27:14,403] [INFO] [Worker ID: 92834]: Processing video 500/1306...
[2025-10-07 05:28:01,228] [INFO] [Worker ID: 92767]: Processing video 600/1306...
[2025-10-07 05:28:48,266] [INFO] [Worker ID: 92767]: Processing video 700/1306...
[2025-10-07 05:29:36,954] [INFO] [Worker ID: 92834]: Processing video 800/1306...
[2025-10-07 05:30:24,279] [INFO] [Worker ID: 92703]: Processing video 900/1306...
[2025-10-07 05:31:10,981] [INFO] [Worker ID: 92767]: Processing video 1000/1306...
[2025-10-07 05:32:02,140] [INFO] [Worker ID: 92834]: Processing video 1100/1306...
[2025-10-07 05:32:48,236] [INFO] [Worker ID: 92767]: Process

ValueError: Input contains NaN.