# 🌲 树木检测项目 - Tree Detection Project

## 项目概述 Project Overview

本项目使用YOLOv5和PyTorch实现基于航空影像的树木检测任务  
This project implements tree detection from aerial imagery using YOLOv5 and PyTorch

**数据集 Dataset**: NeonTreeEvaluation Benchmark  
**目标 Goal**: 检测航空正射影像中的树木 - Detect trees in aerial orthoimagery  
**平台 Platform**: Google Colab (自动GPU检测 - Auto GPU detection)

## 技术栈 Tech Stack
- **深度学习框架 Deep Learning**: PyTorch + YOLOv5
- **数据处理 Data Processing**: OpenCV, PIL, pandas
- **可视化 Visualization**: matplotlib, seaborn

In [None]:
# 1. 环境检测和基础设置 - Environment Detection and Basic Setup
import os
import sys
import torch
import torchvision
import platform
from pathlib import Path

print("=" * 50)
print("🚀 环境检测 Environment Detection")
print("=" * 50)

# 检查Python版本 Check Python version
print(f"Python版本 Python Version: {sys.version}")

# 检查操作系统 Check OS
print(f"操作系统 Operating System: {platform.system()}")

# 检查是否在Colab环境 Check if running in Colab
try:
    import google.colab
    IN_COLAB = True
    print("✅ 运行环境: Google Colab - Running in Google Colab")
except ImportError:
    IN_COLAB = False
    print("❌ 运行环境: 本地环境 - Running locally")

# 检查GPU可用性 Check GPU availability
if torch.cuda.is_available():
    print(f"✅ GPU可用 GPU Available: {torch.cuda.get_device_name(0)}")
    print(f"   CUDA版本 CUDA Version: {torch.version.cuda}")
    print(f"   GPU数量 GPU Count: {torch.cuda.device_count()}")
    device = torch.device('cuda')
else:
    print("⚠️  GPU不可用，将使用CPU - GPU not available, using CPU")
    device = torch.device('cpu')

print(f"PyTorch版本 PyTorch Version: {torch.__version__}")
print(f"设备 Device: {device}")
print("=" * 50)

: 

In [None]:
# 2. 安装必要依赖 - Install Required Dependencies
print("📦 安装依赖包 Installing Dependencies...")

# 安装YOLOv5和相关依赖 Install YOLOv5 and dependencies
!pip install -q ultralytics
!pip install -q opencv-python-headless
!pip install -q Pillow
!pip install -q matplotlib
!pip install -q seaborn
!pip install -q pandas
!pip install -q tqdm
!pip install -q scikit-learn
!pip install -q PyYAML

# 如果在Colab环境，挂载Google Drive (可选)
# Mount Google Drive in Colab (optional)
if IN_COLAB:
    from google.colab import drive
    # drive.mount('/content/drive')  # 取消注释以挂载Drive - Uncomment to mount Drive

print("✅ 依赖安装完成 Dependencies installed successfully!")

In [None]:
# 3. 导入所需库 - Import Required Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from PIL import Image
import json
import zipfile
import requests
from tqdm import tqdm
import shutil
import yaml
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

# 设置matplotlib中文字体支持 Set matplotlib Chinese font support
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

# 设置随机种子 Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

print("✅ 库导入完成 Libraries imported successfully!")

In [None]:
# 4. 数据下载和解压 - Data Download and Extraction
def download_file(url, filename):
    """
    下载文件的函数 Function to download files
    """
    print(f"🔄 开始下载 Starting download: {filename}")
    
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    
    with open(filename, 'wb') as file, tqdm(
        desc=filename,
        total=total_size,
        unit='B',
        unit_scale=True,
        unit_divisor=1024,
    ) as pbar:
        for chunk in response.iter_content(chunk_size=8192):
            if chunk:
                file.write(chunk)
                pbar.update(len(chunk))
    
    print(f"✅ 下载完成 Download completed: {filename}")

# 创建数据目录 Create data directories
os.makedirs('data', exist_ok=True)
os.makedirs('data/raw', exist_ok=True)
os.makedirs('data/processed', exist_ok=True)

# 数据集URL Dataset URLs
DATASET_URL = "https://zenodo.org/records/5914554/files/evaluation.zip?download=1"
ANNOTATIONS_URL = "https://zenodo.org/records/5914554/files/annotations.zip?download=1"

# 下载数据集 Download datasets
if not os.path.exists('data/raw/evaluation.zip'):
    download_file(DATASET_URL, 'data/raw/evaluation.zip')
else:
    print("✅ 评估数据集已存在 Evaluation dataset already exists")

if not os.path.exists('data/raw/annotations.zip'):
    download_file(ANNOTATIONS_URL, 'data/raw/annotations.zip')
else:
    print("✅ 标注数据已存在 Annotations already exist")

print("📁 数据下载完成 Data download completed!")

In [None]:
# 5. 解压数据集 - Extract Datasets
def extract_zip(zip_path, extract_to):
    """
    解压ZIP文件的函数 Function to extract ZIP files
    """
    print(f"📦 解压文件 Extracting: {zip_path}")
    
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    
    print(f"✅ 解压完成 Extraction completed: {extract_to}")

# 解压评估数据集 Extract evaluation dataset
if not os.path.exists('data/raw/evaluation'):
    extract_zip('data/raw/evaluation.zip', 'data/raw/')
else:
    print("✅ 评估数据集已解压 Evaluation dataset already extracted")

# 解压标注数据 Extract annotations
if not os.path.exists('data/raw/annotations'):
    extract_zip('data/raw/annotations.zip', 'data/raw/')
else:
    print("✅ 标注数据已解压 Annotations already extracted")

# 查看数据结构 Explore data structure
print("\n📊 数据结构分析 Data Structure Analysis")
print("=" * 50)

# 检查评估数据集结构 Check evaluation dataset structure
eval_path = Path('data/raw/evaluation')
if eval_path.exists():
    print(f"评估数据集路径 Evaluation dataset path: {eval_path}")
    subdirs = [d for d in eval_path.iterdir() if d.is_dir()]
    print(f"子目录数量 Number of subdirectories: {len(subdirs)}")
    for subdir in subdirs[:5]:  # 显示前5个子目录 Show first 5 subdirectories
        print(f"  - {subdir.name}")
    if len(subdirs) > 5:
        print(f"  ... 还有 {len(subdirs)-5} 个目录 and {len(subdirs)-5} more directories")

# 检查标注数据结构 Check annotations structure
ann_path = Path('data/raw/annotations')
if ann_path.exists():
    print(f"\n标注数据路径 Annotations path: {ann_path}")
    ann_files = list(ann_path.glob('*.csv'))
    print(f"CSV标注文件数量 Number of CSV annotation files: {len(ann_files)}")
    for ann_file in ann_files[:3]:  # 显示前3个标注文件 Show first 3 annotation files
        print(f"  - {ann_file.name}")

print("\n✅ 数据解压和结构分析完成 Data extraction and structure analysis completed!")

In [None]:
# 6. 数据格式转换 - Data Format Conversion (NeonTree -> YOLO) - 使用真实图像
import xml.etree.ElementTree as ET

