# 🌲 树木检测项目 - Tree Detection Project

## 项目概述 Project Overview

本项目使用YOLOv5和PyTorch实现基于航空影像的树木检测任务  
This project implements tree detection from aerial imagery using YOLOv5 and PyTorch

**数据集 Dataset**: NeonTreeEvaluation Benchmark  
**目标 Goal**: 检测航空正射影像中的树木 - Detect trees in aerial orthoimagery  
**平台 Platform**: Google Colab (自动GPU检测 - Auto GPU detection)

## 技术栈 Tech Stack
- **深度学习框架 Deep Learning**: PyTorch + YOLOv5
- **数据处理 Data Processing**: OpenCV, PIL, pandas
- **可视化 Visualization**: matplotlib, seaborn

In [None]:
# 1. 环境检测和基础设置 - Environment Detection and Basic Setup
import os
import sys
import torch
import torchvision
import platform
from pathlib import Path

print("=" * 50)
print("🚀 环境检测 Environment Detection")
print("=" * 50)

# 检查Python版本 Check Python version
print(f"Python版本 Python Version: {sys.version}")

# 检查操作系统 Check OS
print(f"操作系统 Operating System: {platform.system()}")

# 检查是否在Colab环境 Check if running in Colab
try:
    import google.colab
    IN_COLAB = True
    print("✅ 运行环境: Google Colab - Running in Google Colab")
except ImportError:
    IN_COLAB = False
    print("❌ 运行环境: 本地环境 - Running locally")

# 检查GPU可用性 Check GPU availability
if torch.cuda.is_available():
    print(f"✅ GPU可用 GPU Available: {torch.cuda.get_device_name(0)}")
    print(f"   CUDA版本 CUDA Version: {torch.version.cuda}")
    print(f"   GPU数量 GPU Count: {torch.cuda.device_count()}")
    device = torch.device('cuda')
else:
    print("⚠️  GPU不可用，将使用CPU - GPU not available, using CPU")
    device = torch.device('cpu')

print(f"PyTorch版本 PyTorch Version: {torch.__version__}")
print(f"设备 Device: {device}")
print("=" * 50)

In [None]:
# 2. 安装必要依赖 - Install Required Dependencies
print("📦 安装依赖包 Installing Dependencies...")

# 安装YOLOv5和相关依赖 Install YOLOv5 and dependencies
!pip install -q ultralytics
!pip install -q opencv-python-headless
!pip install -q Pillow
!pip install -q matplotlib
!pip install -q seaborn
!pip install -q pandas
!pip install -q tqdm
!pip install -q scikit-learn
!pip install -q PyYAML

# 如果在Colab环境，挂载Google Drive (可选)
# Mount Google Drive in Colab (optional)
if IN_COLAB:
    from google.colab import drive
    # drive.mount('/content/drive')  # 取消注释以挂载Drive - Uncomment to mount Drive

print("✅ 依赖安装完成 Dependencies installed successfully!")

In [None]:
# 3. 导入所需库 - Import Required Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from PIL import Image
import json
import zipfile
import requests
from tqdm import tqdm
import shutil
import yaml
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

# 设置matplotlib中文字体支持 Set matplotlib Chinese font support
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

# 设置随机种子 Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

print("✅ 库导入完成 Libraries imported successfully!")

In [None]:
# 4. 数据下载和解压 - Data Download and Extraction
def download_file(url, filename):
    """
    下载文件的函数 Function to download files
    """
    print(f"🔄 开始下载 Starting download: {filename}")
    
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    
    with open(filename, 'wb') as file, tqdm(
        desc=filename,
        total=total_size,
        unit='B',
        unit_scale=True,
        unit_divisor=1024,
    ) as pbar:
        for chunk in response.iter_content(chunk_size=8192):
            if chunk:
                file.write(chunk)
                pbar.update(len(chunk))
    
    print(f"✅ 下载完成 Download completed: {filename}")

