In [18]:
import json
import os
import random
import numpy as np
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt  # 确保导入matplotlib

# 定义常量
MAX_DEPTH_VALUE = 2 ** 16  # 2的16次幂

def read_coco_annotations(json_file):
    # 读取COCO格式的注释文件
    with open(json_file, 'r') as f:
        return json.load(f)

def calculate_center(bbox):
    # 计算边界框的中心点
    x, y, w, h = bbox
    return int(x + w / 2), int(y + h / 2)

def get_depth_from_image(depth_image, center_point):
    # 从深度图像中获取中心点的深度值
    x, y = center_point
    return depth_image[min(y, depth_image.shape[0] - 1), min(x, depth_image.shape[1] - 1)]

def plot_scatter(depths, log_sizes, categories, category_styles, title, x_title, y_title,
                 output_path=None, dpi=400):
    plt.figure(figsize=(8, 9), dpi=dpi)  # 设置图形大小和分辨率
    unique_categories = list(set(categories))  # 获取唯一的类别列表
    
    for category in unique_categories:
        cate_indices = [i for i, c in enumerate(categories) if c == category]  # 获取当前类别的索引
        cate_depths = [depths[i] for i in cate_indices]  # 获取当前类别的深度值
        cate_sizes = [log_sizes[i] for i in cate_indices]  # 获取当前类别的大小值
        
        # 如果类别是'car'，则随机选取一定比例的数据点
        if category == 'car':
            sample_ratio = 0.7  # 设置比例
            sample_size = int(len(cate_depths) * sample_ratio)
            sampled_indices = random.sample(range(len(cate_depths)), sample_size)
            cate_depths = [cate_depths[i] for i in sampled_indices]
            cate_sizes = [cate_sizes[i] for i in sampled_indices]

        style = category_styles.get(category, {})  # 获取类别的样式
        color = style.get('color', 'gray')  # 获取颜色
        alpha = style.get('alpha', 0.5)  # 获取透明度
        size = style.get('size', 5)  # 获取大小
        marker = style.get('marker', 'o')  # 获取标记符号
        
        # 绘制散点图
        plt.scatter(cate_depths, cate_sizes, alpha=alpha, s=size, label=category, color=color, marker=marker)

    plt.xlabel(x_title, fontsize=17)  # 设置x轴标签
    plt.ylabel(y_title, fontsize=17)  # 设置y轴标签
    plt.title(title, fontsize=16)  # 设置标题
    plt.grid(True)  # 显示网格
    plt.tight_layout()  # 自动调整子图参数

    plt.xticks(fontsize=13)  # 设置x轴刻度字体大小
    plt.yticks(fontsize=13)  # 设置y轴刻度字体大小

    # 设置图例位置为左下角
    plt.legend(title='Category', fontsize=16, bbox_to_anchor=(0, 0), loc='lower left')

    if output_path:
        plt.savefig(output_path, bbox_inches='tight')  # 保存图像
    plt.show()  # 显示图像

def main(annotation_file, depth_image_dir, category_styles, sample_size=100, output_path=None, dpi=400):
    annotations = read_coco_annotations(annotation_file)  # 读取注释文件
    image_info_list = annotations['images']  # 获取图像信息列表
    sampled_images = random.sample(image_info_list, min(sample_size, len(image_info_list)))  # 随机抽样图像

    sizes, depths, categories = [], [], []  # 初始化大小、深度和类别列表

    category_id_to_name = {cat['id']: cat['name'] for cat in annotations['categories']}  # 创建类别ID到名称的映射

    for image_info in sampled_images:
        image_id = image_info['id']
        image_filename = os.path.basename(image_info['file_name'])
        depth_image_path = os.path.join(depth_image_dir, image_filename.replace('.jpg', '.png'))
        
        if not os.path.exists(depth_image_path):
            continue

        depth_image = np.array(Image.open(depth_image_path), dtype=np.float32)  # 读取深度图像

        for ann in filter(lambda a: a['image_id'] == image_id, annotations['annotations']):
            bbox = ann['bbox']
            category_id = ann['category_id']
            category_name = category_id_to_name[category_id]
            center = calculate_center(bbox)
            depth_value = get_depth_from_image(depth_image, center)

            normalized_depth_value = depth_value / MAX_DEPTH_VALUE  # 归一化深度值

            if not (0 <= normalized_depth_value <= 1):
                continue

            w, h = bbox[2], bbox[3]
            size = w * h
            log_size = np.log(size + 1)/np.log(1.8)  # 计算大小的对数值

            depths.append(normalized_depth_value)
            sizes.append(log_size)
            categories.append(category_name)

    if sizes and depths:
        plot_scatter(depths, sizes, categories, category_styles,
                     title='Plot', 
                     x_title='Relative Depth of Object', y_title='Object Size',
                     output_path=output_path, dpi=dpi)
    else:
        print("No valid data to plot.")  # 如果没有有效数据则打印提示

if __name__ == "__main__":
    annotation_file = "/opt/data/private/fcf/mmdetection/data/HazyDet-365k/train/train_coco.json"
    depth_image_dir = "/opt/data/private/fcf/mmdetection/data/HazyDetdevkit/HazyDet/depth_images"

    # 定义每个类别的样式
    category_styles = {
        'car': {'color': 'blue', 'alpha': 0.2, 'size': 45, 'marker': 'o'},
        'bus': {'color': 'orange', 'alpha':0.7, 'size': 55, 'marker': '^'},
        'truck': {'color': 'red', 'alpha': 0.5, 'size': 55, 'marker': 'D'}
    }

    main(annotation_file, depth_image_dir, category_styles, sample_size=70)
    
# 你可以在 `category_styles` 字典中使用这些标记符号来设置不同类别的散点形状。例如：
# 'o' : 圆形
# '^' : 上三角形
# 'v' : 下三角形
# 's' : 方形
# 'p' : 五边形
# '*' : 星形
# 'h' : 六边形1
# 'H' : 六边形2
# '+' : 加号
# 'x' : 叉号
# 'D' : 菱形
# 'd' : 瘦菱形
# '|' : 竖线
# '_' : 横线