In [None]:
!pip install ultralytics opencv-python-headless torch matplotlib pandas torchmetrics scipy
!pip install deep-sort-realtime
!pip install torchvision
!pip install seaborn
!pip install tqdm
!pip install moviepy

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from ultralytics import YOLO
from torchmetrics.detection import MeanAveragePrecision
from deep_sort_realtime.deepsort_tracker import DeepSort
from google.colab.patches import cv2_imshow
from scipy.spatial.distance import euclidean
from collections import defaultdict
import warnings
import os
from tqdm import tqdm
from moviepy.editor import VideoFileClip, clips_array, concatenate_videoclips
%matplotlib inline

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# ==================== CONFIGURATION ====================
input_video_path = "/content/1.mp4"
output_video_dir = "/content/output_videos"
combined_output_path = "/content/combined_output.mp4"
os.makedirs(output_video_dir, exist_ok=True)

model = YOLO('yolov8l.pt')  # YOLOv8 Large model
conf_threshold = 0.3  # Confidence threshold for detections

# Enhanced noise configurations with color coding
noise_configs = [
    # No noise (baseline)
    {"type": "none", "intensity": 0, "label": "Original (No Noise)", "color": (0, 255, 0)},

    # Salt & Pepper noise
    {"type": "salt_and_pepper", "intensity": 5, "label": "Salt & Pepper Low", "color": (255, 100, 100)},
    {"type": "salt_and_pepper", "intensity": 10, "label": "Salt & Pepper Medium", "color": (255, 50, 50)},
    {"type": "salt_and_pepper", "intensity": 15, "label": "Salt & Pepper High", "color": (255, 0, 0)},

    # Flicker noise (temporal)
    {"type": "flicker", "intensity": 10, "label": "Flicker Low", "color": (100, 255, 100)},
    {"type": "flicker", "intensity": 20, "label": "Flicker Medium", "color": (50, 255, 50)},
    {"type": "flicker", "intensity": 30, "label": "Flicker High", "color": (0, 255, 0)},

    # Motion Blur (temporal)
    {"type": "motion_blur", "intensity": 5, "label": "Motion Blur Low", "color": (100, 100, 255)},
    {"type": "motion_blur", "intensity": 10, "label": "Motion Blur Medium", "color": (50, 50, 255)},
    {"type": "motion_blur", "intensity": 15, "label": "Motion Blur High", "color": (0, 0, 255)},

    # Temporal Gaussian
    {"type": "temporal_gaussian", "intensity": 5, "label": "Temporal Gaussian Low", "color": (255, 255, 100)},
    {"type": "temporal_gaussian", "intensity": 10, "label": "Temporal Gaussian Medium", "color": (255, 255, 50)},
    {"type": "temporal_gaussian", "intensity": 15, "label": "Temporal Gaussian High", "color": (255, 255, 0)},

    # Gaussian (spatial)
    {"type": "gaussian", "intensity": 5, "label": "Gaussian Low", "color": (255, 100, 255)},
    {"type": "gaussian", "intensity": 10, "label": "Gaussian Medium", "color": (255, 50, 255)},
    {"type": "gaussian", "intensity": 15, "label": "Gaussian High", "color": (255, 0, 255)},
]