# 创建数据目录 Create data directories
os.makedirs('data', exist_ok=True)
os.makedirs('data/raw', exist_ok=True)
os.makedirs('data/processed', exist_ok=True)

# 数据集URL Dataset URLs
DATASET_URL = "https://zenodo.org/records/5914554/files/evaluation.zip?download=1"
ANNOTATIONS_URL = "https://zenodo.org/records/5914554/files/annotations.zip?download=1"

# 下载数据集 Download datasets
if not os.path.exists('data/raw/evaluation.zip'):
    download_file(DATASET_URL, 'data/raw/evaluation.zip')
else:
    print("✅ 评估数据集已存在 Evaluation dataset already exists")

if not os.path.exists('data/raw/annotations.zip'):
    download_file(ANNOTATIONS_URL, 'data/raw/annotations.zip')
else:
    print("✅ 标注数据已存在 Annotations already exist")

print("📁 数据下载完成 Data download completed!")

In [None]:
# 5. 解压数据集 - Extract Datasets
def extract_zip(zip_path, extract_to):
    """
    解压ZIP文件的函数 Function to extract ZIP files
    """
    print(f"📦 解压文件 Extracting: {zip_path}")
    
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    
    print(f"✅ 解压完成 Extraction completed: {extract_to}")

# 解压评估数据集 Extract evaluation dataset
if not os.path.exists('data/raw/evaluation'):
    extract_zip('data/raw/evaluation.zip', 'data/raw/')
else:
    print("✅ 评估数据集已解压 Evaluation dataset already extracted")

# 解压标注数据 Extract annotations
if not os.path.exists('data/raw/annotations'):
    extract_zip('data/raw/annotations.zip', 'data/raw/')
else:
    print("✅ 标注数据已解压 Annotations already extracted")

# 查看数据结构 Explore data structure
print("\n📊 数据结构分析 Data Structure Analysis")
print("=" * 50)

# 检查评估数据集结构 Check evaluation dataset structure
eval_path = Path('data/raw/evaluation')
if eval_path.exists():
    print(f"评估数据集路径 Evaluation dataset path: {eval_path}")
    subdirs = [d for d in eval_path.iterdir() if d.is_dir()]
    print(f"子目录数量 Number of subdirectories: {len(subdirs)}")
    for subdir in subdirs[:5]:  # 显示前5个子目录 Show first 5 subdirectories
        print(f"  - {subdir.name}")
    if len(subdirs) > 5:
        print(f"  ... 还有 {len(subdirs)-5} 个目录 and {len(subdirs)-5} more directories")

# 检查标注数据结构 Check annotations structure
ann_path = Path('data/raw/annotations')
if ann_path.exists():
    print(f"\n标注数据路径 Annotations path: {ann_path}")
    ann_files = list(ann_path.glob('*.csv'))
    print(f"CSV标注文件数量 Number of CSV annotation files: {len(ann_files)}")
    for ann_file in ann_files[:3]:  # 显示前3个标注文件 Show first 3 annotation files
        print(f"  - {ann_file.name}")

print("\n✅ 数据解压和结构分析完成 Data extraction and structure analysis completed!")

