### Compute Data Statistics

In [None]:
"""
YOLO Data Analysis Tool

This script analyzes bounding box data in YOLO format. It provides visualizations to better understand
data distributions, spatial concentrations, and bounding box dimensions. The updated version keeps all
original functionality intact but replaces boundary coordinates in the correlation matrix.

Key Updates:
- Boundary coordinates (`x_min`, `y_min`, `x_max`, `y_max`) in the correlation matrix
  are replaced with center coordinates (`x_center`, `y_center`).

Example Directory Structure:
    dataset/
    ├── images/
    ├── labels/

Example Usage:
    yolo_dir = '/path/to/yolo/dataset'
    save_dir = '/path/to/save/plots'  # Specify the directory to save plots
    global_save = True  # Set to True to save plots
    analyze_yolo_dir(yolo_dir, save_dir=save_dir, global_save=global_save)
"""

import os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns


def read_yolo_labels(labels_dir):
    """
    Reads YOLO label files from the specified directory.

    Args:
        labels_dir (str): Path to the directory containing YOLO label files.

    Returns:
        np.array: An array of bounding box data with each row as (x_center, y_center, width, height).
    """
    print(f"Reading labels from directory: {labels_dir}")
    bbox_data = []

    label_path = Path(labels_dir)
    if not label_path.exists():
        raise FileNotFoundError(f"The specified labels directory does not exist: {labels_dir}")

    for label_file in label_path.glob('*.txt'):
        print(f"Processing label file: {label_file}")
        with label_file.open('r') as file:
            for line in file:
                try:
                    _, x_center, y_center, width, height = map(float, line.strip().split())
                    x_min = x_center - width / 2
                    x_max = x_center + width / 2
                    y_min = y_center - height / 2
                    y_max = y_center + height / 2
                    bbox_data.append((x_min, y_min, x_max, y_max, width, height, x_center, y_center))
                except ValueError:
                    print(f"Skipping malformed line in file {label_file}: {line.strip()}")

    print(f"Total bounding boxes read: {len(bbox_data)}")
    return np.array(bbox_data)


def compute_statistics(bbox_data):
    """
    Computes inferential statistics for the bounding box features.

    Args:
        bbox_data (np.array): Array of bounding box data.

    Returns:
        pd.DataFrame: Summary statistics for each feature.
    """
    columns = ['x_min', 'y_min', 'x_max', 'y_max', 'width', 'height', 'x_center', 'y_center']
    df = pd.DataFrame(bbox_data, columns=columns)
    stats = df.describe().transpose()
    print("Bounding Box Feature Statistics:")
    print(stats)
    return stats


def plot_correlation_matrix_with_centers(bbox_data, save_dir=None, global_save=False):
    """
    Plots a correlation matrix for bounding box features, replacing
    boundary coordinates (x_min, y_min, x_max, y_max) with center coordinates.

    Args:
        bbox_data (np.array): Array of bounding box data.
        save_dir (str, optional): Directory to save the plot. Default is None.
        global_save (bool, optional): Flag to save the plot globally. Default is False.
    """
    columns = ['x_min', 'y_min', 'x_max', 'y_max', 'width', 'height', 'x_center', 'y_center']
    df = pd.DataFrame(bbox_data, columns=columns)
    
    # Drop boundary coordinates
    df = df.drop(columns=['x_min', 'y_min', 'x_max', 'y_max'])
    
    # Compute correlation matrix
    correlation = df.corr()
    print("Feature Correlation Matrix:")
    print(correlation)

    plt.figure(figsize=(8, 6))
    sns.heatmap(correlation, annot=True, fmt=".2f", cmap='coolwarm', square=True)
    plt.title("Train Data: Correlation Matrix of Lesion Features (Using Center Coordinates)")

    if global_save and save_dir:
        plt.savefig(os.path.join(save_dir, 'correlation_matrix_with_centers.png'))
        print(f"Correlation matrix saved to: {save_dir}")
    plt.show()


def plot_heatmap(bbox_data, image_height=100, image_width=100, save_dir=None, global_save=False):
    """
    Plots a heatmap of lesion locations based on bounding box centers.
    """
    x_centers = bbox_data[:, 0]
    y_centers = bbox_data[:, 1]

    heatmap_accumulator = np.zeros((image_height, image_width))
    for x, y in zip(x_centers, y_centers):
        # Clip coordinates to ensure they stay within bounds
        x_idx = min(int(x * image_width), image_width - 1)
        y_idx = min(int(y * image_height), image_height - 1)
        heatmap_accumulator[y_idx, x_idx] += 1

    plt.figure(figsize=(10, 10))
    plt.imshow(heatmap_accumulator, cmap='viridis', interpolation='nearest', origin='lower')
    plt.colorbar(label='Number of Lesions')
    plt.title('Train Data: Heatmap of Lesion Locations (Center Coordinates)')
    plt.xlabel('X coordinate')
    plt.ylabel('Y coordinate')

    if global_save and save_dir:
        plt.savefig(os.path.join(save_dir, 'heatmap_of_lesion_locations.png'))
        print(f"Heatmap saved to: {save_dir}")
    plt.show()
    