class NeonTreeToYOLO:
    """
    将NeonTree数据集转换为YOLO格式的类
    Class to convert NeonTree dataset to YOLO format
    """
    
    def __init__(self, annotations_root, evaluation_root, output_root):
        self.annotations_root = Path(annotations_root)
        self.evaluation_root = Path(evaluation_root)
        self.output_root = Path(output_root)
        
        # 创建输出目录 Create output directories
        self.create_yolo_structure()
        
        # 创建图像文件名映射 Create image filename mapping
        self.create_image_mapping()
    
    def create_yolo_structure(self):
        """创建YOLO数据集目录结构 Create YOLO dataset directory structure"""
        folders = [
            'images/train', 'images/val', 'images/test',
            'labels/train', 'labels/val', 'labels/test'
        ]
        
        for folder in folders:
            (self.output_root / folder).mkdir(parents=True, exist_ok=True)
        
        print("✅ YOLO目录结构创建完成 YOLO directory structure created")
    
    def create_image_mapping(self):
        """
        创建图像文件名到路径的映射
        Create mapping from image filenames to paths
        """
        self.image_mapping = {}
        
        # 检查 evaluation/RGB 目录
        rgb_dir = self.evaluation_root / 'RGB'
        if rgb_dir.exists():
            for img_file in rgb_dir.glob('*.tif'):
                # 去掉扩展名作为key
                base_name = img_file.stem
                self.image_mapping[base_name] = img_file
                
                # 同时添加不带扩展名的映射
                if base_name.endswith('.tif'):
                    base_name_no_ext = base_name[:-4]
                    self.image_mapping[base_name_no_ext] = img_file
        
        print(f"📋 创建图像映射 Created image mapping with {len(self.image_mapping)} entries")
    
    def find_matching_image(self, xml_filename):
        """
        根据XML文件名查找对应的RGB图像文件
        Find matching RGB image file based on XML filename
        """
        # 从XML文件名中提取基础名称
        base_name = xml_filename.replace('.xml', '')
        
        # 直接匹配
        if base_name in self.image_mapping:
            return self.image_mapping[base_name]
        
        # 尝试模糊匹配 - 查找包含base_name的图像文件
        for img_name, img_path in self.image_mapping.items():
            # 检查是否匹配（去掉年份等变化部分）
            if self.files_match(base_name, img_name):
                return img_path
        
        return None
    
    def files_match(self, xml_name, img_name):
        """
        判断XML文件名和图像文件名是否匹配
        Check if XML filename matches image filename
        """
        # 移除常见的后缀和前缀
        xml_clean = xml_name.lower()
        img_clean = img_name.lower()
        
        # 提取核心站点代码
        xml_parts = xml_clean.split('_')
        img_parts = img_clean.split('_')
        
        # 如果站点代码匹配
        if len(xml_parts) >= 2 and len(img_parts) >= 2:
            xml_site = xml_parts[0]
            img_site = img_parts[0]
            
            if xml_site == img_site:
                # 进一步检查编号匹配
                if len(xml_parts) >= 3 and len(img_parts) >= 2:
                    try:
                        xml_num = xml_parts[1]
                        img_num = img_parts[1]
                        return xml_num == img_num
                    except:
                        pass
        
        # 检查是否包含相同的关键部分
        return any(part in img_clean for part in xml_parts if len(part) > 2)
    
    def parse_xml_annotation(self, xml_file):
        """
        解析XML标注文件
        Parse XML annotation file
        """
        try:
            tree = ET.parse(xml_file)
            root = tree.getroot()
            
            # 获取图像信息 Get image information
            filename_elem = root.find('filename')
            if filename_elem is not None:
                filename = filename_elem.text
            else:
                filename = xml_file.stem + '.tif'  # 默认扩展名
            
            size = root.find('size')
            width = int(size.find('width').text)
            height = int(size.find('height').text)
            
            # 获取所有树木对象 Get all tree objects
            objects = []
            for obj in root.findall('object'):
                name_elem = obj.find('name')
                if name_elem is not None and name_elem.text.lower() == 'tree':
                    bndbox = obj.find('bndbox')
                    xmin = int(float(bndbox.find('xmin').text))
                    ymin = int(float(bndbox.find('ymin').text))
                    xmax = int(float(bndbox.find('xmax').text))
                    ymax = int(float(bndbox.find('ymax').text))
                    
                    # 验证边界框有效性 Validate bounding box
                    if xmax > xmin and ymax > ymin:
                        objects.append({
                            'xmin': xmin,
                            'ymin': ymin,
                            'xmax': xmax,
                            'ymax': ymax
                        })
            
            return {
                'filename': filename,
                'width': width,
                'height': height,
                'objects': objects
            }
            
        except Exception as e:
            print(f"解析XML文件时出错 Error parsing XML file {xml_file}: {e}")
            return None
    
    def convert_bbox_to_yolo(self, bbox, img_width, img_height):
        """
        将边界框坐标转换为YOLO格式
        Convert bounding box coordinates to YOLO format
        
        YOLO格式: [class_id, x_center, y_center, width, height] (归一化 normalized)
        """
        xmin, ymin, xmax, ymax = bbox['xmin'], bbox['ymin'], bbox['xmax'], bbox['ymax']
        
        # 计算中心点和宽高 Calculate center point and dimensions
        x_center = (xmin + xmax) / 2.0
        y_center = (ymin + ymax) / 2.0
        width = xmax - xmin
        height = ymax - ymin
        
        # 归一化 Normalize
        x_center /= img_width
        y_center /= img_height
        width /= img_width
        height /= img_height
        
        return [0, x_center, y_center, width, height]  # 类别ID为0 (树木 tree)
    
    def process_annotations(self):
        """
        处理所有XML标注文件并转换为YOLO格式
        Process all XML annotation files and convert to YOLO format
        """
        xml_files = list(self.annotations_root.glob('*.xml'))
        processed_count = 0
        skipped_count = 0
        
        print(f"📝 找到 {len(xml_files)} 个XML标注文件 Found {len(xml_files)} XML annotation files")
        
        for i, xml_file in enumerate(tqdm(xml_files, desc="处理标注 Processing annotations")):
            try:
                # 解析XML文件 Parse XML file
                annotation_data = self.parse_xml_annotation(xml_file)
                
                if annotation_data is None or not annotation_data['objects']:
                    skipped_count += 1
                    continue
                
                # 查找对应的RGB图像文件 Find corresponding RGB image file
                img_path = self.find_matching_image(xml_file.name)
                
                if img_path is None:
                    # 打印调试信息，但不中断处理 Print debug info but don't interrupt
                    if processed_count < 5:  # 只打印前5个错误
                        print(f"⚠️  未找到对应图像 No matching image for: {xml_file.name}")
                    skipped_count += 1
                    continue
                
                # 读取真实图像 Load real image
                img = cv2.imread(str(img_path))
                if img is None:
                    print(f"⚠️  无法读取图像 Cannot read image: {img_path}")
                    skipped_count += 1
                    continue
                
                # 验证图像尺寸 Verify image dimensions
                actual_height, actual_width = img.shape[:2]
                if actual_width != annotation_data['width'] or actual_height != annotation_data['height']:
                    # 使用实际图像尺寸 Use actual image dimensions
                    annotation_data['width'] = actual_width
                    annotation_data['height'] = actual_height
                
                # 确定数据集分割 Determine dataset split
                if i % 10 < 7:  # 70% 训练集 training set
                    split = 'train'
                elif i % 10 < 9:  # 20% 验证集 validation set
                    split = 'val'
                else:  # 10% 测试集 test set
                    split = 'test'
                
                # 保存图像 Save image
                img_filename = f"tree_{processed_count:06d}.jpg"
                img_save_path = self.output_root / 'images' / split / img_filename
                cv2.imwrite(str(img_save_path), img)
                
                # 保存YOLO格式标签 Save YOLO format labels
                label_filename = f"tree_{processed_count:06d}.txt"
                label_save_path = self.output_root / 'labels' / split / label_filename
                
                with open(label_save_path, 'w') as f:
                    for obj in annotation_data['objects']:
                        yolo_bbox = self.convert_bbox_to_yolo(
                            obj, annotation_data['width'], annotation_data['height']
                        )
                        # YOLO格式: class_id x_center y_center width height (修复换行符)
                        f.write(f"{yolo_bbox[0]} {yolo_bbox[1]:.6f} {yolo_bbox[2]:.6f} {yolo_bbox[3]:.6f} {yolo_bbox[4]:.6f}\n")
                
                processed_count += 1
                
            except Exception as e:
                print(f"处理文件 {xml_file.name} 时出错 Error processing file {xml_file.name}: {e}")
                skipped_count += 1
                continue
        
        print(f"✅ 处理完成 Processing completed: {processed_count} 个样本 samples processed, {skipped_count} 个跳过 skipped")
        return processed_count