In [None]:
# 6. 数据格式转换 - Data Format Conversion (NeonTree -> YOLO)
class NeonTreeToYOLO:
    """
    将NeonTree数据集转换为YOLO格式的类
    Class to convert NeonTree dataset to YOLO format
    """
    
    def __init__(self, data_root, output_root):
        self.data_root = Path(data_root)
        self.output_root = Path(output_root)
        self.annotations_path = Path('data/raw/annotations')
        
        # 创建输出目录 Create output directories
        self.create_yolo_structure()
    
    def create_yolo_structure(self):
        """创建YOLO数据集目录结构 Create YOLO dataset directory structure"""
        folders = [
            'images/train', 'images/val', 'images/test',
            'labels/train', 'labels/val', 'labels/test'
        ]
        
        for folder in folders:
            (self.output_root / folder).mkdir(parents=True, exist_ok=True)
        
        print("✅ YOLO目录结构创建完成 YOLO directory structure created")
    
    def load_annotations(self):
        """加载标注数据 Load annotation data"""
        annotation_files = list(self.annotations_path.glob('*.csv'))
        all_annotations = []
        
        for ann_file in annotation_files:
            try:
                df = pd.read_csv(ann_file)
                df['site'] = ann_file.stem  # 添加站点信息 Add site information
                all_annotations.append(df)
                print(f"✅ 加载标注文件 Loaded annotations: {ann_file.name} ({len(df)} 条记录 records)")
            except Exception as e:
                print(f"❌ 无法加载 Could not load: {ann_file.name} - {e}")
        
        if all_annotations:
            combined_df = pd.concat(all_annotations, ignore_index=True)
            print(f"📊 总标注数量 Total annotations: {len(combined_df)}")
            return combined_df
        else:
            print("❌ 未找到有效的标注数据 No valid annotation data found")
            return pd.DataFrame()
    
    def find_rgb_images(self):
        """查找RGB图像文件 Find RGB image files"""
        rgb_images = []
        
        # 搜索evaluation目录下的RGB文件夹 Search for RGB folders in evaluation directory
        for site_dir in self.data_root.iterdir():
            if site_dir.is_dir():
                rgb_dir = site_dir / 'RGB'
                if rgb_dir.exists():
                    # 查找图像文件 Find image files
                    for img_ext in ['*.jpg', '*.jpeg', '*.png', '*.tif', '*.tiff']:
                        rgb_images.extend(list(rgb_dir.glob(img_ext)))
        
        print(f"🖼️  找到RGB图像 Found RGB images: {len(rgb_images)}")
        return rgb_images
    
    def convert_bbox_to_yolo(self, bbox, img_width, img_height):
        """
        将边界框坐标转换为YOLO格式
        Convert bounding box coordinates to YOLO format
        
        YOLO格式: [class_id, x_center, y_center, width, height] (归一化 normalized)
        """
        # 假设bbox格式为 [x_min, y_min, x_max, y_max]
        # Assume bbox format is [x_min, y_min, x_max, y_max]
        x_min, y_min, x_max, y_max = bbox
        
        # 计算中心点和宽高 Calculate center point and dimensions
        x_center = (x_min + x_max) / 2.0
        y_center = (y_min + y_max) / 2.0
        width = x_max - x_min
        height = y_max - y_min
        
        # 归一化 Normalize
        x_center /= img_width
        y_center /= img_height
        width /= img_width
        height /= img_height
        
        return [0, x_center, y_center, width, height]  # 类别ID为0 (树木 tree)
    
    def process_annotations(self, annotations_df, rgb_images):
        """
        处理标注数据并转换为YOLO格式
        Process annotations and convert to YOLO format
        """
        processed_count = 0
        
        # 创建图像名称到路径的映射 Create mapping from image names to paths
        img_name_to_path = {}
        for img_path in rgb_images:
            img_name = img_path.stem
            img_name_to_path[img_name] = img_path
        
        print(f"📝 开始处理标注 Starting annotation processing...")
        
        for _, row in tqdm(annotations_df.iterrows(), total=len(annotations_df)):
            try:
                # 获取图像信息 Get image information
                site = row.get('site', '')
                
                # 尝试匹配图像文件 Try to match image file
                img_path = None
                for img_name, path in img_name_to_path.items():
                    if site in str(path) or img_name == site:
                        img_path = path
                        break
                
                if img_path is None:
                    continue
                
                # 读取图像获取尺寸 Read image to get dimensions
                img = cv2.imread(str(img_path))
                if img is None:
                    continue
                    
                img_height, img_width = img.shape[:2]
                
                # 提取边界框信息 Extract bounding box information
                # 根据实际的CSV列名调整 Adjust according to actual CSV column names
                if 'xmin' in row and 'ymin' in row and 'xmax' in row and 'ymax' in row:
                    bbox = [row['xmin'], row['ymin'], row['xmax'], row['ymax']]
                elif 'left' in row and 'top' in row and 'right' in row and 'bottom' in row:
                    bbox = [row['left'], row['top'], row['right'], row['bottom']]
                else:
                    # 如果找不到标准的边界框列，跳过
                    # Skip if standard bounding box columns are not found
                    continue
                
                # 转换为YOLO格式 Convert to YOLO format
                yolo_bbox = self.convert_bbox_to_yolo(bbox, img_width, img_height)
                
                # 保存图像和标签 Save image and label
                self.save_image_and_label(img_path, img, yolo_bbox, processed_count)
                processed_count += 1
                
            except Exception as e:
                print(f"处理标注时出错 Error processing annotation: {e}")
                continue
        
        print(f"✅ 处理完成 Processing completed: {processed_count} 个样本 samples")
        return processed_count
    
    def save_image_and_label(self, original_img_path, img, yolo_bbox, index):
        """
        保存图像和对应的YOLO标签
        Save image and corresponding YOLO label
        """
        # 确定数据集分割 Determine dataset split
        if index % 10 < 7:  # 70% 训练集 training set
            split = 'train'
        elif index % 10 < 9:  # 20% 验证集 validation set
            split = 'val'
        else:  # 10% 测试集 test set
            split = 'test'
        
        # 保存图像 Save image
        img_filename = f"tree_{index:06d}.jpg"
        img_save_path = self.output_root / 'images' / split / img_filename
        cv2.imwrite(str(img_save_path), img)
        
        # 保存标签 Save label
        label_filename = f"tree_{index:06d}.txt"
        label_save_path = self.output_root / 'labels' / split / label_filename
        
        with open(label_save_path, 'w') as f:
            # YOLO格式: class_id x_center y_center width height
            f.write(f"{yolo_bbox[0]} {yolo_bbox[1]:.6f} {yolo_bbox[2]:.6f} {yolo_bbox[3]:.6f} {yolo_bbox[4]:.6f}\n")