# ==================== ENHANCED FUNCTIONS ====================
def add_temporal_noise(frame, noise_type="gaussian", intensity=5, prev_noisy_frame=None):
    """Enhanced temporal/spatial noise addition with better error handling"""
    if noise_type == "none":
        return frame.copy(), frame  # Return both original and noisy for comparison

    if frame.dtype != np.uint8:
        frame = frame.astype(np.uint8)

    try:
        if noise_type == "gaussian":
            frame_float = frame.astype(np.float32)
            noise = np.random.normal(0, min(intensity, 50), frame.shape).astype(np.float32)
            noisy_frame = cv2.add(frame_float, noise)
            noisy_frame = np.clip(noisy_frame, 0, 255).astype(np.uint8)

        elif noise_type == "salt_and_pepper":
            noise_prob = intensity / 200
            noisy_frame = frame.copy()
            random_mask = np.random.random(frame.shape[:2])
            noisy_frame[random_mask > (1 - noise_prob/2)] = 255
            noisy_frame[random_mask < noise_prob/2] = 0

        elif noise_type == "flicker":
            flicker_factor = 1 + (np.random.rand() - 0.5) * (intensity / 75)
            noisy_frame = np.clip(frame * flicker_factor, 0, 255).astype(np.uint8)

        elif noise_type == "motion_blur":
            kernel_size = max(3, min(intensity // 2, 21))
            kernel_size = kernel_size + 1 if kernel_size % 2 == 0 else kernel_size
            kernel = np.zeros((kernel_size, kernel_size))
            kernel[int(kernel_size/2), :] = 1.0 / kernel_size
            noisy_frame = cv2.filter2D(frame, -1, kernel)

        elif noise_type == "temporal_gaussian" and prev_noisy_frame is not None:
            noise = np.random.normal(0, min(intensity, 50), frame.shape)
            noisy_frame = 0.6 * prev_noisy_frame.astype(np.float32) + \
                         0.4 * frame.astype(np.float32) + noise
            noisy_frame = np.clip(noisy_frame, 0, 255).astype(np.uint8)

        else:
            noisy_frame = frame.copy()

        return noisy_frame, frame  # Return both noisy and original for comparison

    except Exception as e:
        print(f"Error in {noise_type} noise: {e}")
        return frame.copy(), frame

def adaptive_denoise(frame, noise_type, intensity):
    """Enhanced denoising with performance metrics and error handling"""
    if noise_type == "none" or intensity == 0:
        return frame.copy(), 0  # Return processing time

    start_time = cv2.getTickCount()

    try:
        if noise_type in ["gaussian", "temporal_gaussian"]:
            denoised = cv2.fastNlMeansDenoisingColored(
                frame.astype(np.uint8),
                h=3 + intensity//3,
                hColor=3 + intensity//3,
                templateWindowSize=7,
                searchWindowSize=21
            )
        elif noise_type == "salt_and_pepper":
            kernel_size = max(3, min(3 + intensity//4, 11))
            kernel_size = kernel_size + 1 if kernel_size % 2 == 0 else kernel_size
            denoised = cv2.medianBlur(frame.astype(np.uint8), kernel_size)
        elif noise_type == "flicker":
            denoised = cv2.bilateralFilter(
                frame.astype(np.uint8),
                9,
                50 + intensity*3,
                50 + intensity*3
            )
        elif noise_type == "motion_blur":
            denoised = cv2.fastNlMeansDenoisingColored(
                frame.astype(np.uint8),
                h=10,
                hColor=10,
                templateWindowSize=7,
                searchWindowSize=21
            )
        else:
            denoised = frame.copy()

        processing_time = (cv2.getTickCount() - start_time) / cv2.getTickFrequency()
        return denoised, processing_time

    except Exception as e:
        print(f"Error in denoising: {e}")
        processing_time = (cv2.getTickCount() - start_time) / cv2.getTickFrequency()
        return frame.copy(), processing_time

def detect_objects(frame, model):
    """Enhanced object detection with class-specific metrics and error handling"""
    try:
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        results = model.predict(source=frame_rgb, save=False,
                              conf=conf_threshold, verbose=False)

        detections = []
        class_counts = defaultdict(int)

        if results and hasattr(results[0], 'boxes') and results[0].boxes is not None:
            for result in results[0].boxes:
                if result.xyxy.numel() > 0 and result.conf.numel() > 0 and result.cls.numel() > 0:
                    x1, y1, x2, y2 = map(int, result.xyxy[0])
                    if (x2 - x1) > 10 and (y2 - y1) > 10:  # Minimum size threshold
                        conf = result.conf[0].item()
                        cls = int(result.cls[0])
                        detections.append(([x1, y1, x2, y2], conf, cls))
                        class_counts[model.names[cls]] += 1

        return detections, class_counts
    except Exception as e:
        print(f"Detection error: {e}")
        return [], defaultdict(int)

def generate_dynamic_gt(frame, detections):
    """Improved ground truth generation with fallbacks"""
    if len(detections) == 0:
        return {
            "boxes": torch.zeros((0, 4), dtype=torch.float32),
            "labels": torch.zeros((0,), dtype=torch.int32)
        }

    try:
        # Try to get high confidence detections first
        high_conf_detections = [d for d in detections if d[1] > 0.7]

        if len(high_conf_detections) > 0:
            # Get largest detection by area
            best_det = max(high_conf_detections,
                          key=lambda x: (x[0][2]-x[0][0])*(x[0][3]-x[0][1]))
            return {
                "boxes": torch.tensor([best_det[0]], dtype=torch.float32),
                "labels": torch.tensor([best_det[2]], dtype=torch.int32)
            }
        else:
            # Fallback to best available detection
            best_det = max(detections, key=lambda x: x[1], default=None)
            if best_det:
                return {
                    "boxes": torch.tensor([best_det[0]], dtype=torch.float32),
                    "labels": torch.tensor([best_det[2]], dtype=torch.int32)
                }
            else:
                return {
                    "boxes": torch.zeros((0, 4), dtype=torch.float32),
                    "labels": torch.zeros((0,), dtype=torch.int32)
                }
    except Exception as e:
        print(f"GT generation error: {e}")
        return {
            "boxes": torch.zeros((0, 4), dtype=torch.float32),
            "labels": torch.zeros((0,), dtype=torch.int32)
        }

def analyze_movement(tracks, prev_positions, frame_count):
    """Improved movement analysis with better metrics"""
    movements = []
    velocities = []

    for track in tracks:
        if not track.is_confirmed():
            continue

        try:
            track_id = track.track_id
            ltrb = track.to_ltrb()
            current_pos = np.array([(ltrb[0]+ltrb[2])/2, (ltrb[1]+ltrb[3])/2])  # Bbox center

            if track_id in prev_positions:
                prev_pos, prev_frame = prev_positions[track_id]
                if frame_count > prev_frame:  # Ensure positive time difference
                    displacement = euclidean(current_pos, prev_pos)
                    time_elapsed = frame_count - prev_frame
                    velocity = displacement / time_elapsed

                    movements.append(displacement)
                    velocities.append(velocity)

            prev_positions[track_id] = (current_pos, frame_count)
        except Exception as e:
            print(f"Movement analysis error: {e}")
            continue

    return {
        "displacements": movements if movements else [0],
        "velocities": velocities if velocities else [0]
    }

def calculate_optical_flow(prev_frame, current_frame):
    """Improved optical flow calculation with validation"""
    try:
        prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
        current_gray = cv2.cvtColor(current_frame, cv2.COLOR_BGR2GRAY)

        flow = cv2.calcOpticalFlowFarneback(
            prev_gray, current_gray,
            None, 0.5, 3, 15, 3, 5, 1.2, 0
        )

        magnitude, angle = cv2.cartToPolar(flow[..., 0], flow[..., 1])

        # Filter out extreme values
        valid_magnitude = magnitude[(magnitude > 0.1) & (magnitude < 10)]
        avg_magnitude = np.mean(valid_magnitude) if len(valid_magnitude) > 0 else 0

        return {
            "avg_magnitude": avg_magnitude,
            "std_magnitude": np.std(valid_magnitude) if len(valid_magnitude) > 0 else 0,
            "avg_angle": np.mean(angle),
            "flow": flow
        }
    except Exception as e:
        print(f"Optical flow error: {e}")
        return {
            "avg_magnitude": 0,
            "std_magnitude": 0,
            "avg_angle": 0,
            "flow": None
        }

def safe_metric_compute(metric):
    """Safe metric computation with fallback for empty cases"""
    try:
        result = metric.compute()
        # Handle cases where there are no detections
        if torch.isnan(result['map']) or result['map'] < 0:
            result['map'] = 0.0
        if torch.isnan(result['map_50']) or result['map_50'] < 0:
            result['map_50'] = 0.0
        if torch.isnan(result['mar_100']) or result['mar_100'] < 0:
            result['mar_100'] = 0.0
        return result
    except Exception as e:
        print(f"Metric computation error: {e}")
        return {
            'map': torch.tensor(0.0),
            'map_50': torch.tensor(0.0),
            'mar_100': torch.tensor(0.0)
        }

def create_comparison_frame(original_frame, processed_frame, config, metrics):
    """Create side-by-side comparison frame with metrics overlay"""
    # Resize frames to half width for side-by-side display
    h, w = original_frame.shape[:2]
    comparison_frame = np.zeros((h, w*2, 3), dtype=np.uint8)

    # Place original frame on left
    comparison_frame[:, :w] = original_frame

    # Place processed frame on right
    comparison_frame[:, w:] = processed_frame

    # Add divider line
    cv2.line(comparison_frame, (w, 0), (w, h), (255, 255, 255), 2)

    # Add title and metrics
    cv2.putText(comparison_frame, "Original", (10, 30),
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
    cv2.putText(comparison_frame, config['label'], (w+10, 30),
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, config['color'], 2)

    # Add metrics on the right side
    y_offset = 60
    for metric, value in metrics.items():
        if isinstance(value, (int, float)):
            text = f"{metric}: {value:.2f}"
        else:
            text = f"{metric}: {value}"
        cv2.putText(comparison_frame, text, (w+10, y_offset),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1)
        y_offset += 30

    return comparison_frame

# ==================== MAIN PROCESSING ====================
if not os.path.exists(input_video_path):
    raise FileNotFoundError(f"Video file not found: {input_video_path}")

cap = cv2.VideoCapture(input_video_path)
if not cap.isOpened():
    raise ValueError(f"Cannot open video: {input_video_path}")

frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"Video Info: {frame_width}x{frame_height}, {fps} FPS, {total_frames} frames")

# Enhanced data storage
metrics_data = []
frame_metrics = defaultdict(list)
class_performance = defaultdict(lambda: defaultdict(list))
movement_data = defaultdict(list)
optical_flow_data = defaultdict(list)
detection_data = defaultdict(list)
comparison_frames = []  # Store frames for combined video

MAX_FRAMES = min(300, total_frames)  # Limit frames for testing

# Initialize video writer for combined output
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
combined_out = cv2.VideoWriter(combined_output_path, fourcc, fps, (frame_width*2, frame_height))

for config in noise_configs:
    print(f"\n=== Processing {config['label']} ===")

    cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
    metric = MeanAveragePrecision()
    tracker = DeepSort(max_age=30)
    prev_positions = {}
    prev_frame = None
    prev_noisy_frame = None

    flow_results = {
        "avg_magnitude": 0,
        "std_magnitude": 0,
        "avg_angle": 0,
        "flow": None
    }

    frame_count = 0
    total_detections = 0
    total_displacement = 0
    total_velocity = 0

    progress_bar = tqdm(total=MAX_FRAMES, desc=f"Processing {config['label']}")

    while frame_count < MAX_FRAMES:
        ret, frame = cap.read()
        if not ret:
            break

        frame_count += 1
        progress_bar.update(1)

        # 1. Add noise (returns both noisy and original)
        noisy_frame, original_frame = add_temporal_noise(
            frame, config["type"], config["intensity"], prev_noisy_frame
        )
        prev_noisy_frame = noisy_frame.copy()

        # 2. Adaptive denoising with timing
        denoised_frame, denoise_time = adaptive_denoise(
            noisy_frame, config["type"], config["intensity"]
        )

        # 3. Object detection with class counts
        detections, class_counts = detect_objects(denoised_frame, model)
        total_detections += len(detections)

        detection_data[config['label']].append(len(detections))

        # Track class performance
        for class_name, count in class_counts.items():
            class_performance[config['label']][class_name].append(count)

        # 4. Tracking with DeepSORT
        tracks = tracker.update_tracks(detections, frame=denoised_frame)

        # 5. Movement analysis
        movement_results = analyze_movement(tracks, prev_positions, frame_count)
        avg_displacement = np.mean(movement_results["displacements"])
        avg_velocity = np.mean(movement_results["velocities"])

        total_displacement += avg_displacement
        total_velocity += avg_velocity
        movement_data[config['label']].append(avg_displacement)

        # 6. Optical flow analysis
        if prev_frame is not None:
            flow_results = calculate_optical_flow(prev_frame, denoised_frame)
            optical_flow_data[config['label']].append(flow_results["avg_magnitude"])
        prev_frame = denoised_frame.copy()

        # 7. Metric evaluation with safety checks
        gt = generate_dynamic_gt(frame, detections)

        # Only update metrics if we have either predictions or ground truth
        if len(detections) > 0 or len(gt["boxes"]) > 0:
            preds = [{
                "boxes": torch.tensor([d[0] for d in detections], dtype=torch.float32) if detections else torch.zeros((0, 4), dtype=torch.float32),
                "scores": torch.tensor([d[1] for d in detections], dtype=torch.float32) if detections else torch.zeros((0,), dtype=torch.float32),
                "labels": torch.tensor([d[2] for d in detections], dtype=torch.int32) if detections else torch.zeros((0,), dtype=torch.int32)
            }]
            metric.update(preds, [gt])

        # 8. Store frame metrics
        frame_metrics[config['label']].append({
            "frame_num": frame_count,
            "detections": len(detections),
            "denoise_time": denoise_time,
            "avg_displacement": avg_displacement,
            "avg_velocity": avg_velocity,
            "optical_flow": flow_results["avg_magnitude"] if prev_frame is not None else 0,
            "class_counts": dict(class_counts)
        })

        # 9. Create comparison frame with metrics
        current_metrics = {
            "Detections": len(detections),
            "Displacement": avg_displacement,
            "Velocity": avg_velocity,
            "Denoise Time": f"{denoise_time*1000:.1f}ms",
            "Optical Flow": flow_results["avg_magnitude"] if prev_frame is not None else 0
        }

        comparison_frame = create_comparison_frame(
            original_frame, denoised_frame, config, current_metrics
        )

        # Store every 5th frame for combined video
        if frame_count % 5 == 0:
            comparison_frames.append(comparison_frame)

        # 10. Visualization (every 30 frames)
        if frame_count % 30 == 0 or frame_count == 1:
            viz_frame = denoised_frame.copy()

            # Draw ground truth (green)
            for box in gt["boxes"]:
                x1, y1, x2, y2 = map(int, box)
                cv2.rectangle(viz_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
                cv2.putText(viz_frame, "GT", (x1, y1-10),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

            # Draw detections (red)
            for box, conf, cls in detections:
                x1, y1, x2, y2 = box
                cv2.rectangle(viz_frame, (x1, y1), (x2, y2), (0, 0, 255), 2)
                cv2.putText(viz_frame, f"{model.names[cls]} {conf:.2f}",
                           (x1, y1-30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)

            # Additional info
            cv2.putText(viz_frame, f"{config['label']} - Frame {frame_count}", (10, 30),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
            cv2.putText(viz_frame, f"Detections: {len(detections)}", (10, 60),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
            cv2.putText(viz_frame, f"Avg Displacement: {avg_displacement:.2f} px", (10, 90),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
            cv2.putText(viz_frame, f"Avg Velocity: {avg_velocity:.2f} px/frame", (10, 120),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)

            cv2_imshow(viz_frame)

    progress_bar.close()

    # Calculate final metrics for this noise configuration
    result = safe_metric_compute(metric)
    avg_detections = total_detections / frame_count if frame_count > 0 else 0
    avg_displacement = total_displacement / frame_count if frame_count > 0 else 0
    avg_velocity = total_velocity / frame_count if frame_count > 0 else 0
    avg_optical_flow = np.mean(optical_flow_data[config['label']]) if optical_flow_data[config['label']] else 0

    # Calculate stabilities with protection against division by zero
    det_stability = (np.std(detection_data[config['label']]) / avg_detections) if avg_detections > 0 else 0
    move_stability = (np.std(movement_data[config['label']]) / avg_displacement) if avg_displacement > 0 else 0

    metrics_data.append({
        "Noise Type": config['label'],
        "Noise Intensity": config["intensity"],
        "mAP": result['map'].item(),
        "Precision": result['map_50'].item(),
        "Recall": result['mar_100'].item(),
        "Avg Detections": avg_detections,
        "Avg Displacement": avg_displacement,
        "Avg Velocity": avg_velocity,
        "Avg Optical Flow": avg_optical_flow,
        "Detection Stability": det_stability,
        "Movement Stability": move_stability,
        "Avg Denoise Time (ms)": np.mean([f['denoise_time'] for f in frame_metrics[config['label']]]) * 1000
    })

    print(f"\nResults for {config['label']}:")
    print(f"- mAP: {metrics_data[-1]['mAP']:.4f}")
    print(f"- Precision: {metrics_data[-1]['Precision']:.4f}")
    print(f"- Recall: {metrics_data[-1]['Recall']:.4f}")
    print(f"- Avg Detections per Frame: {metrics_data[-1]['Avg Detections']:.2f}")
    print(f"- Avg Object Displacement: {metrics_data[-1]['Avg Displacement']:.2f} px")
    print(f"- Avg Object Velocity: {metrics_data[-1]['Avg Velocity']:.2f} px/frame")
    print(f"- Avg Optical Flow: {metrics_data[-1]['Avg Optical Flow']:.2f}")
    print(f"- Detection Stability: {metrics_data[-1]['Detection Stability']:.4f}")
    print(f"- Movement Stability: {metrics_data[-1]['Movement Stability']:.4f}")
    print(f"- Avg Denoise Time: {metrics_data[-1]['Avg Denoise Time (ms)']:.2f} ms")

# Write all comparison frames to combined video
print("\nCreating combined output video...")
for frame in comparison_frames:
    combined_out.write(frame)
combined_out.release()
print(f"Combined video saved at: {combined_output_path}")

# ==================== ENHANCED ANALYSIS & VISUALIZATION ====================
# Save frame-by-frame metrics
for config_label, metrics in frame_metrics.items():
    df = pd.DataFrame(metrics)
    df.to_csv(f"frame_metrics_{config_label.replace(' ', '_')}.csv", index=False)

# Save class performance data
class_perf_df = pd.DataFrame.from_dict({
    (config, cls): counts
    for config in class_performance
    for cls, counts in class_performance[config].items()
}, orient='index')

# Improved heatmap generation with error handling
try:
    if not class_perf_df.empty:
        plt.figure(figsize=(15, 8))
        class_perf_agg = class_perf_df.groupby(level=[0, 1]).mean()

        # Ensure we have data to plot
        if not class_perf_agg.empty:
            # Create pivot table for heatmap
            heatmap_data = class_perf_agg.unstack()

            # Handle case where unstack returns Series instead of DataFrame
            if isinstance(heatmap_data, pd.Series):
                heatmap_data = heatmap_data.unstack()

            sns.heatmap(heatmap_data, annot=True, fmt=".1f", cmap="YlGnBu")
            plt.title("Average Detections by Noise Type and Object Class", pad=15, fontweight='bold')
            plt.tight_layout()
            plt.savefig("class_performance_heatmap.png", dpi=300, bbox_inches='tight')
            plt.show()
        else:
            print("Warning: No data available for class performance heatmap")
    else:
        print("Warning: Class performance dataframe is empty")
except Exception as e:
    print(f"Error generating heatmap: {e}")

# Save final metrics
metrics_df = pd.DataFrame(metrics_data)
metrics_df.to_csv("final_metrics.csv", index=False)

# Enhanced visualization
plt.style.use('seaborn-v0_8')
plt.rcParams.update({
    'font.size': 12,
    'figure.titlesize': 16,
    'axes.titlesize': 14,
    'axes.labelweight': 'bold',
    'figure.figsize': (15, 10)
})

# 2. Frame-by-frame metrics comparison
plt.figure(figsize=(15, 10))
for config in noise_configs[:4]:  # Plot first 4 for clarity
    df = pd.DataFrame(frame_metrics[config['label']])
    plt.plot(df['frame_num'], df['detections'], label=config['label'], color=np.array(config['color'])/255)
plt.title("Detection Counts Over Time by Noise Type", pad=15, fontweight='bold')
plt.xlabel("Frame Number")
plt.ylabel("Detections")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("detections_over_time.png", dpi=300, bbox_inches='tight')
plt.show()

# 3. Denoising performance
plt.figure(figsize=(15, 6))
sns.barplot(data=metrics_df, x='Noise Type', y='Avg Denoise Time (ms)', hue='Noise Intensity', palette='viridis')
plt.title("Denoising Processing Time by Noise Type", pad=15, fontweight='bold')
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig("denoising_performance.png", dpi=300, bbox_inches='tight')
plt.show()

# 4. mAP comparison
plt.figure(figsize=(15, 6))
sns.barplot(data=metrics_df, x='Noise Type', y='mAP', palette='viridis')
plt.title("mAP Across Different Noise Types", pad=15, fontweight='bold')
plt.xticks(rotation=45)
plt.ylim(0, 1)
plt.tight_layout()
plt.savefig("mAP_comparison.png", dpi=300, bbox_inches='tight')
plt.show()

# 5. Noise intensity vs velocity
plt.figure(figsize=(15, 6))
sns.lineplot(data=metrics_df, x='Noise Intensity', y='Avg Velocity',
             hue='Noise Type', style='Noise Type', markers=True, dashes=False)
plt.title("Effect of Noise Intensity on Object Velocity", pad=15, fontweight='bold')
plt.xlabel("Noise Intensity")
plt.ylabel("Average Velocity (px/frame)")
plt.grid(True)
plt.tight_layout()
plt.savefig("noise_intensity_vs_velocity.png", dpi=300, bbox_inches='tight')
plt.show()

# 6. Metrics correlation
plt.figure(figsize=(10, 8))
corr = metrics_df[["mAP", "Precision", "Recall", "Avg Detections",
                  "Avg Displacement", "Avg Velocity", "Avg Optical Flow", "Detection Stability"]].corr()
sns.heatmap(corr, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0,
            annot_kws={"size": 10}, fmt=".2f")
plt.title("Correlation Between Performance Metrics", pad=15, fontweight='bold')
plt.tight_layout()
plt.savefig("metrics_correlation.png", dpi=300, bbox_inches='tight')
plt.show()

# ==================== CLEANUP & RESULTS ====================
cap.release()
print("\nEnhanced processing complete! New outputs include:")
print(f"- Frame-by-frame metrics for each noise type")
print(f"- Class-specific performance analysis")
print(f"- Denoising performance metrics")
print(f"- Movement and stability metrics")
print(f"- Combined output video showing all results")
print(f"- Visualizations:")
print(f"  - class_performance_heatmap.png")
print(f"  - detections_over_time.png")
print(f"  - denoising_performance.png")
print(f"  - mAP_comparison.png")
print(f"  - noise_intensity_vs_velocity.png")
print(f"  - metrics_correlation.png")