# 执行数据转换 Execute data conversion
print("🔄 开始数据格式转换 Starting data format conversion...")
print("📋 使用XML格式的标注数据和真实RGB图像 Using XML annotations with real RGB images")

# 创建data目录 Create data directories
os.makedirs('data', exist_ok=True)
os.makedirs('data/processed', exist_ok=True)

# 检查annotations和evaluation数据是否存在 Check if annotations and evaluation data exist
annotation_paths = ['data/raw/annotations', 'annotations', '/content/annotations']
eval_paths = ['data/raw/evaluation', 'evaluation', '/content/evaluation']

annotations_root = None
eval_root = None

# 查找annotations目录 Find annotations directory
for path in annotation_paths:
    if Path(path).exists():
        annotations_root = Path(path)
        xml_count = len(list(annotations_root.glob('*.xml')))
        if xml_count > 0:
            print(f"✅ 找到annotations数据集 Found annotations dataset at: {annotations_root} ({xml_count} XML files)")
            break

if annotations_root is None:
    print("❌ 未找到annotations数据集或XML文件")
    print("📁 请确保annotations目录存在且包含XML文件")

# 查找evaluation目录 Find evaluation directory  
for path in eval_paths:
    if Path(path).exists():
        eval_root = Path(path)
        # 检查RGB子目录
        rgb_dir = eval_root / 'RGB'
        if rgb_dir.exists():
            img_count = len(list(rgb_dir.glob('*.tif')))
            print(f"✅ 找到evaluation数据集 Found evaluation dataset at: {eval_root} ({img_count} images)")
            break

if eval_root is None:
    print("❌ 未找到evaluation数据集，请确保已下载并解压evaluation.zip")

# 只有当两个数据集都找到时才进行转换 Only proceed if both datasets are found
if annotations_root is not None and eval_root is not None:
    # 初始化转换器 Initialize converter
    converter = NeonTreeToYOLO(annotations_root, eval_root, 'data/processed')
    
    # 处理标注数据 Process annotation data
    processed_samples = converter.process_annotations()
    
    if processed_samples > 0:
        print(f"✅ 数据转换完成 Data conversion completed: {processed_samples} 个样本 samples")
        
        # 显示数据集统计 Show dataset statistics
        train_images = len(list(Path('data/processed/images/train').glob('*.jpg')))
        val_images = len(list(Path('data/processed/images/val').glob('*.jpg')))
        test_images = len(list(Path('data/processed/images/test').glob('*.jpg')))
        
        print(f"📊 数据集统计 Dataset Statistics:")
        print(f"   训练集 Training: {train_images} 张图像")
        print(f"   验证集 Validation: {val_images} 张图像")
        print(f"   测试集 Testing: {test_images} 张图像")
        print(f"   总计 Total: {train_images + val_images + test_images} 张图像")
    else:
        print("❌ 未能处理任何样本 No samples were processed")
        print("请检查XML文件和图像文件的匹配关系")
else:
    print("❌ 缺少必要的数据集，无法进行转换")
    print("请确保以下目录存在：")
    print("  - annotations 目录（包含XML文件）")
    print("  - evaluation/RGB 目录（包含.tif图像文件）")

In [None]:
# 7. 创建YOLO配置文件 - Create YOLO Configuration Files (修复路径问题)
import os
from pathlib import Path
import yaml

def create_dataset_yaml():
    """
    创建YOLO数据集配置文件
    Create YOLO dataset configuration file
    """
    # 获取当前工作目录 Get current working directory
    current_dir = os.getcwd()
    
    # 构建绝对路径 Build absolute paths
    processed_dir = os.path.join(current_dir, 'data/processed')
    train_path = os.path.join(processed_dir, 'images/train')
    val_path = os.path.join(processed_dir, 'images/val')
    test_path = os.path.join(processed_dir, 'images/test')
    
    # 检查路径是否存在 Check if paths exist
    paths_info = {
        'train': train_path,
        'val': val_path,
        'test': test_path
    }
    
    print("📋 检查数据集路径 Checking dataset paths:")
    existing_paths = {}
    
    for split, path in paths_info.items():
        if os.path.exists(path):
            img_count = len([f for f in os.listdir(path) if f.endswith('.jpg')])
            print(f"✅ {split}: {path} ({img_count} 张图像 images)")
            existing_paths[split] = path
        else:
            print(f"❌ {split}: {path} (路径不存在 path does not exist)")
    
    # 确保至少有训练集存在 Ensure at least training set exists
    if 'train' not in existing_paths:
        print("❌ 错误：未找到训练集 Error: Training set not found")
        return None
    
    # 创建数据集配置 Create dataset configuration
    dataset_config = {
        'path': processed_dir,  # 数据集根目录 dataset root dir
        'train': 'images/train',  # 相对于path的训练图像路径 train images relative to path
        'val': 'images/val' if 'val' in existing_paths else 'images/train',  # 验证集，如果不存在则使用训练集
        'test': 'images/test' if 'test' in existing_paths else 'images/train',  # 测试集，如果不存在则使用训练集
        'nc': 1,                  # 类别数量 number of classes
        'names': ['tree']         # 类别名称 class names
    }
    
    # 特别处理验证集路径
    if 'val' not in existing_paths:
        print("⚠️  验证集不存在，使用训练集作为验证集 Validation set not found, using training set as validation")
        dataset_config['val'] = 'images/train'
    
    if 'test' not in existing_paths:
        print("⚠️  测试集不存在，使用训练集作为测试集 Test set not found, using training set as test")
        dataset_config['test'] = 'images/train'
    
    # 确保配置目录存在
    os.makedirs('data', exist_ok=True)
    
    # 保存配置文件 Save configuration file
    config_path = 'data/tree_dataset.yaml'
    with open(config_path, 'w') as f:
        yaml.dump(dataset_config, f, default_flow_style=False)
    
    print(f"✅ YOLO数据集配置文件创建完成 YOLO dataset configuration file created: {config_path}")
    
    # 验证配置文件内容
    print("\n🔍 验证配置文件内容 Verifying configuration file:")
    for key, value in dataset_config.items():
        if key in ['train', 'val', 'test']:
            full_path = os.path.join(processed_dir, value)
            exists = os.path.exists(full_path)
            print(f"   {key}: {value} -> {full_path} ({'✅' if exists else '❌'})")
        else:
            print(f"   {key}: {value}")
    
    return config_path