# 执行数据转换 Execute data conversion
print("🔄 开始数据格式转换 Starting data format conversion...")
converter = NeonTreeToYOLO('data/raw/evaluation', 'data/processed')

# 加载标注数据 Load annotation data
annotations_df = converter.load_annotations()

if not annotations_df.empty:
    # 查找RGB图像 Find RGB images
    rgb_images = converter.find_rgb_images()
    
    if rgb_images:
        # 处理标注 Process annotations
        processed_samples = converter.process_annotations(annotations_df, rgb_images)
        print(f"✅ 数据转换完成 Data conversion completed: {processed_samples} 个样本 samples")
    else:
        print("❌ 未找到RGB图像 No RGB images found")
else:
    print("❌ 未找到标注数据 No annotation data found")

In [None]:
# 7. 创建YOLO配置文件 - Create YOLO Configuration Files
def create_dataset_yaml():
    """
    创建YOLO数据集配置文件
    Create YOLO dataset configuration file
    """
    dataset_config = {
        'train': 'data/processed/images/train',
        'val': 'data/processed/images/val',
        'test': 'data/processed/images/test',
        'nc': 1,  # 类别数量 number of classes
        'names': ['tree']  # 类别名称 class names
    }
    
    # 保存配置文件 Save configuration file
    with open('data/tree_dataset.yaml', 'w') as f:
        yaml.dump(dataset_config, f, default_flow_style=False)
    
    print("✅ YOLO数据集配置文件创建完成 YOLO dataset configuration file created")
    print("📄 配置文件路径 Configuration file path: data/tree_dataset.yaml")
    
    return 'data/tree_dataset.yaml'

