In [2]:
import os
import yaml
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from collections import Counter
from tqdm import tqdm
import random
from PIL import Image
import warnings
warnings.filterwarnings('ignore')


In [10]:
# Module 0: Data Loading Utilities
def load_yaml(yaml_path):
    """Load dataset configuration from YAML."""
    with open(yaml_path, 'r') as f:
        data = yaml.safe_load(f)
    return data

def get_paths(root_dir, split):
    img_dir = os.path.join(root_dir, split, 'images')
    label_dir = os.path.join(root_dir, split, 'labels')
    
    if not os.path.exists(img_dir):
        print(f"错误：{img_dir} 不存在")
        return [], []
    
    img_files = [f for f in os.listdir(img_dir) 
                 if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
    imgs = [os.path.join(img_dir, f) for f in img_files]
    
    matched = 0
    labels = []
    for img_file in img_files:
        stem = os.path.splitext(img_file)[0]  # 去掉 .jpg
        label_file = stem + '.txt'
        label_path = os.path.join(label_dir, os.path.basename(label_file))
        if os.path.exists(label_path):
            labels.append(label_path)
            matched += 1
    
    match_rate = matched / len(imgs) if imgs else 0
    unique_stem_rate = len(set(os.path.splitext(f)[0] for f in img_files)) / len(img_files)
    
    print(f"{split:>5}: {len(imgs):>5} 图像, {len(labels):>5} 标签, "
          f"标签匹配: {match_rate:.1%}, 图像唯一性: {unique_stem_rate:.1%}")
    
    return imgs, labels

# Module 1: Data Quantity Statistics
def analyze_data_quantity(root_dir, splits=['train', 'val', 'test']):
    """Analyze image and label counts per split."""
    stats = {}
    for split in splits:
        imgs, labels = get_paths(root_dir, split)
        stats[split] = {
            'images': len(imgs),
            'labels': len(labels),
            'match': len(imgs) == len(labels)
        }
    return stats

# Module 2: Class Distribution
def analyze_class_distribution(label_paths, class_names):
    """Count occurrences per class."""
    class_counts = Counter()
    for label_path in label_paths:
        with open(label_path, 'r') as f:
            lines = f.readlines()
        for line in lines:
            cls = int(line.split()[0])
            class_counts[cls] += 1
    df = pd.DataFrame({'class': class_names, 'count': [class_counts[i] for i in range(len(class_names))]})
    max_min_ratio = df['count'].max() / df['count'].min() if df['count'].min() > 0 else float('inf')
    return df, max_min_ratio

def visualize_class_distribution(df, output_path):
    """Plot class distribution bar chart."""
    plt.figure(figsize=(10, 6))
    sns.barplot(x='class', y='count', data=df)
    plt.title('Class Distribution')
    plt.xticks(rotation=45)
    plt.savefig(output_path)
    plt.close()

# Module 3: Image Sizes and Aspect Ratios
def analyze_image_sizes(img_paths):
    """Collect width, height, and aspect ratios."""
    sizes = []
    for img_path in tqdm(img_paths, desc="Analyzing sizes"):
        img = cv2.imread(img_path)
        if img is None:
            continue  # Skip invalid images
        h, w = img.shape[:2]
        sizes.append((w, h, w / h))
    df = pd.DataFrame(sizes, columns=['width', 'height', 'aspect_ratio'])
    return df

def visualize_image_sizes(df, output_path):
    """Plot histograms for sizes and aspect ratios."""
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    sns.histplot(df['width'], kde=True, color='blue', label='Width')
    sns.histplot(df['height'], kde=True, color='red', label='Height')
    plt.title('Width and Height Distribution')
    plt.legend()
    plt.subplot(1, 2, 2)
    sns.histplot(df['aspect_ratio'], kde=True)
    plt.title('Aspect Ratio Distribution')
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

# Module 4: Bounding Box Quality
def analyze_bbox_quality(label_paths):
    """Check for invalid or empty bounding boxes."""
    invalid = 0
    empty = 0
    total_bboxes = 0
    for label_path in label_paths:
        with open(label_path, 'r') as f:
            lines = f.readlines()
        if not lines:
            empty += 1
            continue
        for line in lines:
            parts = line.split()
            if len(parts) != 5:
                invalid += 1
                continue
            cls, x, y, w, h = map(float, parts)
            if any(v <= 0 for v in [w, h]) or x < 0 or y < 0 or x > 1 or y > 1:
                invalid += 1
            total_bboxes += 1
    return invalid, empty, total_bboxes

def visualize_random_samples(img_paths, label_paths, class_names, num_samples=5, output_path=None):
    """Visualize random images with bounding boxes."""
    samples = random.sample(range(len(img_paths)), min(num_samples, len(img_paths)))
    fig, axs = plt.subplots(1, len(samples), figsize=(20, 5))
    for i, idx in enumerate(samples):
        img = cv2.imread(img_paths[idx])
        h, w = img.shape[:2]
        with open(label_paths[idx], 'r') as f:
            lines = f.readlines()
        for line in lines:
            parts = line.split()
            if len(parts) != 5: continue
            cls, x, y, bw, bh = map(float, parts)
            x1 = int((x - bw/2) * w)
            y1 = int((y - bh/2) * h)
            x2 = int((x + bw/2) * w)
            y2 = int((y + bh/2) * h)
            cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.putText(img, class_names[int(cls)], (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
        axs[i].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        axs[i].axis('off')
    if output_path:
        plt.savefig(output_path)
    plt.close()

# Module 5: Objects per Image
def analyze_objects_per_image(label_paths):
    """Calculate objects count per image."""
    counts = []
    for label_path in label_paths:
        with open(label_path, 'r') as f:
            lines = f.readlines()
        counts.append(len(lines))
    df = pd.DataFrame({'objects': counts})
    return df['objects'].mean(), df['objects'].max(), df

def visualize_objects_per_image(df, output_path):
    """Plot histogram of objects per image."""
    plt.figure(figsize=(8, 5))
    sns.histplot(df['objects'], kde=True)
    plt.title('Objects per Image Distribution')
    plt.savefig(output_path)
    plt.close()

# Module 6: Image Quality (Brightness & Blur)
def analyze_image_quality(img_paths):
    """Compute brightness and blur metrics."""
    brightness = []
    blur = []
    for img_path in tqdm(img_paths, desc="Analyzing quality"):
        img = cv2.imread(img_path, 0)
        if img is None:
            continue
        brightness.append(np.mean(img))
        lap = cv2.Laplacian(img, cv2.CV_64F).var()
        blur.append(lap)
    df = pd.DataFrame({'brightness': brightness, 'blur': blur})
    return df

def visualize_image_quality(df, output_path):
    """Plot distributions for brightness and blur."""
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    sns.histplot(df['brightness'], kde=True)
    plt.title('Brightness Distribution')
    plt.subplot(1, 2, 2)
    sns.histplot(df['blur'], kde=True)
    plt.title('Blur (Laplacian Variance) Distribution')
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

def get_extreme_samples(img_paths, df, num=3):
    """Get paths of extreme images (dark, bright, blurry)."""
    dark = df.nsmallest(num, 'brightness').index
    bright = df.nlargest(num, 'brightness').index
    blurry = df.nsmallest(num, 'blur').index
    extremes = [img_paths[i] for i in list(dark) + list(bright) + list(blurry)]
    return extremes

# Module 7: Class Visual Differences
def collect_class_samples(img_paths, label_paths, class_names, num_per_class=3):
    """Collect sample images per class."""
    class_imgs = {i: [] for i in range(len(class_names))}
    for img_path, label_path in zip(img_paths, label_paths):
        with open(label_path, 'r') as f:
            lines = f.readlines()
        classes = set([int(line.split()[0]) for line in lines if line.strip()])
        for cls in classes:
            if len(class_imgs[cls]) < num_per_class:
                class_imgs[cls].append(img_path)
    return class_imgs

def visualize_class_samples(class_imgs, class_names, output_path):
    """Plot grid of sample images per class."""
    num_classes = len(class_imgs)
    rows, cols = (3, 4) if num_classes >= 12 else (1, num_classes)
    fig, axs = plt.subplots(rows, cols, figsize=(12, 9) if num_classes >= 12 else (20, 5))
    axs = axs.flatten()
    for i, (cls, imgs) in enumerate(class_imgs.items()):
        if imgs:
            img = cv2.imread(imgs[0])
            axs[i].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            axs[i].set_title(class_names[cls])
            axs[i].axis('off')
    for j in range(i + 1, len(axs)):
        axs[j].axis('off')
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

def extract_histogram_features(img_paths):
    """Extract color histogram features."""
    features = []
    for img_path in img_paths:
        img = cv2.imread(img_path)
        if img is None:
            continue
        hist = cv2.calcHist([img], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256]).flatten()
        features.append(hist)
    return np.array(features)

def analyze_and_visualize_clustering(features, labels, output_path):
    """Perform PCA/t-SNE and plot."""
    if len(features) == 0:
        return
    pca = PCA(n_components=2).fit_transform(features)
    tsne = TSNE(n_components=2).fit_transform(features)
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    sns.scatterplot(x=pca[:,0], y=pca[:,1], hue=labels, palette='viridis')
    plt.title('PCA Clustering')
    plt.subplot(1, 2, 2)
    sns.scatterplot(x=tsne[:,0], y=tsne[:,1], hue=labels, palette='viridis')
    plt.title('t-SNE Clustering')
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

# Module 8: Split Consistency
def analyze_split_consistency(root_dir, class_names, splits=['train', 'val', 'test']):
    """Compare distributions across splits."""
    class_dfs = {}
    size_dfs = {}
    obj_dfs = {}
    for split in splits:
        imgs, labels = get_paths(root_dir, split)
        class_dfs[split] = analyze_class_distribution(labels, class_names)[0]
        size_dfs[split] = analyze_image_sizes(imgs)
        _, _, obj_dfs[split] = analyze_objects_per_image(labels)
    return class_dfs, size_dfs, obj_dfs

def visualize_split_consistency(class_dfs, size_dfs, obj_dfs, output_paths):
    """Plot comparisons for classes, sizes, objects."""
    # Class distribution
    combined_class = pd.concat([df.assign(split=split) for split, df in class_dfs.items()])
    plt.figure(figsize=(12, 6))
    sns.barplot(x='class', y='count', hue='split', data=combined_class)
    plt.title('Class Distribution Across Splits')
    plt.xticks(rotation=45)
    plt.savefig(output_paths[0])
    plt.close()
    
    # Aspect ratios
    plt.figure(figsize=(12, 6))
    for split, df in size_dfs.items():
        sns.kdeplot(df['aspect_ratio'], label=split)
    plt.title('Aspect Ratio Distribution Across Splits')
    plt.legend()
    plt.savefig(output_paths[1])
    plt.close()
    
    # Objects per image
    plt.figure(figsize=(12, 6))
    for split, df in obj_dfs.items():
        sns.kdeplot(df['objects'], label=split)
    plt.title('Objects per Image Across Splits')
    plt.legend()
    plt.savefig(output_paths[2])
    plt.close()

# Module 9: Data Augmentation Suggestions
def generate_augmentation_suggestions(class_df, max_min_ratio, size_df, bbox_invalid, bbox_total, avg_obj, quality_df):
    """Generate suggestions based on analyses."""
    suggestions = []
    if max_min_ratio > 3:
        suggestions.append(('样本极度不平衡', '类别平衡采样，增加稀有类数据增强'))
    if size_df['width'].std() > 100 or size_df['height'].std() > 100:
        suggestions.append(('尺寸差异大', 'Resize+Pad 保持长宽比'))
    if quality_df['blur'].mean() < 50:  # Adjustable threshold
        suggestions.append(('模糊样本较多', '使用 A.Sharpen 、A.MotionBlur 增强'))
    if quality_df['brightness'].min() < 50 or quality_df['brightness'].max() > 200:
        suggestions.append(('夜间样本少', '光照增强（CLAHE等）'))
    if avg_obj > 5:
        suggestions.append(('背景复杂遮挡多', 'Mosaic / MixUp 增强有效'))
    return pd.DataFrame(suggestions, columns=['问题', '策略'])

# Report Generation
def generate_report(output_dir, qty_stats, class_df, max_min_ratio, invalid, empty, total_bboxes, avg_obj, max_obj, suggestions):
    """Generate Markdown report with results and images."""
    report = """
# AgroPest-12 EDA Report

## 1. Data Quantity
{0}

## 2. Class Distribution
Max/Min Ratio: {1:.2f}
![Class Dist](class_dist.png)

## 3. Image Sizes
![Sizes](sizes.png)

## 4. BBox Quality
Invalid: {2}/{3} ({4:.2f}%)
Empty: {5}
![Samples](samples.png)

## 5. Objects per Image
Avg: {6:.2f}, Max: {7}
![Objects Hist](objects_hist.png)

## 6. Image Quality
![Quality](quality.png)
Extreme Samples: extreme_0.jpg, extreme_1.jpg, ... (check folder)

## 7. Class Differences
![Class Samples](class_samples.png)
![Clustering](clustering.png)

## 8. Split Consistency
![Class Consist](class_consist.png)
![Size Consist](size_consist.png)
![Obj Consist](obj_consist.png)

## 9. Augmentation Suggestions
{8}
""".format(
        qty_stats,
        max_min_ratio,
        invalid, total_bboxes, (invalid / total_bboxes * 100) if total_bboxes > 0 else 0,
        empty,
        avg_obj, max_obj,
        suggestions.to_markdown(index=False)
    )
    report_path = os.path.join(output_dir, 'report.md')
    with open(report_path, 'w') as f:
        f.write(report)
    print(f"Report generated at: {report_path}")


In [12]:
root_dir='data/AgroPest12'
output_dir='eda_outputs'
os.makedirs(output_dir, exist_ok=True)
    
# 加载配置
yaml_path = os.path.join(root_dir, 'data.yaml')
if not os.path.exists(yaml_path):
    raise FileNotFoundError(f"找不到 data.yaml，请确保在 {root_dir} 目录下")

config = load_yaml(yaml_path)
class_names = config.get('names', [])
print(f"检测到 {len(class_names)} 个类别: {class_names}")

# 加载所有数据路径（适配新结构）
splits = ['train', 'valid', 'test']
all_imgs = {}
all_labels = {}

print("\n=== 数据集结构检查 ===")
for split in splits:
    imgs, labels = get_paths(root_dir, split)
    all_imgs[split] = imgs
    all_labels[split] = labels



检测到 12 个类别: ['Ants', 'Bees', 'Beetles', 'Caterpillars', 'Earthworms', 'Earwigs', 'Grasshoppers', 'Moths', 'Slugs', 'Snails', 'Wasps', 'Weevils']

=== 数据集结构检查 ===
train: 11502 图像, 11502 标签, 标签匹配: 100.0%, 图像唯一性: 100.0%
valid:  1095 图像,  1095 标签, 标签匹配: 100.0%, 图像唯一性: 100.0%
 test:   546 图像,   546 标签, 标签匹配: 100.0%, 图像唯一性: 100.0%


In [13]:
# 1. 数据量统计
qty_stats = {
    split: {'images': len(imgs), 'labels': len(labels), 'match_rate': len(imgs) == len(labels)}
    for split, (imgs, labels) in zip(splits, [(all_imgs[s], all_labels[s]) for s in splits])
}
print("\n1. 数据量统计:")
print(pd.DataFrame(qty_stats).T)



1. 数据量统计:
      images labels match_rate
train  11502  11502       True
valid   1095   1095       True
test     546    546       True


In [14]:
# 2. 类别分布（基于train）
print("\n2. 类别分布分析...")
class_df, max_min_ratio = analyze_class_distribution(all_labels['train'], class_names)
print(f"最大/最小类别比例: {max_min_ratio:.2f}")
print(class_df.sort_values('count', ascending=False))



2. 类别分布分析...
最大/最小类别比例: 2.43
           class  count
0           Ants   2231
3   Caterpillars   1740
1           Bees   1596
9         Snails   1199
5        Earwigs   1182
10         Wasps   1167
4     Earthworms   1083
6   Grasshoppers   1071
7          Moths   1062
2        Beetles   1058
11       Weevils    975
8          Slugs    918


In [15]:
# 3. 图像尺寸
print("\n3. 图像尺寸分析...")
size_df_train = analyze_image_sizes(all_imgs['train'])
print(f"尺寸范围: {size_df_train['width'].min()}x{size_df_train['height'].min()} ~ {size_df_train['width'].max()}x{size_df_train['height'].max()}")



3. 图像尺寸分析...


Analyzing sizes: 100%|██████████| 11502/11502 [00:23<00:00, 489.96it/s]

尺寸范围: 640x640 ~ 640x640





In [16]:
# 4. 标注质量检查（全部数据）
print("\n4. 标注质量检查...")
all_label_paths = []
for labels in all_labels.values():
    all_label_paths.extend(labels)
invalid, empty, total_bboxes = analyze_bbox_quality(all_label_paths)
print(f"无效框: {invalid}/{total_bboxes} ({invalid/total_bboxes*100:.2f}%)")
print(f"空标签文件: {empty}/{len(all_label_paths)} ({empty/len(all_label_paths)*100:.2f}%)")



4. 标注质量检查...
无效框: 0/17312 (0.00%)
空标签文件: 3/13143 (0.02%)


In [17]:
# 5. 每图目标数
avg_obj, max_obj, _ = analyze_objects_per_image(all_labels['train'])
print(f"5. 每图平均目标数: {avg_obj:.2f}, 最大: {max_obj}")



5. 每图平均目标数: 1.33, 最大: 49


In [20]:
# 6. 图像质量
print("\n6. 图像质量分析...")
quality_df = analyze_image_quality(all_imgs['train'])
print(f"平均亮度: {quality_df['brightness'].mean():.1f}, 平均清晰度: {quality_df['blur'].mean():.1f}")

# 生成所有可视化
print("\n7. 生成可视化图表...")

# 类别分布图
visualize_class_distribution(class_df, os.path.join(output_dir, 'class_dist.png'))

# 尺寸分布图
visualize_image_sizes(size_df_train, os.path.join(output_dir, 'sizes.png'))

# 随机样本可视化
visualize_random_samples(all_imgs['train'], all_labels['train'], class_names, 
                        num_samples=8, output_path=os.path.join(output_dir, 'samples.png'))

# 每图目标数分布
_, _, obj_df = analyze_objects_per_image(all_labels['train'])
visualize_objects_per_image(obj_df, os.path.join(output_dir, 'objects_hist.png'))

# 质量分布
visualize_image_quality(quality_df, os.path.join(output_dir, 'quality.png'))

# 类别样本展示
class_imgs = collect_class_samples(all_imgs['train'], all_labels['train'], class_names)
visualize_class_samples(class_imgs, class_names, os.path.join(output_dir, 'class_samples.png'))

# 数据增强建议
suggestions = generate_augmentation_suggestions(
    class_df, max_min_ratio, size_df_train, invalid, total_bboxes, avg_obj, quality_df
)

print(f"\n✅ EDA 完成！结果保存在: {output_dir}")
print("\n9. 数据增强建议:")
print(suggestions)

# 生成报告
generate_report(output_dir, qty_stats, class_df, max_min_ratio, invalid, empty, 
                total_bboxes, avg_obj, max_obj, suggestions)

print({
    'qty_stats': qty_stats,
    'class_df': class_df,
    'max_min_ratio': max_min_ratio,
    'bbox_quality': (invalid, empty, total_bboxes),
    'objects_stats': (avg_obj, max_obj),
    'suggestions': suggestions
})


6. 图像质量分析...


Analyzing quality: 100%|██████████| 11502/11502 [00:51<00:00, 224.13it/s]


平均亮度: 118.5, 平均清晰度: 159.7

7. 生成可视化图表...

✅ EDA 完成！结果保存在: eda_outputs

9. 数据增强建议:
      问题            策略
0  夜间样本少  光照增强（CLAHE等）
Report generated at: eda_outputs\report.md
{'qty_stats': {'train': {'images': 11502, 'labels': 11502, 'match_rate': True}, 'valid': {'images': 1095, 'labels': 1095, 'match_rate': True}, 'test': {'images': 546, 'labels': 546, 'match_rate': True}}, 'class_df':            class  count
0           Ants   2231
1           Bees   1596
2        Beetles   1058
3   Caterpillars   1740
4     Earthworms   1083
5        Earwigs   1182
6   Grasshoppers   1071
7          Moths   1062
8          Slugs    918
9         Snails   1199
10         Wasps   1167
11       Weevils    975, 'max_min_ratio': np.float64(2.4302832244008714), 'bbox_quality': (0, 3, 17312), 'objects_stats': (np.float64(1.3286384976525822), np.int64(49)), 'suggestions':       问题            策略
0  夜间样本少  光照增强（CLAHE等）}