def fix_data_structure():
    """
    修复数据结构问题
    Fix data structure issues
    """
    print("\n🔧 检查并修复数据结构 Checking and fixing data structure...")
    
    processed_dir = Path('data/processed')
    
    if not processed_dir.exists():
        print("❌ data/processed 目录不存在 data/processed directory does not exist")
        return False
    
    # 检查必要的目录结构
    required_dirs = [
        'images/train', 'images/val', 'images/test',
        'labels/train', 'labels/val', 'labels/test'
    ]
    
    missing_dirs = []
    for dir_path in required_dirs:
        full_path = processed_dir / dir_path
        if not full_path.exists():
            missing_dirs.append(dir_path)
    
    if missing_dirs:
        print(f"⚠️  缺少目录 Missing directories: {missing_dirs}")
        
        # 如果只是验证集和测试集缺失，创建它们
        for missing_dir in missing_dirs:
            if 'val' in missing_dir or 'test' in missing_dir:
                (processed_dir / missing_dir).mkdir(parents=True, exist_ok=True)
                print(f"✅ 创建目录 Created directory: {missing_dir}")
    
    # 检查训练集是否有数据
    train_images = list((processed_dir / 'images/train').glob('*.jpg'))
    train_labels = list((processed_dir / 'labels/train').glob('*.txt'))
    
    print(f"📊 训练集统计 Training set statistics:")
    print(f"   图像文件 Image files: {len(train_images)}")
    print(f"   标签文件 Label files: {len(train_labels)}")
    
    if len(train_images) == 0:
        print("❌ 训练集为空，请先运行数据转换 Training set is empty, please run data conversion first")
        return False
    
    return True

# 修复数据结构
data_structure_ok = fix_data_structure()

if data_structure_ok:
    # 创建配置文件 Create configuration file
    dataset_yaml_path = create_dataset_yaml()
    
    if dataset_yaml_path:
        # 显示配置文件内容 Display configuration file content
        with open(dataset_yaml_path, 'r') as f:
            config_content = f.read()
            print("\n📋 数据集配置内容 Dataset Configuration Content:")
            print("=" * 40)
            print(config_content)
            print("=" * 40)
        
        # 最终验证
        print("\n✅ 配置文件创建并验证完成 Configuration file created and verified")
    else:
        print("❌ 配置文件创建失败 Failed to create configuration file")
else:
    print("❌ 数据结构修复失败，请检查数据转换步骤 Data structure fix failed, please check data conversion step")

In [ ]:
# 8. YOLOv5专业树木检测训练 - Professional YOLOv5 Tree Detection Training (修复参数)

import subprocess
import sys
import yaml
import os
import torch
import gc
from pathlib import Path

print("🌲 YOLOv5专业树木检测训练")
print("💡 基于2024年最新YOLOv5最佳实践")
print("🚀 完全放弃best.pt，使用官方预训练模型从零开始")
print("🔧 修复训练命令参数格式")
print("=" * 70)