# 创建配置文件 Create configuration file
dataset_yaml_path = create_dataset_yaml()

# 显示配置文件内容 Display configuration file content
with open(dataset_yaml_path, 'r') as f:
    config_content = f.read()
    print("\n📋 数据集配置内容 Dataset Configuration Content:")
    print("=" * 40)
    print(config_content)
    print("=" * 40)

In [None]:
# 8. YOLOv5模型训练 - YOLOv5 Model Training
from ultralytics import YOLO

# 检查预训练模型 Check pre-trained model
pretrained_model_path = 'best.pt'
if os.path.exists(pretrained_model_path):
    print(f"✅ 找到预训练模型 Pre-trained model found: {pretrained_model_path}")
    model = YOLO(pretrained_model_path)
else:
    print("⚠️  未找到best.pt，使用YOLOv5s预训练模型 best.pt not found, using YOLOv5s pre-trained model")
    model = YOLO('yolov5s.pt')

# 训练参数设置 Training parameters
training_args = {
    'data': dataset_yaml_path,        # 数据集配置文件 dataset configuration file
    'epochs': 50,                     # 训练轮数 training epochs (可根据需要调整 adjust as needed)
    'imgsz': 640,                     # 输入图像尺寸 input image size
    'batch': 16,                      # 批次大小 batch size (根据GPU内存调整 adjust based on GPU memory)
    'optimizer': 'AdamW',             # 优化器 optimizer
    'lr0': 0.001,                     # 初始学习率 initial learning rate
    'weight_decay': 0.0005,           # 权重衰减 weight decay
    'warmup_epochs': 3,               # 预热轮数 warmup epochs
    'patience': 10,                   # 早停耐心值 early stopping patience
    'save_period': 10,                # 模型保存间隔 model save interval
    'device': device,                 # 设备 device
    'workers': 2,                     # 数据加载工作进程数 data loading workers
    'project': 'runs/detect',         # 项目目录 project directory
    'name': 'tree_detection',         # 实验名称 experiment name
    'exist_ok': True,                 # 覆盖已存在的实验 overwrite existing experiment
    'pretrained': True,               # 是否使用预训练权重 whether to use pre-trained weights
    'verbose': True                   # 详细输出 verbose output
}

print("🚀 开始训练模型 Starting model training...")
print(f"训练参数 Training parameters: {training_args}")
print("=" * 60)

# 开始训练 Start training
try:
    results = model.train(**training_args)
    print("✅ 模型训练完成 Model training completed!")
    print(f"训练结果保存在 Training results saved in: runs/detect/tree_detection")
except Exception as e:
    print(f"❌ 训练过程中出现错误 Error during training: {e}")
    print("请检查数据集是否正确准备 Please check if dataset is properly prepared")

In [None]:
# 9. 模型推理和结果可视化 - Model Inference and Result Visualization
def load_trained_model():
    """
    加载训练好的模型
    Load trained model
    """
    # 查找最新的训练结果 Find latest training results
    model_paths = [
        'runs/detect/tree_detection/weights/best.pt',
        'runs/detect/tree_detection/weights/last.pt',
        'best.pt',  # 原始预训练模型 original pre-trained model
        'yolov5s.pt'  # 默认模型 default model
    ]
    
    for model_path in model_paths:
        if os.path.exists(model_path):
            print(f"✅ 加载模型 Loading model: {model_path}")
            return YOLO(model_path)
    
    print("❌ 未找到可用模型 No available model found")
    return None