def plot_bbox_analysis(bbox_data, plots_to_generate=None, figsize=(6, 6), save_dir=None, global_save=False):
    """
    Generates visualizations based on bounding box data.

    Args:
        bbox_data (np.array): Array of bounding box data.
        plots_to_generate (list, optional): List of plots to generate.
        figsize (tuple, optional): Size of the figure for each plot. Default is (6, 6).
        save_dir (str, optional): Directory to save the plots. Default is None.
        global_save (bool, optional): Flag to save the plots globally. Default is False.
    """
    print("Starting plot generation...")
    if len(bbox_data) == 0:
        print("No bounding box data found.")
        return

    if plots_to_generate is None:
        plots_to_generate = ['total_instances', 'centered_bboxes', 'center_distribution', 'width_vs_height']

    x_centers = bbox_data[:, 6]
    y_centers = bbox_data[:, 7]
    widths = bbox_data[:, 4]
    heights = bbox_data[:, 5]

    if 'total_instances' in plots_to_generate:
        print("Plotting total instances...")
        plt.figure(figsize=figsize)
        plt.bar(['unlabeled'], [len(bbox_data)], color='blue')
        plt.ylabel('instances')
        plt.title('Train Data: Total Bounding Box Instances')
        plt.show()

        if global_save and save_dir:
            plt.savefig(os.path.join(save_dir, 'total_instances.png'))
            print(f"Total instances plot saved to: {save_dir}")

    if 'centered_bboxes' in plots_to_generate:
        print("Plotting centered bounding boxes...")
        plt.figure(figsize=figsize)
        for width, height in zip(widths, heights):
            plt.gca().add_patch(
                plt.Rectangle((0.5 - width / 2, 0.5 - height / 2), width, height, 
                              fill=False, edgecolor='blue', linewidth=0.5)
            )
        plt.xlim(0, 1)
        plt.ylim(0, 1)
        plt.title('Train Data: Bounding Boxes Centered')
        plt.xlabel('x')
        plt.ylabel('y')

        if global_save and save_dir:
            plt.savefig(os.path.join(save_dir, 'centered_bounding_boxes.png'))
            print(f"Centered bounding boxes plot saved to: {save_dir}")
        plt.show()

    if 'center_distribution' in plots_to_generate:
        print("Plotting center distribution...")
        plt.figure(figsize=figsize)
        plt.hist2d(x_centers, y_centers, bins=50, cmap='Blues')
        plt.colorbar(label='Density')
        plt.xlabel('x')
        plt.ylabel('y')
        plt.title('Train Data: Bounding Box Center Distribution')

        if global_save and save_dir:
            plt.savefig(os.path.join(save_dir, 'center_distribution.png'))
            print(f"Center distribution plot saved to: {save_dir}")
        plt.show()

    if 'width_vs_height' in plots_to_generate:
        print("Plotting width vs. height distribution...")
        plt.figure(figsize=figsize)
        plt.hist2d(widths, heights, bins=50, cmap='Blues')
        plt.colorbar(label='Density')
        plt.xlabel('width')
        plt.ylabel('height')
        plt.title('Train Data: Bounding Box Width vs Height')

        if global_save and save_dir:
            plt.savefig(os.path.join(save_dir, 'width_vs_height_distribution.png'))
            print(f"Width vs height distribution plot saved to: {save_dir}")
        plt.show()


def analyze_yolo_dir(yolo_dir, plots_to_generate=None, save_dir=None, global_save=False):
    """
    Analyzes bounding box data in YOLO format and generates specified plots.

    Args:
        yolo_dir (str): Path to the YOLO dataset directory.
        plots_to_generate (list, optional): List of plots to generate. Default is None (all plots).
        save_dir (str, optional): Directory to save the plots. Default is None.
        global_save (bool, optional): Flag to save the plots globally. Default is False.
    """
    print(f"Starting analysis for YOLO directory: {yolo_dir}")
    labels_dir = os.path.join(yolo_dir, 'labels')

    if not os.path.exists(labels_dir):
        raise FileNotFoundError(f"Labels directory not found: {labels_dir}")

    print("Reading bounding box data...")
    bbox_data = read_yolo_labels(labels_dir)

    print("Plotting analysis...")
    plot_bbox_analysis(bbox_data, plots_to_generate, save_dir=save_dir, global_save=global_save)

    # Additional new plots
    print("Generating heatmap and correlation matrix...")
    plot_heatmap(bbox_data, save_dir=save_dir, global_save=global_save)
    plot_correlation_matrix_with_centers(bbox_data, save_dir=save_dir, global_save=global_save)
    compute_statistics(bbox_data)


# Example usage
yolo_dir = '/home/falcon/student1/coronis-data/itobos-challenge-data/train'  # Replace with your YOLO dataset directory
analyze_yolo_dir(yolo_dir)