def memory_cleanup():
    """内存清理"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def create_professional_dataset_config():
    """创建专业的数据集配置"""
    print("📝 创建专业YOLOv5数据集配置...")
    
    # 检查数据集路径
    dataset_path = os.path.join(os.getcwd(), 'data/processed')
    if not os.path.exists(dataset_path):
        print(f"❌ 数据集路径不存在: {dataset_path}")
        return None
    
    # 专业配置 - 遵循YOLOv5官方标准
    professional_config = {
        'path': dataset_path,
        'train': 'images/train',
        'val': 'images/train',  # 使用训练集作为验证集（小数据集最佳实践）
        'test': 'images/test',
        'nc': 1,
        'names': ['tree']
    }
    
    # 检查训练数据
    train_path = os.path.join(dataset_path, 'images/train')
    
    if os.path.exists(train_path):
        train_count = len([f for f in os.listdir(train_path) if f.endswith('.jpg')])
        print(f"   训练集图像: {train_count}")
        
        if train_count == 0:
            print("❌ 训练数据为空")
            return None
    else:
        print("❌ 训练数据目录不存在")
        return None
    
    # 保存专业配置
    config_path = 'tree_detection_professional.yaml'
    with open(config_path, 'w') as f:
        yaml.dump(professional_config, f, default_flow_style=False)
    
    print(f"✅ 专业配置文件: {config_path}")
    return config_path

# 执行内存清理
memory_cleanup()

# 确保YOLOv5环境
if not os.path.exists('yolov5'):
    print("📦 克隆官方YOLOv5仓库...")
    !git clone https://github.com/ultralytics/yolov5.git
    print("✅ YOLOv5仓库克隆完成")

# 安装/更新依赖
print("📦 检查YOLOv5依赖...")
try:
    os.chdir('yolov5')
    !pip install -r requirements.txt --quiet
    os.chdir('..')
    print("✅ YOLOv5依赖检查完成")
except Exception as e:
    print(f"⚠️ 依赖安装警告: {e}")

# 创建专业数据集配置
dataset_yaml = create_professional_dataset_config()

if dataset_yaml is None:
    print("❌ 无法创建数据集配置，训练终止")
else:
    print("\n🚀 开始YOLOv5专业训练...")
    
    # 根据硬件配置设置训练参数
    if torch.cuda.is_available():
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
        device_name = torch.cuda.get_device_name(0)
        print(f"🎮 GPU: {device_name} ({gpu_memory:.1f} GB)")
        
        # 根据GPU内存优化参数
        if gpu_memory >= 20:  # A100, RTX 4090等
            model_size = 'yolov5l'
            batch_size = 24
            img_size = 640
            epochs = 200
            workers = 6
        elif gpu_memory >= 15:  # L4, RTX 3090等
            model_size = 'yolov5m'
            batch_size = 20
            img_size = 640
            epochs = 150
            workers = 4
        elif gpu_memory >= 8:   # RTX 3070, V100等
            model_size = 'yolov5s'
            batch_size = 16
            img_size = 640
            epochs = 120
            workers = 2
        else:                   # GTX 1080, RTX 2070等
            model_size = 'yolov5s'
            batch_size = 8
            img_size = 512
            epochs = 100
            workers = 2
            
        device_param = '0'
    else:
        print("💻 使用CPU训练 (不推荐)")
        model_size = 'yolov5s'
        batch_size = 4
        img_size = 512
        epochs = 50
        workers = 1
        device_param = 'cpu'
    
    print(f"\n⚙️  专业训练配置:")
    print(f"   模型: {model_size}.pt (官方预训练)")
    print(f"   图像尺寸: {img_size}")
    print(f"   批次大小: {batch_size}")
    print(f"   训练轮数: {epochs}")
    print(f"   数据加载线程: {workers}")
    print(f"   设备: {device_param}")
    
    # 构建修复的专业训练命令
    professional_command = [
        sys.executable, 'yolov5/train.py',
        '--data', dataset_yaml,
        '--weights', f'{model_size}.pt',        # 官方预训练权重
        '--epochs', str(epochs),
        '--batch-size', str(batch_size),
        '--imgsz', str(img_size),
        '--device', device_param,
        '--workers', str(workers),
        '--project', 'runs/train',
        '--name', 'tree_detection_professional',
        '--optimizer', 'SGD',                   # 推荐的优化器
        '--patience', '30',                     # 早停耐心值
        '--save-period', '25',                  # 定期保存
        '--cache',                              # 缓存图像
        '--exist-ok'
    ]
    
    # 高端GPU的额外优化
    if torch.cuda.is_available() and gpu_memory >= 8:
        professional_command.append('--multi-scale')  # 多尺度训练
        if gpu_memory >= 15:
            professional_command.append('--amp')       # 混合精度训练
    
    print(f"\n🔥 修复后的训练命令:")
    cmd_str = ' '.join(professional_command)
    print(f"   {cmd_str}")
    
    print(f"\n🚀 开始专业训练...")
    print(f"💡 策略: 官方预训练模型 + 简化稳定参数")
    print(f"🎯 目标: 高质量树木检测模型")
    print("=" * 70)
    
    # 执行专业训练
    try:
        memory_cleanup()  # 训练前清理
        
        print("⏳ 训练进行中，请耐心等待...")
        result = subprocess.run(
            professional_command,
            capture_output=True,
            text=True,
            timeout=7200  # 2小时超时
        )
        
        if result.returncode == 0:
            print("🎉 专业训练成功完成!")
            
            # 检查训练结果
            model_dir = 'runs/train/tree_detection_professional/weights'
            if os.path.exists(model_dir):
                best_model = os.path.join(model_dir, 'best.pt')
                last_model = os.path.join(model_dir, 'last.pt')
                
                if os.path.exists(best_model):
                    model_size_mb = os.path.getsize(best_model) / (1024 * 1024)
                    print(f"✅ 最佳模型: {best_model} ({model_size_mb:.1f}MB)")
                    final_model_path = best_model
                elif os.path.exists(last_model):
                    model_size_mb = os.path.getsize(last_model) / (1024 * 1024)
                    print(f"✅ 最新模型: {last_model} ({model_size_mb:.1f}MB)")
                    final_model_path = last_model
                else:
                    final_model_path = None
                    print("⚠️ 未找到训练模型")
                
                # 显示训练总结
                print("\n📊 训练输出总结:")
                output_lines = result.stdout.split('\n')
                key_lines = []
                for line in output_lines:
                    if any(keyword in line.lower() for keyword in 
                          ['results', 'best', 'map', 'precision', 'recall', 'fitness', 'epoch']):
                        key_lines.append(line)
                
                # 显示最后15行关键信息
                for line in key_lines[-15:]:
                    if line.strip():
                        print(f"   {line}")
                        
            else:
                print("❌ 训练结果目录不存在")
                final_model_path = None
                
        else:
            print(f"❌ 训练失败，返回码: {result.returncode}")
            print("\n错误输出:")
            if result.stderr:
                print(result.stderr[:1500])
            
            print("\n标准输出片段:")
            if result.stdout:
                stdout_lines = result.stdout.split('\n')
                for line in stdout_lines[-20:]:
                    if line.strip():
                        print(f"   {line}")
            
            final_model_path = None
            
    except subprocess.TimeoutExpired:
        print("⏰ 训练超时，检查部分结果...")
        model_dir = 'runs/train/tree_detection_professional/weights'
        if os.path.exists(os.path.join(model_dir, 'last.pt')):
            final_model_path = os.path.join(model_dir, 'last.pt')
            print(f"✅ 找到部分训练模型: {final_model_path}")
        else:
            final_model_path = None
            
    except Exception as e:
        print(f"❌ 训练异常: {e}")
        final_model_path = None
    
    # 最终清理和总结
    memory_cleanup()
    
    print("\n🎯 YOLOv5专业训练完成")
    if final_model_path:
        print(f"✅ 专业训练模型: {final_model_path}")
        print("🔥 基于官方预训练模型和稳定参数")
        print("📈 使用简化配置确保兼容性")
        print("🚀 运行Cell 9进行模型推理测试")
    else:
        print("❌ 训练未成功完成")
        print("🔧 可能的解决方案:")
        print("   1. 检查数据集格式和路径")
        print("   2. 确认GPU内存充足")
        print("   3. 检查YOLOv5环境安装")
        print("   4. 尝试降低batch_size参数")
    
    print("=" * 70)

In [None]:
# 9. 专业模型推理测试 - Professional Model Inference Testing

import torch
import gc
import os
import sys
import subprocess
import time
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import cv2

def comprehensive_cleanup():
    """综合内存清理"""
    try:
        print("🧹 执行内存清理...")
        
        # Python垃圾回收
        collected = gc.collect()
        print(f"   垃圾回收释放对象: {collected}")
        
        # CUDA缓存清理
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            torch.cuda.ipc_collect()
            torch.cuda.reset_peak_memory_stats()
            
            allocated = torch.cuda.memory_allocated() / 1024**3
            reserved = torch.cuda.memory_reserved() / 1024**3
            print(f"   CUDA内存: 已分配 {allocated:.1f}GB, 已保留 {reserved:.1f}GB")
        
        return True
        
    except Exception as e:
        print(f"⚠️ 清理过程警告: {e}")
        return False

def professional_inference(model_path, test_images, confidence_threshold=0.25):
    """专业模型推理"""
    print(f"🔍 专业模型推理测试: {model_path}")
    print(f"📊 测试图像数量: {len(test_images)}")
    print(f"🎯 置信度阈值: {confidence_threshold}")
    
    # 创建输出目录
    output_dir = 'professional_inference_results'
    os.makedirs(output_dir, exist_ok=True)
    
    # 构建推理命令
    inference_command = [
        sys.executable, 'yolov5/detect.py',
        '--weights', model_path,
        '--source', str(test_images[0].parent),  # 使用图像目录
        '--project', output_dir,
        '--name', 'detect',
        '--img', '640',
        '--conf', str(confidence_threshold),
        '--iou', '0.45',
        '--max-det', '1000',
        '--save-txt',
        '--save-conf',
        '--exist-ok'
    ]
    
    # 添加设备参数
    if torch.cuda.is_available():
        inference_command.extend(['--device', '0'])
    else:
        inference_command.extend(['--device', 'cpu'])
    
    print(f"\n🚀 执行推理命令:")
    print(f"   {' '.join(inference_command)}")
    
    # 执行推理
    start_time = time.time()
    try:
        result = subprocess.run(
            inference_command,
            capture_output=True,
            text=True,
            timeout=300  # 5分钟超时
        )
        
        inference_time = time.time() - start_time
        
        if result.returncode == 0:
            print(f"✅ 推理成功完成 (耗时: {inference_time:.1f}s)")
            
            # 分析推理结果
            results = analyze_inference_results(output_dir)
            return results
            
        else:
            print(f"❌ 推理失败: {result.stderr[:500]}")
            return None
            
    except subprocess.TimeoutExpired:
        print("⏰ 推理超时")
        return None
    except Exception as e:
        print(f"❌ 推理异常: {e}")
        return None

def analyze_inference_results(output_dir):
    """分析推理结果"""
    print("\n📊 分析推理结果...")
    
    results_dir = Path(output_dir) / 'detect'
    if not results_dir.exists():
        print("❌ 推理结果目录不存在")
        return None
    
    # 统计检测结果
    total_detections = 0
    confidence_scores = []
    image_counts = {}
    
    # 分析标签文件
    label_dir = results_dir / 'labels'
    if label_dir.exists():
        label_files = list(label_dir.glob('*.txt'))
        print(f"   处理标签文件: {len(label_files)}")
        
        for label_file in label_files:
            image_name = label_file.stem
            detections_in_image = 0
            
            with open(label_file, 'r') as f:
                for line in f:
                    if line.strip():
                        parts = line.strip().split()
                        if len(parts) >= 6:
                            conf = float(parts[5])
                            confidence_scores.append(conf)
                            total_detections += 1
                            detections_in_image += 1
            
            image_counts[image_name] = detections_in_image
    
    # 检查输出图像
    output_images = list(results_dir.glob('*.jpg')) + list(results_dir.glob('*.png'))
    
    # 计算统计信息
    if confidence_scores:
        avg_confidence = np.mean(confidence_scores)
        max_confidence = np.max(confidence_scores)
        min_confidence = np.min(confidence_scores)
        std_confidence = np.std(confidence_scores)
        
        # 置信度分布统计
        high_conf_count = sum(1 for c in confidence_scores if c > 0.7)
        medium_conf_count = sum(1 for c in confidence_scores if 0.4 <= c <= 0.7)
        low_conf_count = sum(1 for c in confidence_scores if c < 0.4)
        
        print(f"\n📈 检测统计:")
        print(f"   总检测数量: {total_detections}")
        print(f"   平均置信度: {avg_confidence:.3f}")
        print(f"   置信度范围: {min_confidence:.3f} - {max_confidence:.3f}")
        print(f"   置信度标准差: {std_confidence:.3f}")
        print(f"   高置信度检测(>0.7): {high_conf_count}")
        print(f"   中等置信度检测(0.4-0.7): {medium_conf_count}")
        print(f"   低置信度检测(<0.4): {low_conf_count}")
        
        # 每张图像检测统计
        if image_counts:
            avg_detections_per_image = np.mean(list(image_counts.values()))
            max_detections = max(image_counts.values())
            print(f"\n🖼️ 图像检测统计:")
            print(f"   处理图像数量: {len(image_counts)}")
            print(f"   平均每图检测数: {avg_detections_per_image:.1f}")
            print(f"   最大单图检测数: {max_detections}")
            
            # 显示检测最多的图像
            top_images = sorted(image_counts.items(), key=lambda x: x[1], reverse=True)[:5]
            print(f"   检测数最多的图像:")
            for img_name, count in top_images:
                print(f"     {img_name}: {count} 个检测")
        
        # 输出图像信息
        if output_images:
            print(f"\n📁 输出文件:")
            print(f"   检测结果图像: {len(output_images)}")
            for img_path in output_images[:5]:  # 显示前5个
                print(f"     {img_path.name}")
            if len(output_images) > 5:
                print(f"     ... 还有 {len(output_images)-5} 个文件")
        
        return {
            'total_detections': total_detections,
            'avg_confidence': avg_confidence,
            'max_confidence': max_confidence,
            'min_confidence': min_confidence,
            'std_confidence': std_confidence,
            'high_conf_count': high_conf_count,
            'medium_conf_count': medium_conf_count,
            'low_conf_count': low_conf_count,
            'confidence_scores': confidence_scores,
            'image_counts': image_counts,
            'output_images': output_images,
            'results_dir': str(results_dir)
        }
    else:
        print("⚠️ 未检测到任何对象")
        return {
            'total_detections': 0,
            'avg_confidence': 0,
            'results_dir': str(results_dir)
        }

def visualize_results(results):
    """可视化推理结果"""
    if not results or results['total_detections'] == 0:
        print("📊 无检测结果可视化")
        return
    
    print("\n📊 创建结果可视化图表...")
    
    try:
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        
        # 置信度分布直方图
        confidences = results['confidence_scores']
        axes[0].hist(confidences, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
        axes[0].set_title('🎯 检测置信度分布', fontsize=14, fontweight='bold')
        axes[0].set_xlabel('置信度')
        axes[0].set_ylabel('频次')
        axes[0].grid(True, alpha=0.3)
        
        # 添加统计线
        avg_conf = results['avg_confidence']
        axes[0].axvline(avg_conf, color='red', linestyle='--', linewidth=2,
                       label=f'平均置信度: {avg_conf:.3f}')
        axes[0].legend()
        
        # 置信度等级饼图
        high_count = results['high_conf_count']
        medium_count = results['medium_conf_count'] 
        low_count = results['low_conf_count']
        
        labels = ['高置信度(>0.7)', '中等置信度(0.4-0.7)', '低置信度(<0.4)']
        sizes = [high_count, medium_count, low_count]
        colors = ['#2ecc71', '#f39c12', '#e74c3c']
        
        # 过滤掉为0的项
        filtered_data = [(label, size, color) for label, size, color in zip(labels, sizes, colors) if size > 0]
        if filtered_data:
            labels, sizes, colors = zip(*filtered_data)
            axes[1].pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
            axes[1].set_title('🎯 置信度等级分布', fontsize=14, fontweight='bold')
        else:
            axes[1].text(0.5, 0.5, '无检测数据', ha='center', va='center', 
                        transform=axes[1].transAxes, fontsize=16)
            axes[1].set_title('🎯 置信度等级分布', fontsize=14, fontweight='bold')
        
        plt.tight_layout()
        plt.show()
        
        print("✅ 可视化图表创建完成")
        
    except Exception as e:
        print(f"❌ 可视化创建失败: {e}")

def show_sample_results(results):
    """显示示例检测结果"""
    if not results or not results.get('output_images'):
        print("📷 无示例图像可显示")
        return
    
    print("\n📷 显示检测结果示例...")
    
    try:
        output_images = results['output_images']
        sample_count = min(3, len(output_images))
        
        if sample_count > 0:
            fig, axes = plt.subplots(1, sample_count, figsize=(5*sample_count, 5))
            if sample_count == 1:
                axes = [axes]
            
            for i in range(sample_count):
                img_path = output_images[i]
                img = cv2.imread(str(img_path))
                if img is not None:
                    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    axes[i].imshow(img_rgb)
                    axes[i].set_title(f'检测结果: {img_path.name}', fontsize=12)
                    axes[i].axis('off')
                else:
                    axes[i].text(0.5, 0.5, '图像加载失败', ha='center', va='center',
                               transform=axes[i].transAxes)
                    axes[i].set_title(f'错误: {img_path.name}')
            
            plt.tight_layout()
            plt.show()
            print(f"✅ 显示了 {sample_count} 个检测结果示例")
        else:
            print("📷 没有可显示的检测结果图像")
            
    except Exception as e:
        print(f"❌ 示例显示失败: {e}")

# 主执行流程
print("🚀 启动专业模型推理测试...")
print("=" * 70)

# 清理内存
comprehensive_cleanup()

# 查找训练好的模型
model_candidates = [
    'runs/train/tree_detection_professional/weights/best.pt',
    'runs/train/tree_detection_professional/weights/last.pt',
    'runs/train/tree_detection_modern/weights/best.pt',
    'runs/train/tree_detection_basic/weights/best.pt'
]

available_model = None
for model_path in model_candidates:
    if os.path.exists(model_path):
        model_size = os.path.getsize(model_path) / (1024 * 1024)
        print(f"🎯 找到训练模型: {model_path} ({model_size:.1f}MB)")
        available_model = model_path
        break

if not available_model:
    print("❌ 未找到训练好的模型")
    print("💡 请先运行Cell 8完成模型训练")
else:
    # 查找测试图像
    test_dirs = [
        'data/processed/images/test',
        'data/processed/images/val', 
        'data/processed/images/train'
    ]
    
    test_images = []
    for test_dir in test_dirs:
        if os.path.exists(test_dir):
            images = list(Path(test_dir).glob('*.jpg'))
            test_images.extend(images)
            if len(test_images) >= 10:  # 限制测试图像数量
                test_images = test_images[:10]
                break
    
    if not test_images:
        print("❌ 未找到测试图像")
    else:
        print(f"📸 找到测试图像: {len(test_images)} 张")
        
        # 执行专业推理测试
        results = professional_inference(available_model, test_images)
        
        if results:
            print(f"\n🎉 推理测试完成!")
            
            # 可视化结果
            visualize_results(results)
            
            # 显示示例结果
            show_sample_results(results)
            
            # 总结报告
            print(f"\n📋 推理测试总结报告:")
            print(f"   模型: {available_model}")
            print(f"   测试图像: {len(test_images)}")
            print(f"   总检测数: {results['total_detections']}")
            if results['total_detections'] > 0:
                print(f"   平均置信度: {results['avg_confidence']:.3f}")
                print(f"   高质量检测(>0.7): {results['high_conf_count']}")
            print(f"   结果目录: {results['results_dir']}")
            
        else:
            print("❌ 推理测试失败")

# 最终清理
comprehensive_cleanup()
print("✅ 专业模型推理测试完成")
print("=" * 70)

In [None]:
# 10. YOLOv5模型评估和性能分析 - YOLOv5 Model Evaluation and Performance Analysis
def evaluate_yolov5_performance(model_path, test_images):
    """
    评估YOLOv5模型性能
    Evaluate YOLOv5 model performance
    """
    print("📊 开始YOLOv5模型性能评估 Starting YOLOv5 model performance evaluation...")
    
    total_images = len(test_images)
    total_detections = 0
    confidence_scores = []
    processing_times = []
    
    # 创建临时目录存储评估结果 Create temporary directory for evaluation results
    eval_dir = 'temp_eval'
    os.makedirs(eval_dir, exist_ok=True)
    
    print(f"🔍 评估 {total_images} 张图像 Evaluating {total_images} images...")
    
    # 对每张图像进行检测 Detect on each image
    for i, img_path in enumerate(test_images):
        try:
            import time
            start_time = time.time()
            
            # 构建检测命令 Build detection command
            detect_command = [
                sys.executable, 'yolov5/detect.py',
                '--weights', model_path,
                '--source', str(img_path),
                '--project', eval_dir,
                '--name', f'eval_{i}',
                '--save-txt',  # 保存检测结果 save detection results
                '--save-conf', # 保存置信度 save confidence scores
                '--exist-ok',
                '--nosave'     # 不保存图像，只要文本结果 don't save images, only text results
            ]
            
            # 运行检测 Run detection
            result = subprocess.run(detect_command, 
                                  capture_output=True, 
                                  text=True, 
                                  cwd='.')
            
            processing_time = time.time() - start_time
            processing_times.append(processing_time)
            
            if result.returncode == 0:
                # 读取检测结果 Read detection results
                result_dir = Path(eval_dir) / f'eval_{i}' / 'labels'
                if result_dir.exists():
                    for label_file in result_dir.glob('*.txt'):
                        with open(label_file, 'r') as f:
                            lines = f.readlines()
                            for line in lines:
                                if line.strip():
                                    parts = line.strip().split()
                                    if len(parts) >= 6:  # class x y w h conf
                                        confidence = float(parts[5])
                                        confidence_scores.append(confidence)
                                        total_detections += 1
                                        
            if (i + 1) % 10 == 0 or i == total_images - 1:
                print(f"   已处理 {i+1}/{total_images} 张图像 Processed {i+1}/{total_images} images")
                
        except Exception as e:
            print(f"评估图像 {i+1} 时出错 Error evaluating image {i+1}: {e}")
            continue
    
    # 清理临时目录 Clean up temporary directory
    import shutil
    if os.path.exists(eval_dir):
        shutil.rmtree(eval_dir)
    
    # 计算统计信息 Calculate statistics
    avg_detections_per_image = total_detections / total_images if total_images > 0 else 0
    avg_confidence = np.mean(confidence_scores) if confidence_scores else 0
    max_confidence = np.max(confidence_scores) if confidence_scores else 0
    min_confidence = np.min(confidence_scores) if confidence_scores else 0
    avg_processing_time = np.mean(processing_times) if processing_times else 0
    
    print("\n📈 YOLOv5模型性能统计 YOLOv5 Model Performance Statistics")
    print("=" * 60)
    print(f"总测试图像数量 Total test images: {total_images}")
    print(f"总检测数量 Total detections: {total_detections}")
    print(f"平均每张图像检测数量 Average detections per image: {avg_detections_per_image:.2f}")
    print(f"平均置信度 Average confidence: {avg_confidence:.3f}")
    print(f"最高置信度 Maximum confidence: {max_confidence:.3f}")
    print(f"最低置信度 Minimum confidence: {min_confidence:.3f}")
    print(f"平均处理时间 Average processing time: {avg_processing_time:.3f} 秒/张 seconds per image")
    
    # 创建可视化图表 Create visualization charts
    if confidence_scores or processing_times:
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        
        # 置信度分布图 Confidence distribution
        if confidence_scores:
            axes[0].hist(confidence_scores, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
            axes[0].set_title('🎯 检测置信度分布 Detection Confidence Distribution', fontsize=12, fontweight='bold')
            axes[0].set_xlabel('置信度 Confidence Score')
            axes[0].set_ylabel('频次 Frequency')
            axes[0].grid(True, alpha=0.3)
            
            # 添加统计线 Add statistical lines
            axes[0].axvline(avg_confidence, color='red', linestyle='--', 
                          label=f'平均值 Mean: {avg_confidence:.3f}')
            axes[0].legend()
        else:
            axes[0].text(0.5, 0.5, '无置信度数据\nNo Confidence Data', 
                       ha='center', va='center', transform=axes[0].transAxes)
        
        # 处理时间分布图 Processing time distribution
        if processing_times:
            axes[1].hist(processing_times, bins=20, alpha=0.7, color='lightcoral', edgecolor='black')
            axes[1].set_title('⏱️ 处理时间分布 Processing Time Distribution', fontsize=12, fontweight='bold')
            axes[1].set_xlabel('处理时间 Processing Time (seconds)')
            axes[1].set_ylabel('频次 Frequency')
            axes[1].grid(True, alpha=0.3)
            
            # 添加统计线 Add statistical lines
            axes[1].axvline(avg_processing_time, color='red', linestyle='--', 
                          label=f'平均值 Mean: {avg_processing_time:.3f}s')
            axes[1].legend()
        else:
            axes[1].text(0.5, 0.5, '无处理时间数据\nNo Processing Time Data', 
                       ha='center', va='center', transform=axes[1].transAxes)
        
        plt.tight_layout()
        plt.show()
    
    return {
        'total_images': total_images,
        'total_detections': total_detections,
        'avg_detections_per_image': avg_detections_per_image,
        'confidence_scores': confidence_scores,
        'avg_confidence': avg_confidence,
        'processing_times': processing_times,
        'avg_processing_time': avg_processing_time
    }

def run_yolov5_validation(model_path, dataset_yaml):
    """
    运行YOLOv5官方验证脚本
    Run YOLOv5 official validation script
    """
    print("🔬 运行YOLOv5官方验证 Running YOLOv5 official validation...")
    
    try:
        # 构建验证命令 Build validation command
        val_command = [
            sys.executable, 'yolov5/val.py',
            '--weights', model_path,
            '--data', dataset_yaml,
            '--img', '640',
            '--batch', '8',
            '--conf', '0.001',  # 低置信度阈值以获得更多检测 low confidence threshold for more detections
            '--iou', '0.6',     # IoU阈值 IoU threshold
            '--task', 'val',
            '--device', '0' if torch.cuda.is_available() else 'cpu',
            '--save-txt',
            '--save-conf',
            '--project', 'runs/val',
            '--name', 'tree_detection_val',
            '--exist-ok'
        ]
        
        print(f"验证命令 Validation command: {' '.join(val_command)}")
        
        # 运行验证 Run validation
        result = subprocess.run(val_command, 
                              capture_output=True, 
                              text=True, 
                              cwd='.')
        
        if result.returncode == 0:
            print("✅ YOLOv5官方验证完成 YOLOv5 official validation completed!")
            
            # 显示验证输出的关键信息 Show key information from validation output
            output_lines = result.stdout.split('\n')
            print("\n📊 验证结果摘要 Validation Results Summary:")
            print("-" * 50)
            
            for line in output_lines:
                if any(keyword in line.lower() for keyword in ['precision', 'recall', 'map', 'f1']):
                    print(f"   {line.strip()}")
            
            # 检查是否生成了结果文件 Check if result files were generated
            val_results_dir = Path('runs/val/tree_detection_val')
            if val_results_dir.exists():
                print(f"\n📁 验证结果保存在 Validation results saved in: {val_results_dir}")
                
                # 列出生成的文件 List generated files
                result_files = list(val_results_dir.glob('*'))
                for f in result_files:
                    if f.is_file():
                        print(f"   {f.name}")
            
        else:
            print(f"❌ YOLOv5验证失败 YOLOv5 validation failed")
            print(f"错误输出 Error output: {result.stderr}")
            
    except Exception as e:
        print(f"❌ 验证过程中出现错误 Error during validation: {e}")

# 如果有训练好的模型，进行性能评估 If trained model exists, perform performance evaluation
if 'model_path' in locals() and model_path is not None and os.path.exists(model_path):
    print(f"🎯 使用模型进行性能评估 Using model for performance evaluation: {model_path}")
    
    # 获取测试图像 Get test images
    test_image_dir = Path('data/processed/images/test')
    eval_images = []
    
    if test_image_dir.exists():
        eval_images = list(test_image_dir.glob('*.jpg'))
    
    # 如果没有处理过的测试图像，使用原始评估图像 If no processed test images, use original evaluation images
    if not eval_images:
        eval_rgb_dir = Path('data/raw/evaluation/RGB')
        if eval_rgb_dir.exists():
            eval_images = list(eval_rgb_dir.glob('*.tif'))[:20]  # 限制为20张图像
            print(f"📸 使用原始评估图像 Using original evaluation images: {len(eval_images)}")
    
    if eval_images:
        # 进行性能评估 Perform performance evaluation
        performance_stats = evaluate_yolov5_performance(model_path, eval_images)
        
        # 运行官方验证（如果数据集配置存在）Run official validation if dataset config exists
        if 'dataset_yaml_path' in locals() and os.path.exists(dataset_yaml_path):
            run_yolov5_validation(model_path, dataset_yaml_path)
        
    else:
        print("⚠️  未找到测试图像，跳过性能评估 No test images found, skipping performance evaluation")
        
else:
    print("⚠️  未找到训练好的模型，跳过性能评估 No trained model found, skipping performance evaluation")
    print("请先完成模型训练 Please complete model training first")

## 📝 项目总结 Project Summary

### 完成的功能 Completed Features
1. ✅ **环境检测** - 自动检测GPU和Colab环境
2. ✅ **数据下载** - 自动下载NeonTreeEvaluation数据集
3. ✅ **数据转换** - 将NeonTree格式转换为YOLO格式
4. ✅ **模型训练** - 使用YOLOv5进行树木检测训练
5. ✅ **结果推理** - 对测试图像进行树木检测
6. ✅ **结果可视化** - 显示检测结果和性能统计

### 主要参数说明 Key Parameter Explanations
- **训练轮数 Epochs**: 50轮（可根据效果调整）
- **批次大小 Batch Size**: 16（根据GPU内存调整）
- **图像尺寸 Image Size**: 640x640像素
- **学习率 Learning Rate**: 0.001（AdamW优化器）
- **数据分割 Data Split**: 70%训练/20%验证/10%测试

### 使用说明 Usage Instructions
1. 在Google Colab中运行所有代码单元
2. 确保GPU环境可用以加速训练
3. 根据需要调整训练参数
4. 查看results目录中的检测结果图像

### 故障排除 Troubleshooting
- 如果内存不足，减小batch_size参数
- 如果训练时间过长，减少epochs数量
- 如果检测效果不佳，尝试增加训练轮数
- 确保数据集正确下载和解压

### 进一步改进 Further Improvements
- 数据增强技术提升模型泛化能力
- 超参数调优优化模型性能
- 多尺度训练提高检测精度
- 模型集成提升整体效果