def visualize_predictions(model, test_images, output_dir='results'):
    """
    可视化预测结果
    Visualize prediction results
    """
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"🖼️  开始推理 Starting inference on {len(test_images)} images...")
    
    for i, img_path in enumerate(test_images[:10]):  # 限制为前10张图像 limit to first 10 images
        try:
            # 进行推理 Perform inference
            results = model(str(img_path))
            
            # 获取预测结果 Get prediction results
            result = results[0]
            
            # 在图像上绘制检测框 Draw detection boxes on image
            annotated_img = result.plot()
            
            # 保存结果图像 Save result image
            output_path = os.path.join(output_dir, f'detection_result_{i+1}.jpg')
            cv2.imwrite(output_path, annotated_img)
            
            # 显示检测信息 Display detection information
            boxes = result.boxes
            if boxes is not None:
                print(f"图像 {i+1} Image {i+1}: 检测到 {len(boxes)} 棵树 trees detected")
                for j, box in enumerate(boxes):
                    conf = box.conf[0].item()
                    print(f"  树木 Tree {j+1}: 置信度 confidence = {conf:.3f}")
            else:
                print(f"图像 {i+1} Image {i+1}: 未检测到树木 No trees detected")
                
        except Exception as e:
            print(f"处理图像 {i+1} 时出错 Error processing image {i+1}: {e}")
    
    print(f"✅ 推理完成 Inference completed! 结果保存在 Results saved in: {output_dir}")

def display_sample_results(results_dir='results', num_samples=4):
    """
    显示样本检测结果
    Display sample detection results
    """
    if not os.path.exists(results_dir):
        print("❌ 结果目录不存在 Results directory does not exist")
        return
        
    result_images = [f for f in os.listdir(results_dir) if f.endswith('.jpg')]
    
    if not result_images:
        print("❌ 未找到结果图像 No result images found")
        return
    
    # 创建子图 Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('🌲 树木检测结果 Tree Detection Results', fontsize=16, fontweight='bold')
    
    axes = axes.flatten()
    
    for i in range(min(num_samples, len(result_images))):
        img_path = os.path.join(results_dir, result_images[i])
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        axes[i].imshow(img)
        axes[i].set_title(f'检测结果 Detection Result {i+1}', fontsize=12)
        axes[i].axis('off')
    
    # 隐藏未使用的子图 Hide unused subplots
    for i in range(len(result_images), len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

# 加载模型并进行推理 Load model and perform inference
model = load_trained_model()

if model is not None:
    # 获取测试图像 Get test images
    test_image_dir = Path('data/processed/images/test')
    if test_image_dir.exists():
        test_images = list(test_image_dir.glob('*.jpg'))
        if test_images:
            # 进行推理和可视化 Perform inference and visualization
            visualize_predictions(model, test_images)
            
            # 显示样本结果 Display sample results
            display_sample_results()
            
        else:
            print("❌ 测试图像目录为空 Test image directory is empty")
    else:
        print("❌ 测试图像目录不存在 Test image directory does not exist")
        
        # 使用一些示例图像进行测试 Use some sample images for testing
        eval_images = list(Path('data/raw/evaluation').rglob('*.jpg'))
        if eval_images:
            print(f"📸 使用原始评估图像进行测试 Using original evaluation images for testing: {len(eval_images)}")
            visualize_predictions(model, eval_images[:5])  # 使用前5张图像 use first 5 images
            display_sample_results()
else:
    print("❌ 无法加载模型进行推理 Cannot load model for inference")

In [None]:
# 10. 模型评估和性能分析 - Model Evaluation and Performance Analysis
def evaluate_model_performance(model, test_images):
    """
    评估模型性能
    Evaluate model performance
    """
    print("📊 开始模型性能评估 Starting model performance evaluation...")
    
    total_images = len(test_images)
    total_detections = 0
    confidence_scores = []
    
    # 统计检测结果 Count detection results
    for i, img_path in enumerate(test_images):
        try:
            results = model(str(img_path))
            result = results[0]
            
            if result.boxes is not None:
                num_detections = len(result.boxes)
                total_detections += num_detections
                
                # 收集置信度分数 Collect confidence scores
                for box in result.boxes:
                    conf = box.conf[0].item()
                    confidence_scores.append(conf)
                    
        except Exception as e:
            print(f"评估图像 {i+1} 时出错 Error evaluating image {i+1}: {e}")
    
    # 计算统计信息 Calculate statistics
    avg_detections_per_image = total_detections / total_images if total_images > 0 else 0
    avg_confidence = np.mean(confidence_scores) if confidence_scores else 0
    max_confidence = np.max(confidence_scores) if confidence_scores else 0
    min_confidence = np.min(confidence_scores) if confidence_scores else 0
    
    print("\n📈 模型性能统计 Model Performance Statistics")
    print("=" * 50)
    print(f"总测试图像数量 Total test images: {total_images}")
    print(f"总检测数量 Total detections: {total_detections}")
    print(f"平均每张图像检测数量 Average detections per image: {avg_detections_per_image:.2f}")
    print(f"平均置信度 Average confidence: {avg_confidence:.3f}")
    print(f"最高置信度 Maximum confidence: {max_confidence:.3f}")
    print(f"最低置信度 Minimum confidence: {min_confidence:.3f}")
    
    # 绘制置信度分布图 Plot confidence distribution
    if confidence_scores:
        plt.figure(figsize=(10, 6))
        plt.hist(confidence_scores, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
        plt.title('🎯 检测置信度分布 Detection Confidence Distribution', fontsize=14, fontweight='bold')
        plt.xlabel('置信度 Confidence Score')
        plt.ylabel('频次 Frequency')
        plt.grid(True, alpha=0.3)
        plt.show()
    
    return {
        'total_images': total_images,
        'total_detections': total_detections,
        'avg_detections_per_image': avg_detections_per_image,
        'confidence_scores': confidence_scores,
        'avg_confidence': avg_confidence
    }

# 如果有测试图像，进行性能评估 If test images exist, perform performance evaluation
if model is not None:
    test_image_dir = Path('data/processed/images/test')
    if test_image_dir.exists():
        test_images = list(test_image_dir.glob('*.jpg'))
        if test_images:
            performance_stats = evaluate_model_performance(model, test_images)
        else:
            # 使用原始评估图像 Use original evaluation images
            eval_images = list(Path('data/raw/evaluation').rglob('*.jpg'))
            if eval_images:
                performance_stats = evaluate_model_performance(model, eval_images[:20])  # 限制为20张图像
    else:
        print("⚠️  跳过性能评估，未找到测试图像 Skipping performance evaluation, no test images found")

## 📝 项目总结 Project Summary

### 完成的功能 Completed Features
1. ✅ **环境检测** - 自动检测GPU和Colab环境
2. ✅ **数据下载** - 自动下载NeonTreeEvaluation数据集
3. ✅ **数据转换** - 将NeonTree格式转换为YOLO格式
4. ✅ **模型训练** - 使用YOLOv5进行树木检测训练
5. ✅ **结果推理** - 对测试图像进行树木检测
6. ✅ **结果可视化** - 显示检测结果和性能统计

### 主要参数说明 Key Parameter Explanations
- **训练轮数 Epochs**: 50轮（可根据效果调整）
- **批次大小 Batch Size**: 16（根据GPU内存调整）
- **图像尺寸 Image Size**: 640x640像素
- **学习率 Learning Rate**: 0.001（AdamW优化器）
- **数据分割 Data Split**: 70%训练/20%验证/10%测试

### 使用说明 Usage Instructions
1. 在Google Colab中运行所有代码单元
2. 确保GPU环境可用以加速训练
3. 根据需要调整训练参数
4. 查看results目录中的检测结果图像

### 故障排除 Troubleshooting
- 如果内存不足，减小batch_size参数
- 如果训练时间过长，减少epochs数量
- 如果检测效果不佳，尝试增加训练轮数
- 确保数据集正确下载和解压

### 进一步改进 Further Improvements
- 数据增强技术提升模型泛化能力
- 超参数调优优化模型性能
- 多尺度训练提高检测精度
- 模型集成提升整体效果