# 🌲 树木检测项目 - Tree Detection Project

## 项目概述 Project Overview

本项目使用YOLOv5和PyTorch实现基于航空影像的树木检测任务  
This project implements tree detection from aerial imagery using YOLOv5 and PyTorch

**数据集 Dataset**: NeonTreeEvaluation Benchmark  
**目标 Goal**: 检测航空正射影像中的树木 - Detect trees in aerial orthoimagery  
**平台 Platform**: Google Colab (自动GPU检测 - Auto GPU detection)

## 技术栈 Tech Stack
- **深度学习框架 Deep Learning**: PyTorch + YOLOv5
- **数据处理 Data Processing**: OpenCV, PIL, pandas
- **可视化 Visualization**: matplotlib, seaborn

In [None]:
# 1. 环境检测和基础设置 - Environment Detection and Basic Setup
import os
import sys
import torch
import torchvision
import platform
from pathlib import Path

print("=" * 50)
print("🚀 环境检测 Environment Detection")
print("=" * 50)

# 检查Python版本 Check Python version
print(f"Python版本 Python Version: {sys.version}")

# 检查操作系统 Check OS
print(f"操作系统 Operating System: {platform.system()}")

# 检查是否在Colab环境 Check if running in Colab
try:
    import google.colab
    IN_COLAB = True
    print("✅ 运行环境: Google Colab - Running in Google Colab")
except ImportError:
    IN_COLAB = False
    print("❌ 运行环境: 本地环境 - Running locally")

# 检查GPU可用性 Check GPU availability
if torch.cuda.is_available():
    print(f"✅ GPU可用 GPU Available: {torch.cuda.get_device_name(0)}")
    print(f"   CUDA版本 CUDA Version: {torch.version.cuda}")
    print(f"   GPU数量 GPU Count: {torch.cuda.device_count()}")
    device = torch.device('cuda')
else:
    print("⚠️  GPU不可用，将使用CPU - GPU not available, using CPU")
    device = torch.device('cpu')

print(f"PyTorch版本 PyTorch Version: {torch.__version__}")
print(f"设备 Device: {device}")
print("=" * 50)

: 

In [None]:
# 2. 安装必要依赖 - Install Required Dependencies
print("📦 安装依赖包 Installing Dependencies...")

# 安装YOLOv5和相关依赖 Install YOLOv5 and dependencies
!pip install -q ultralytics
!pip install -q opencv-python-headless
!pip install -q Pillow
!pip install -q matplotlib
!pip install -q seaborn
!pip install -q pandas
!pip install -q tqdm
!pip install -q scikit-learn
!pip install -q PyYAML

# 如果在Colab环境，挂载Google Drive (可选)
# Mount Google Drive in Colab (optional)
if IN_COLAB:
    from google.colab import drive
    # drive.mount('/content/drive')  # 取消注释以挂载Drive - Uncomment to mount Drive

print("✅ 依赖安装完成 Dependencies installed successfully!")

In [None]:
# 3. 导入所需库 - Import Required Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from PIL import Image
import json
import zipfile
import requests
from tqdm import tqdm
import shutil
import yaml
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

# 设置matplotlib中文字体支持 Set matplotlib Chinese font support
plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

# 设置随机种子 Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

print("✅ 库导入完成 Libraries imported successfully!")

In [None]:
# 4. 数据下载和解压 - Data Download and Extraction
def download_file(url, filename):
    """
    下载文件的函数 Function to download files
    """
    print(f"🔄 开始下载 Starting download: {filename}")
    
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    
    with open(filename, 'wb') as file, tqdm(
        desc=filename,
        total=total_size,
        unit='B',
        unit_scale=True,
        unit_divisor=1024,
    ) as pbar:
        for chunk in response.iter_content(chunk_size=8192):
            if chunk:
                file.write(chunk)
                pbar.update(len(chunk))
    
    print(f"✅ 下载完成 Download completed: {filename}")

# 创建数据目录 Create data directories
os.makedirs('data', exist_ok=True)
os.makedirs('data/raw', exist_ok=True)
os.makedirs('data/processed', exist_ok=True)

# 数据集URL Dataset URLs
DATASET_URL = "https://zenodo.org/records/5914554/files/evaluation.zip?download=1"
ANNOTATIONS_URL = "https://zenodo.org/records/5914554/files/annotations.zip?download=1"

# 下载数据集 Download datasets
if not os.path.exists('data/raw/evaluation.zip'):
    download_file(DATASET_URL, 'data/raw/evaluation.zip')
else:
    print("✅ 评估数据集已存在 Evaluation dataset already exists")

if not os.path.exists('data/raw/annotations.zip'):
    download_file(ANNOTATIONS_URL, 'data/raw/annotations.zip')
else:
    print("✅ 标注数据已存在 Annotations already exist")

print("📁 数据下载完成 Data download completed!")

In [None]:
# 5. 解压数据集 - Extract Datasets
def extract_zip(zip_path, extract_to):
    """
    解压ZIP文件的函数 Function to extract ZIP files
    """
    print(f"📦 解压文件 Extracting: {zip_path}")
    
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    
    print(f"✅ 解压完成 Extraction completed: {extract_to}")

# 解压评估数据集 Extract evaluation dataset
if not os.path.exists('data/raw/evaluation'):
    extract_zip('data/raw/evaluation.zip', 'data/raw/')
else:
    print("✅ 评估数据集已解压 Evaluation dataset already extracted")

# 解压标注数据 Extract annotations
if not os.path.exists('data/raw/annotations'):
    extract_zip('data/raw/annotations.zip', 'data/raw/')
else:
    print("✅ 标注数据已解压 Annotations already extracted")

# 查看数据结构 Explore data structure
print("\n📊 数据结构分析 Data Structure Analysis")
print("=" * 50)

# 检查评估数据集结构 Check evaluation dataset structure
eval_path = Path('data/raw/evaluation')
if eval_path.exists():
    print(f"评估数据集路径 Evaluation dataset path: {eval_path}")
    subdirs = [d for d in eval_path.iterdir() if d.is_dir()]
    print(f"子目录数量 Number of subdirectories: {len(subdirs)}")
    for subdir in subdirs[:5]:  # 显示前5个子目录 Show first 5 subdirectories
        print(f"  - {subdir.name}")
    if len(subdirs) > 5:
        print(f"  ... 还有 {len(subdirs)-5} 个目录 and {len(subdirs)-5} more directories")

# 检查标注数据结构 Check annotations structure
ann_path = Path('data/raw/annotations')
if ann_path.exists():
    print(f"\n标注数据路径 Annotations path: {ann_path}")
    ann_files = list(ann_path.glob('*.csv'))
    print(f"CSV标注文件数量 Number of CSV annotation files: {len(ann_files)}")
    for ann_file in ann_files[:3]:  # 显示前3个标注文件 Show first 3 annotation files
        print(f"  - {ann_file.name}")

print("\n✅ 数据解压和结构分析完成 Data extraction and structure analysis completed!")

In [ ]:
# 6. 数据格式转换 - Data Format Conversion (NeonTree -> YOLO) - 使用真实图像
import xml.etree.ElementTree as ET

class NeonTreeToYOLO:
    """
    将NeonTree数据集转换为YOLO格式的类
    Class to convert NeonTree dataset to YOLO format
    """
    
    def __init__(self, annotations_root, evaluation_root, output_root):
        self.annotations_root = Path(annotations_root)
        self.evaluation_root = Path(evaluation_root)
        self.output_root = Path(output_root)
        
        # 创建输出目录 Create output directories
        self.create_yolo_structure()
        
        # 创建图像文件名映射 Create image filename mapping
        self.create_image_mapping()
    
    def create_yolo_structure(self):
        """创建YOLO数据集目录结构 Create YOLO dataset directory structure"""
        folders = [
            'images/train', 'images/val', 'images/test',
            'labels/train', 'labels/val', 'labels/test'
        ]
        
        for folder in folders:
            (self.output_root / folder).mkdir(parents=True, exist_ok=True)
        
        print("✅ YOLO目录结构创建完成 YOLO directory structure created")
    
    def create_image_mapping(self):
        """
        创建图像文件名到路径的映射
        Create mapping from image filenames to paths
        """
        self.image_mapping = {}
        
        # 检查 evaluation/RGB 目录
        rgb_dir = self.evaluation_root / 'RGB'
        if rgb_dir.exists():
            for img_file in rgb_dir.glob('*.tif'):
                # 去掉扩展名作为key
                base_name = img_file.stem
                self.image_mapping[base_name] = img_file
                
                # 同时添加不带扩展名的映射
                if base_name.endswith('.tif'):
                    base_name_no_ext = base_name[:-4]
                    self.image_mapping[base_name_no_ext] = img_file
        
        print(f"📋 创建图像映射 Created image mapping with {len(self.image_mapping)} entries")
    
    def find_matching_image(self, xml_filename):
        """
        根据XML文件名查找对应的RGB图像文件
        Find matching RGB image file based on XML filename
        """
        # 从XML文件名中提取基础名称
        base_name = xml_filename.replace('.xml', '')
        
        # 直接匹配
        if base_name in self.image_mapping:
            return self.image_mapping[base_name]
        
        # 尝试模糊匹配 - 查找包含base_name的图像文件
        for img_name, img_path in self.image_mapping.items():
            # 检查是否匹配（去掉年份等变化部分）
            if self.files_match(base_name, img_name):
                return img_path
        
        return None
    
    def files_match(self, xml_name, img_name):
        """
        判断XML文件名和图像文件名是否匹配
        Check if XML filename matches image filename
        """
        # 移除常见的后缀和前缀
        xml_clean = xml_name.lower()
        img_clean = img_name.lower()
        
        # 提取核心站点代码
        xml_parts = xml_clean.split('_')
        img_parts = img_clean.split('_')
        
        # 如果站点代码匹配
        if len(xml_parts) >= 2 and len(img_parts) >= 2:
            xml_site = xml_parts[0]
            img_site = img_parts[0]
            
            if xml_site == img_site:
                # 进一步检查编号匹配
                if len(xml_parts) >= 3 and len(img_parts) >= 2:
                    try:
                        xml_num = xml_parts[1]
                        img_num = img_parts[1]
                        return xml_num == img_num
                    except:
                        pass
        
        # 检查是否包含相同的关键部分
        return any(part in img_clean for part in xml_parts if len(part) > 2)
    
    def parse_xml_annotation(self, xml_file):
        """
        解析XML标注文件
        Parse XML annotation file
        """
        try:
            tree = ET.parse(xml_file)
            root = tree.getroot()
            
            # 获取图像信息 Get image information
            filename_elem = root.find('filename')
            if filename_elem is not None:
                filename = filename_elem.text
            else:
                filename = xml_file.stem + '.tif'  # 默认扩展名
            
            size = root.find('size')
            width = int(size.find('width').text)
            height = int(size.find('height').text)
            
            # 获取所有树木对象 Get all tree objects
            objects = []
            for obj in root.findall('object'):
                name_elem = obj.find('name')
                if name_elem is not None and name_elem.text.lower() == 'tree':
                    bndbox = obj.find('bndbox')
                    xmin = int(float(bndbox.find('xmin').text))
                    ymin = int(float(bndbox.find('ymin').text))
                    xmax = int(float(bndbox.find('xmax').text))
                    ymax = int(float(bndbox.find('ymax').text))
                    
                    # 验证边界框有效性 Validate bounding box
                    if xmax > xmin and ymax > ymin:
                        objects.append({
                            'xmin': xmin,
                            'ymin': ymin,
                            'xmax': xmax,
                            'ymax': ymax
                        })
            
            return {
                'filename': filename,
                'width': width,
                'height': height,
                'objects': objects
            }
            
        except Exception as e:
            print(f"解析XML文件时出错 Error parsing XML file {xml_file}: {e}")
            return None
    
    def convert_bbox_to_yolo(self, bbox, img_width, img_height):
        """
        将边界框坐标转换为YOLO格式
        Convert bounding box coordinates to YOLO format
        
        YOLO格式: [class_id, x_center, y_center, width, height] (归一化 normalized)
        """
        xmin, ymin, xmax, ymax = bbox['xmin'], bbox['ymin'], bbox['xmax'], bbox['ymax']
        
        # 计算中心点和宽高 Calculate center point and dimensions
        x_center = (xmin + xmax) / 2.0
        y_center = (ymin + ymax) / 2.0
        width = xmax - xmin
        height = ymax - ymin
        
        # 归一化 Normalize
        x_center /= img_width
        y_center /= img_height
        width /= img_width
        height /= img_height
        
        return [0, x_center, y_center, width, height]  # 类别ID为0 (树木 tree)
    
    def process_annotations(self):
        """
        处理所有XML标注文件并转换为YOLO格式
        Process all XML annotation files and convert to YOLO format
        """
        xml_files = list(self.annotations_root.glob('*.xml'))
        processed_count = 0
        skipped_count = 0
        
        print(f"📝 找到 {len(xml_files)} 个XML标注文件 Found {len(xml_files)} XML annotation files")
        
        for i, xml_file in enumerate(tqdm(xml_files, desc="处理标注 Processing annotations")):
            try:
                # 解析XML文件 Parse XML file
                annotation_data = self.parse_xml_annotation(xml_file)
                
                if annotation_data is None or not annotation_data['objects']:
                    skipped_count += 1
                    continue
                
                # 查找对应的RGB图像文件 Find corresponding RGB image file
                img_path = self.find_matching_image(xml_file.name)
                
                if img_path is None:
                    # 打印调试信息，但不中断处理 Print debug info but don't interrupt
                    if processed_count < 5:  # 只打印前5个错误
                        print(f"⚠️  未找到对应图像 No matching image for: {xml_file.name}")
                    skipped_count += 1
                    continue
                
                # 读取真实图像 Load real image
                img = cv2.imread(str(img_path))
                if img is None:
                    print(f"⚠️  无法读取图像 Cannot read image: {img_path}")
                    skipped_count += 1
                    continue
                
                # 验证图像尺寸 Verify image dimensions
                actual_height, actual_width = img.shape[:2]
                if actual_width != annotation_data['width'] or actual_height != annotation_data['height']:
                    # 使用实际图像尺寸 Use actual image dimensions
                    annotation_data['width'] = actual_width
                    annotation_data['height'] = actual_height
                
                # 确定数据集分割 Determine dataset split
                if i % 10 < 7:  # 70% 训练集 training set
                    split = 'train'
                elif i % 10 < 9:  # 20% 验证集 validation set
                    split = 'val'
                else:  # 10% 测试集 test set
                    split = 'test'
                
                # 保存图像 Save image
                img_filename = f"tree_{processed_count:06d}.jpg"
                img_save_path = self.output_root / 'images' / split / img_filename
                cv2.imwrite(str(img_save_path), img)
                
                # 保存YOLO格式标签 Save YOLO format labels
                label_filename = f"tree_{processed_count:06d}.txt"
                label_save_path = self.output_root / 'labels' / split / label_filename
                
                with open(label_save_path, 'w') as f:
                    for obj in annotation_data['objects']:
                        yolo_bbox = self.convert_bbox_to_yolo(
                            obj, annotation_data['width'], annotation_data['height']
                        )
                        # YOLO格式: class_id x_center y_center width height (修复换行符)
                        f.write(f"{yolo_bbox[0]} {yolo_bbox[1]:.6f} {yolo_bbox[2]:.6f} {yolo_bbox[3]:.6f} {yolo_bbox[4]:.6f}\n")
                
                processed_count += 1
                
            except Exception as e:
                print(f"处理文件 {xml_file.name} 时出错 Error processing file {xml_file.name}: {e}")
                skipped_count += 1
                continue
        
        print(f"✅ 处理完成 Processing completed: {processed_count} 个样本 samples processed, {skipped_count} 个跳过 skipped")
        return processed_count

# 执行数据转换 Execute data conversion
print("🔄 开始数据格式转换 Starting data format conversion...")
print("📋 使用XML格式的标注数据和真实RGB图像 Using XML annotations with real RGB images")

# 创建data目录 Create data directories
os.makedirs('data', exist_ok=True)
os.makedirs('data/processed', exist_ok=True)

# 检查annotations和evaluation数据是否存在 Check if annotations and evaluation data exist
annotation_paths = ['data/raw/annotations', 'annotations', '/content/annotations']
eval_paths = ['data/raw/evaluation', 'evaluation', '/content/evaluation']

annotations_root = None
eval_root = None

# 查找annotations目录 Find annotations directory
for path in annotation_paths:
    if Path(path).exists():
        annotations_root = Path(path)
        xml_count = len(list(annotations_root.glob('*.xml')))
        if xml_count > 0:
            print(f"✅ 找到annotations数据集 Found annotations dataset at: {annotations_root} ({xml_count} XML files)")
            break

if annotations_root is None:
    print("❌ 未找到annotations数据集或XML文件")
    print("📁 请确保annotations目录存在且包含XML文件")

# 查找evaluation目录 Find evaluation directory  
for path in eval_paths:
    if Path(path).exists():
        eval_root = Path(path)
        # 检查RGB子目录
        rgb_dir = eval_root / 'RGB'
        if rgb_dir.exists():
            img_count = len(list(rgb_dir.glob('*.tif')))
            print(f"✅ 找到evaluation数据集 Found evaluation dataset at: {eval_root} ({img_count} images)")
            break

if eval_root is None:
    print("❌ 未找到evaluation数据集，请确保已下载并解压evaluation.zip")

# 只有当两个数据集都找到时才进行转换 Only proceed if both datasets are found
if annotations_root is not None and eval_root is not None:
    # 初始化转换器 Initialize converter
    converter = NeonTreeToYOLO(annotations_root, eval_root, 'data/processed')
    
    # 处理标注数据 Process annotation data
    processed_samples = converter.process_annotations()
    
    if processed_samples > 0:
        print(f"✅ 数据转换完成 Data conversion completed: {processed_samples} 个样本 samples")
        
        # 显示数据集统计 Show dataset statistics
        train_images = len(list(Path('data/processed/images/train').glob('*.jpg')))
        val_images = len(list(Path('data/processed/images/val').glob('*.jpg')))
        test_images = len(list(Path('data/processed/images/test').glob('*.jpg')))
        
        print(f"📊 数据集统计 Dataset Statistics:")
        print(f"   训练集 Training: {train_images} 张图像")
        print(f"   验证集 Validation: {val_images} 张图像")
        print(f"   测试集 Testing: {test_images} 张图像")
        print(f"   总计 Total: {train_images + val_images + test_images} 张图像")
    else:
        print("❌ 未能处理任何样本 No samples were processed")
        print("请检查XML文件和图像文件的匹配关系")
else:
    print("❌ 缺少必要的数据集，无法进行转换")
    print("请确保以下目录存在：")
    print("  - annotations 目录（包含XML文件）")
    print("  - evaluation/RGB 目录（包含.tif图像文件）")

In [ ]:
# 7. 创建YOLO配置文件 - Create YOLO Configuration Files (修复路径问题)
import os
from pathlib import Path
import yaml

def create_dataset_yaml():
    """
    创建YOLO数据集配置文件
    Create YOLO dataset configuration file
    """
    # 获取当前工作目录 Get current working directory
    current_dir = os.getcwd()
    
    # 构建绝对路径 Build absolute paths
    processed_dir = os.path.join(current_dir, 'data/processed')
    train_path = os.path.join(processed_dir, 'images/train')
    val_path = os.path.join(processed_dir, 'images/val')
    test_path = os.path.join(processed_dir, 'images/test')
    
    # 检查路径是否存在 Check if paths exist
    paths_info = {
        'train': train_path,
        'val': val_path,
        'test': test_path
    }
    
    print("📋 检查数据集路径 Checking dataset paths:")
    existing_paths = {}
    
    for split, path in paths_info.items():
        if os.path.exists(path):
            img_count = len([f for f in os.listdir(path) if f.endswith('.jpg')])
            print(f"✅ {split}: {path} ({img_count} 张图像 images)")
            existing_paths[split] = path
        else:
            print(f"❌ {split}: {path} (路径不存在 path does not exist)")
    
    # 确保至少有训练集存在 Ensure at least training set exists
    if 'train' not in existing_paths:
        print("❌ 错误：未找到训练集 Error: Training set not found")
        return None
    
    # 创建数据集配置 Create dataset configuration
    dataset_config = {
        'path': processed_dir,  # 数据集根目录 dataset root dir
        'train': 'images/train',  # 相对于path的训练图像路径 train images relative to path
        'val': 'images/val' if 'val' in existing_paths else 'images/train',  # 验证集，如果不存在则使用训练集
        'test': 'images/test' if 'test' in existing_paths else 'images/train',  # 测试集，如果不存在则使用训练集
        'nc': 1,                  # 类别数量 number of classes
        'names': ['tree']         # 类别名称 class names
    }
    
    # 特别处理验证集路径
    if 'val' not in existing_paths:
        print("⚠️  验证集不存在，使用训练集作为验证集 Validation set not found, using training set as validation")
        dataset_config['val'] = 'images/train'
    
    if 'test' not in existing_paths:
        print("⚠️  测试集不存在，使用训练集作为测试集 Test set not found, using training set as test")
        dataset_config['test'] = 'images/train'
    
    # 确保配置目录存在
    os.makedirs('data', exist_ok=True)
    
    # 保存配置文件 Save configuration file
    config_path = 'data/tree_dataset.yaml'
    with open(config_path, 'w') as f:
        yaml.dump(dataset_config, f, default_flow_style=False)
    
    print(f"✅ YOLO数据集配置文件创建完成 YOLO dataset configuration file created: {config_path}")
    
    # 验证配置文件内容
    print("\n🔍 验证配置文件内容 Verifying configuration file:")
    for key, value in dataset_config.items():
        if key in ['train', 'val', 'test']:
            full_path = os.path.join(processed_dir, value)
            exists = os.path.exists(full_path)
            print(f"   {key}: {value} -> {full_path} ({'✅' if exists else '❌'})")
        else:
            print(f"   {key}: {value}")
    
    return config_path

def fix_data_structure():
    """
    修复数据结构问题
    Fix data structure issues
    """
    print("\n🔧 检查并修复数据结构 Checking and fixing data structure...")
    
    processed_dir = Path('data/processed')
    
    if not processed_dir.exists():
        print("❌ data/processed 目录不存在 data/processed directory does not exist")
        return False
    
    # 检查必要的目录结构
    required_dirs = [
        'images/train', 'images/val', 'images/test',
        'labels/train', 'labels/val', 'labels/test'
    ]
    
    missing_dirs = []
    for dir_path in required_dirs:
        full_path = processed_dir / dir_path
        if not full_path.exists():
            missing_dirs.append(dir_path)
    
    if missing_dirs:
        print(f"⚠️  缺少目录 Missing directories: {missing_dirs}")
        
        # 如果只是验证集和测试集缺失，创建它们
        for missing_dir in missing_dirs:
            if 'val' in missing_dir or 'test' in missing_dir:
                (processed_dir / missing_dir).mkdir(parents=True, exist_ok=True)
                print(f"✅ 创建目录 Created directory: {missing_dir}")
    
    # 检查训练集是否有数据
    train_images = list((processed_dir / 'images/train').glob('*.jpg'))
    train_labels = list((processed_dir / 'labels/train').glob('*.txt'))
    
    print(f"📊 训练集统计 Training set statistics:")
    print(f"   图像文件 Image files: {len(train_images)}")
    print(f"   标签文件 Label files: {len(train_labels)}")
    
    if len(train_images) == 0:
        print("❌ 训练集为空，请先运行数据转换 Training set is empty, please run data conversion first")
        return False
    
    return True

# 修复数据结构
data_structure_ok = fix_data_structure()

if data_structure_ok:
    # 创建配置文件 Create configuration file
    dataset_yaml_path = create_dataset_yaml()
    
    if dataset_yaml_path:
        # 显示配置文件内容 Display configuration file content
        with open(dataset_yaml_path, 'r') as f:
            config_content = f.read()
            print("\n📋 数据集配置内容 Dataset Configuration Content:")
            print("=" * 40)
            print(config_content)
            print("=" * 40)
        
        # 最终验证
        print("\n✅ 配置文件创建并验证完成 Configuration file created and verified")
    else:
        print("❌ 配置文件创建失败 Failed to create configuration file")
else:
    print("❌ 数据结构修复失败，请检查数据转换步骤 Data structure fix failed, please check data conversion step")

In [ ]:
# 8. YOLOv5模型训练 - YOLOv5 Model Training (解决resume参数问题)
import torch
import subprocess
import sys
import pickle
import gc  # 垃圾回收

# 首先安装YOLOv5 First install YOLOv5
print("📦 安装YOLOv5... Installing YOLOv5...")
try:
    # 克隆YOLOv5仓库 Clone YOLOv5 repository
    if not os.path.exists('yolov5'):
        !git clone https://github.com/ultralytics/yolov5.git
        os.chdir('yolov5')
        !pip install -r requirements.txt
        os.chdir('..')
    print("✅ YOLOv5安装完成 YOLOv5 installation completed")
except Exception as e:
    print(f"YOLOv5安装警告 YOLOv5 installation warning: {e}")

# 添加yolov5到系统路径 Add yolov5 to system path
if os.path.exists('yolov5'):
    sys.path.append('yolov5')

def check_system_resources():
    """
    检查系统资源
    Check system resources
    """
    print("💻 检查系统资源 Checking system resources...")
    
    # 检查GPU内存
    if torch.cuda.is_available():
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
        gpu_allocated = torch.cuda.memory_allocated(0) / 1024**3
        gpu_reserved = torch.cuda.memory_reserved(0) / 1024**3
        
        print(f"🎮 GPU内存 GPU Memory:")
        print(f"   总内存 Total: {gpu_memory:.1f} GB")
        print(f"   已分配 Allocated: {gpu_allocated:.1f} GB")
        print(f"   已保留 Reserved: {gpu_reserved:.1f} GB")
        print(f"   可用 Available: {gpu_memory - gpu_reserved:.1f} GB")
        
        # 如果可用内存少于2GB，建议降低batch size
        if (gpu_memory - gpu_reserved) < 2.0:
            print("⚠️  GPU内存不足，建议降低batch_size")
            return 'low_memory'
    
    # 检查磁盘空间
    import shutil
    disk_usage = shutil.disk_usage('/')
    free_space = disk_usage.free / 1024**3
    
    print(f"💾 磁盘空间 Disk Space:")
    print(f"   可用空间 Free space: {free_space:.1f} GB")
    
    if free_space < 2.0:
        print("⚠️  磁盘空间不足")
        return 'low_disk'
    
    return 'ok'

def fix_best_pt_compatibility(model_path, output_path='best_fixed.pt'):
    """
    修复best.pt模型的兼容性问题 (内存优化版)
    Fix compatibility issues with best.pt model (memory optimized)
    """
    print(f"🔧 尝试修复模型兼容性 Attempting to fix model compatibility: {model_path}")
    
    try:
        # 方法1: 使用weights_only=False加载
        print("   方法1: 使用weights_only=False加载 Method 1: Load with weights_only=False")
        checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
        
        if 'model' in checkpoint:
            print("   ✅ 成功加载模型 Successfully loaded model")
            
            # 提取模型权重和关键信息
            model_state = checkpoint['model']
            
            # 创建新的兼容检查点 Create new compatible checkpoint
            new_checkpoint = {
                'model': model_state,
                'epoch': checkpoint.get('epoch', 0),
                'best_fitness': checkpoint.get('best_fitness', 0.0),
            }
            
            # 清理不必要的内存占用
            del checkpoint
            gc.collect()
            
            # 保存修复后的模型 Save fixed model
            torch.save(new_checkpoint, output_path)
            print(f"   ✅ 修复后的模型已保存 Fixed model saved: {output_path}")
            return output_path
            
    except Exception as e1:
        print(f"   ❌ 方法1失败 Method 1 failed: {e1}")
        # 清理内存
        gc.collect()
        
        try:
            # 方法3: 仅提取权重字典 (最轻量)
            print("   方法2: 仅提取state_dict Method 2: Extract state_dict only")
            
            checkpoint = torch.load(model_path, map_location='cpu')
            
            # 尝试各种可能的结构
            state_dict = None
            if 'model' in checkpoint:
                if hasattr(checkpoint['model'], 'state_dict'):
                    state_dict = checkpoint['model'].state_dict()
                elif isinstance(checkpoint['model'], dict):
                    state_dict = checkpoint['model']
            elif 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            
            if state_dict is not None:
                # 保存纯权重文件
                torch.save({'model': state_dict}, output_path)
                print(f"   ✅ state_dict提取成功 state_dict extraction successful: {output_path}")
                return output_path
                    
        except Exception as e3:
            print(f"   ❌ 方法2失败 Method 2 failed: {e3}")
    
    print("   ❌ 所有修复方法都失败了 All fix methods failed")
    return None

# 检查系统资源
resource_status = check_system_resources()

# 根据资源情况调整参数
if resource_status == 'low_memory':
    batch_size = 4
    img_size = 416  # 降低图像尺寸
    epochs = 10     # 减少训练轮数
    workers = 1
    print("🔧 检测到低内存，使用轻量化配置")
elif resource_status == 'low_disk':
    batch_size = 8
    img_size = 640
    epochs = 15
    workers = 1
    print("🔧 检测到磁盘空间不足，减少epochs")
else:
    batch_size = 8
    img_size = 640
    epochs = 20  # 减少默认epochs避免长时间运行
    workers = 2

# 检查并修复预训练模型 Check and fix pre-trained model
pretrained_model_path = 'best.pt'
use_custom_model = False
fixed_model_path = None

if os.path.exists(pretrained_model_path):
    print(f"✅ 找到预训练模型 Pre-trained model found: {pretrained_model_path}")
    
    # 尝试修复模型兼容性
    fixed_model_path = fix_best_pt_compatibility(pretrained_model_path)
    
    if fixed_model_path and os.path.exists(fixed_model_path):
        print(f"🎯 使用修复后的模型 Using fixed model: {fixed_model_path}")
        weights_path = fixed_model_path
        use_custom_model = True
    else:
        print("⚠️  模型修复失败，将使用YOLOv5s作为基础模型")
        weights_path = 'yolov5s.pt'
else:
    print("⚠️  未找到best.pt，将使用YOLOv5s预训练模型")
    weights_path = 'yolov5s.pt'

print("\n⚙️  训练参数配置 Training parameters configuration:")
print(f"   图像尺寸 Image size: {img_size}")
print(f"   批次大小 Batch size: {batch_size}")
print(f"   训练轮数 Epochs: {epochs}")
print(f"   数据集配置 Dataset config: {dataset_yaml_path}")
print(f"   模型权重 Model weights: {weights_path}")
print(f"   设备 Device: {device}")
print(f"   资源状态 Resource status: {resource_status}")
print("=" * 60)

# 构建训练命令 Build training command (修复resume参数问题)
training_command = [
    sys.executable, 'yolov5/train.py',
    '--img', str(img_size),
    '--batch', str(batch_size),
    '--epochs', str(epochs),
    '--data', dataset_yaml_path,
    '--weights', weights_path,
    '--project', 'runs/train',
    '--name', 'tree_detection_light',  # 轻量版命名
    '--patience', '5',    # 更早的早停
    '--save-period', '3', # 更频繁保存
    '--workers', str(workers),
    '--cache', 'ram',     # 使用RAM缓存而不是磁盘
    '--exist-ok'
    # 移除了有问题的 --resume 参数
]

# 如果有GPU，添加device参数 If GPU available, add device parameter
if torch.cuda.is_available():
    training_command.extend(['--device', '0'])

print("🚀 开始YOLOv5轻量化训练 Starting YOLOv5 lightweight training...")
print(f"训练命令 Training command: {' '.join(training_command)}")
print("💡 已优化内存使用并修复resume参数问题 Memory usage optimized and resume parameter issue fixed")
print("=" * 60)

# 清理内存
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# 开始训练 Start training
try:
    print("⏳ 训练开始，请耐心等待... Training started, please wait...")
    result = subprocess.run(training_command, 
                          capture_output=True, 
                          text=True, 
                          cwd='.',
                          timeout=1800)  # 30分钟超时
    
    if result.returncode == 0:
        print("✅ YOLOv5模型训练完成 YOLOv5 model training completed!")
        print(f"训练结果保存在 Training results saved in: runs/train/tree_detection_light")
        
        # 显示训练输出的最后几行 Show last few lines of training output
        output_lines = result.stdout.split('\n')
        print("\n📊 训练输出摘要 Training output summary:")
        for line in output_lines[-10:]:
            if line.strip():
                print(f"   {line}")
                
    else:
        print(f"❌ 训练失败 Training failed with return code: {result.returncode}")
        print("错误输出 Error output:")
        print(result.stderr[:1000])  # 限制错误输出长度

except subprocess.TimeoutExpired:
    print("⏰ 训练超时，可能需要减少epochs或batch_size")
except KeyboardInterrupt:
    print("⚠️  训练被用户中断 Training interrupted by user")
except Exception as e:
    print(f"❌ 训练过程中出现错误 Error during training: {e}")

# 检查训练结果 Check training results
weights_dir = Path('runs/train/tree_detection_light/weights')
if weights_dir.exists():
    best_pt = weights_dir / 'best.pt'
    last_pt = weights_dir / 'last.pt'
    
    if best_pt.exists():
        print(f"✅ 最佳模型保存于 Best model saved at: {best_pt}")
        # 设置模型路径供后续使用
        model_path = str(best_pt)
    elif last_pt.exists():
        print(f"✅ 最后模型保存于 Last model saved at: {last_pt}")
        model_path = str(last_pt)
    else:
        model_path = None
        
else:
    print("⚠️  未找到训练结果目录 Training results directory not found")
    model_path = None

# 清理内存
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("🧹 内存清理完成 Memory cleanup completed")

In [ ]:
# 9. 防崩溃GPU优化推理 - Crash-Resistant GPU Optimized Inference
import torch
import gc
import os
import sys
import subprocess
import time
import signal
from pathlib import Path

# 设置进程超时处理
def timeout_handler(signum, frame):
    print("⏰ 推理进程超时，强制结束")
    raise TimeoutError("Inference process timeout")

def comprehensive_cleanup():
    """
    综合内存和GPU清理 - 增强版
    Enhanced comprehensive memory and GPU cleanup
    """
    try:
        print("🧹 执行深度内存清理...")
        
        # 1. Python垃圾回收
        collected = gc.collect()
        print(f"   垃圾回收释放对象: {collected}")
        
        # 2. CUDA缓存清理
        if torch.cuda.is_available():
            # 清空所有GPU缓存
            torch.cuda.empty_cache()
            
            # 同步所有CUDA流
            torch.cuda.synchronize()
            
            # 清理IPC缓存
            torch.cuda.ipc_collect()
            
            # 重置峰值内存统计
            torch.cuda.reset_peak_memory_stats()
            
            # 获取清理后的内存状态
            allocated = torch.cuda.memory_allocated() / 1024**3
            reserved = torch.cuda.memory_reserved() / 1024**3
            print(f"   CUDA内存清理后: 已分配 {allocated:.1f}GB, 已保留 {reserved:.1f}GB")
        
        # 3. 强制系统内存释放
        import psutil
        process = psutil.Process(os.getpid())
        memory_mb = process.memory_info().rss / 1024 / 1024
        print(f"   系统内存使用: {memory_mb:.1f} MB")
        
        # 4. 删除临时变量
        import sys
        local_vars = list(locals().keys())
        for var in local_vars:
            if var.startswith('temp_') or var.startswith('_'):
                try:
                    del locals()[var]
                except:
                    pass
        
        return True
        
    except Exception as e:
        print(f"⚠️  清理过程中出现警告: {e}")
        return False

def check_system_stability():
    """
    检查系统稳定性
    Check system stability
    """
    print("🔍 检查系统稳定性...")
    
    try:
        # 检查内存使用率
        import psutil
        memory = psutil.virtual_memory()
        memory_usage = memory.percent
        
        print(f"   系统内存使用率: {memory_usage:.1f}%")
        
        if memory_usage > 90:
            print("⚠️  系统内存使用率过高，可能导致崩溃")
            return 'critical'
        elif memory_usage > 80:
            print("⚠️  系统内存使用率较高")
            return 'warning'
        
        # 检查GPU状态
        if torch.cuda.is_available():
            try:
                # 简单的GPU测试
                test_tensor = torch.zeros(10, 10).cuda()
                test_result = test_tensor.sum()
                del test_tensor
                torch.cuda.empty_cache()
                
                gpu_memory = torch.cuda.get_device_properties(0).total_memory
                allocated = torch.cuda.memory_allocated()
                usage_percent = (allocated / gpu_memory) * 100
                
                print(f"   GPU内存使用率: {usage_percent:.1f}%")
                
                if usage_percent > 90:
                    return 'critical'
                elif usage_percent > 70:
                    return 'warning'
                    
            except Exception as e:
                print(f"   GPU健康检查失败: {e}")
                return 'gpu_error'
        
        # 检查磁盘空间
        disk = psutil.disk_usage('/')
        disk_usage = (disk.used / disk.total) * 100
        
        print(f"   磁盘使用率: {disk_usage:.1f}%")
        
        if disk_usage > 95:
            return 'critical'
        
        print("✅ 系统状态良好")
        return 'stable'
        
    except Exception as e:
        print(f"❌ 系统稳定性检查失败: {e}")
        return 'error'

def safe_gpu_test():
    """
    安全的GPU测试
    Safe GPU test
    """
    print("🧪 执行安全GPU测试...")
    
    try:
        if not torch.cuda.is_available():
            print("❌ CUDA不可用")
            return False
        
        # 测试1: 基础张量操作
        test1 = torch.randn(100, 100, device='cuda')
        result1 = test1 @ test1.T
        del test1, result1
        
        # 测试2: 内存分配和释放
        for size in [100, 500, 1000]:
            test_tensor = torch.zeros(size, size, device='cuda')
            del test_tensor
            torch.cuda.empty_cache()
        
        # 测试3: 设备同步
        torch.cuda.synchronize()
        
        print("✅ GPU测试通过")
        return True
        
    except Exception as e:
        print(f"❌ GPU测试失败: {e}")
        comprehensive_cleanup()
        return False

def ultra_safe_inference(model_path, test_image_path, output_dir='ultra_safe_results'):
    """
    超安全推理模式 - 单张图像处理
    Ultra-safe inference mode - single image processing
    """
    print(f"🛡️  超安全推理模式: {test_image_path}")
    
    try:
        # 创建独立的输出目录
        os.makedirs(output_dir, exist_ok=True)
        
        # 设置最保守的参数
        safe_params = {
            'img_size': 320,  # 最小图像尺寸
            'conf': 0.5,      # 较高置信度阈值
            'device': '0' if torch.cuda.is_available() else 'cpu'
        }
        
        print(f"   安全参数: {safe_params}")
        
        # 构建超保守的推理命令
        cmd = [
            sys.executable, 'yolov5/detect.py',
            '--weights', model_path,
            '--source', str(test_image_path),
            '--project', output_dir,
            '--name', 'safe_detect',
            '--img', str(safe_params['img_size']),
            '--conf', str(safe_params['conf']),
            '--device', safe_params['device'],
            '--save-txt',
            '--line-thickness', '1',
            '--exist-ok',
            '--nosave'  # 不保存图像，只要检测结果
        ]
        
        print(f"   执行命令: {' '.join(cmd)}")
        
        # 设置超时处理
        start_time = time.time()
        
        # 执行推理
        result = subprocess.run(
            cmd,
            capture_output=True,
            text=True,
            timeout=60,  # 1分钟超时
            cwd='.'
        )
        
        inference_time = time.time() - start_time
        
        if result.returncode == 0:
            print(f"✅ 推理成功 (耗时: {inference_time:.1f}s)")
            
            # 检查结果
            result_dir = Path(output_dir) / 'safe_detect'
            detection_count = 0
            
            if result_dir.exists():
                label_dir = result_dir / 'labels'
                if label_dir.exists():
                    for label_file in label_dir.glob('*.txt'):
                        with open(label_file, 'r') as f:
                            lines = f.readlines()
                            detection_count = len([l for l in lines if l.strip()])
            
            print(f"   检测到 {detection_count} 个对象")
            return {
                'success': True,
                'detections': detection_count,
                'time': inference_time,
                'output': result_dir
            }
        else:
            print(f"❌ 推理失败: {result.stderr[:200]}")
            return {'success': False, 'error': result.stderr}
            
    except subprocess.TimeoutExpired:
        print("⏰ 推理超时")
        return {'success': False, 'error': 'timeout'}
    except Exception as e:
        print(f"❌ 推理异常: {e}")
        return {'success': False, 'error': str(e)}
    finally:
        # 强制清理
        comprehensive_cleanup()

def progressive_inference_test(model_path, test_images_dir):
    """
    渐进式推理测试 - 逐步增加复杂度
    Progressive inference test - gradually increase complexity
    """
    print("🔄 开始渐进式推理测试...")
    
    # 获取测试图像
    test_images = list(Path(test_images_dir).glob('*.jpg'))
    if not test_images:
        print("❌ 未找到测试图像")
        return False
    
    print(f"找到 {len(test_images)} 张测试图像")
    
    results = []
    max_images = min(3, len(test_images))  # 最多测试3张
    
    for i in range(max_images):
        print(f"\n--- 测试图像 {i+1}/{max_images} ---")
        
        # 每次推理前检查系统状态
        stability = check_system_stability()
        
        if stability in ['critical', 'error']:
            print(f"⚠️  系统状态异常 ({stability})，停止测试")
            break
        
        # 执行超安全推理
        img_path = test_images[i]
        result = ultra_safe_inference(model_path, img_path, f'progressive_test_{i+1}')
        
        if result['success']:
            results.append(result)
            print(f"✅ 图像 {i+1} 处理成功")
        else:
            print(f"❌ 图像 {i+1} 处理失败: {result.get('error', 'unknown')}")
            # 失败后进行额外清理
            comprehensive_cleanup()
            time.sleep(2)  # 休息2秒
        
        # 强制清理和短暂休息
        comprehensive_cleanup()
        time.sleep(1)
    
    # 汇总结果
    success_count = len(results)
    print(f"\n📊 渐进式测试汇总:")
    print(f"   成功率: {success_count}/{max_images} ({success_count/max_images*100:.1f}%)")
    
    if results:
        total_detections = sum(r['detections'] for r in results)
        avg_time = sum(r['time'] for r in results) / len(results)
        print(f"   总检测数: {total_detections}")
        print(f"   平均推理时间: {avg_time:.1f}s")
    
    return success_count > 0

# 主执行流程 - 防崩溃版本
print("🚀 启动防崩溃GPU推理系统...")

# 初始深度清理
comprehensive_cleanup()

# 系统稳定性检查
stability_status = check_system_stability()

if stability_status in ['critical', 'error']:
    print("❌ 系统状态不稳定，无法安全进行推理")
    print("建议重启内核或减少其他程序的资源占用")
else:
    # GPU安全测试
    gpu_safe = safe_gpu_test()
    
    if not gpu_safe:
        print("❌ GPU测试失败，将使用CPU模式")
    
    # 查找可用模型
    model_candidates = [
        'runs/train/tree_detection_light/weights/best.pt',
        'runs/train/tree_detection/weights/best.pt', 
        'best_fixed.pt',
        'yolov5s.pt'
    ]
    
    selected_model = None
    for model_path in model_candidates:
        if os.path.exists(model_path):
            model_size = os.path.getsize(model_path) / (1024 * 1024)
            print(f"🎯 找到模型: {model_path} ({model_size:.1f}MB)")
            selected_model = model_path
            break
    
    if not selected_model:
        print("❌ 未找到可用模型")
    else:
        # 查找测试图像
        test_dirs = [
            'data/processed/images/test',
            'data/processed/images/val',
            'data/processed/images/train'
        ]
        
        test_images_dir = None
        for test_dir in test_dirs:
            if os.path.exists(test_dir):
                img_count = len(list(Path(test_dir).glob('*.jpg')))
                if img_count > 0:
                    test_images_dir = test_dir
                    print(f"📁 使用测试图像: {test_dir} ({img_count}张)")
                    break
        
        if test_images_dir:
            # 执行渐进式推理测试
            if progressive_inference_test(selected_model, test_images_dir):
                print("✅ 防崩溃推理测试成功完成")
            else:
                print("❌ 推理测试失败")
        else:
            print("❌ 未找到测试图像")

# 最终系统清理
print("\n🧹 执行最终系统清理...")
comprehensive_cleanup()

# 最终状态检查
final_status = check_system_stability()
print(f"📋 最终系统状态: {final_status}")
print("✅ 防崩溃推理系统运行完成")

In [ ]:
# 10. YOLOv5模型评估和性能分析 - YOLOv5 Model Evaluation and Performance Analysis
def evaluate_yolov5_performance(model_path, test_images):
    """
    评估YOLOv5模型性能
    Evaluate YOLOv5 model performance
    """
    print("📊 开始YOLOv5模型性能评估 Starting YOLOv5 model performance evaluation...")
    
    total_images = len(test_images)
    total_detections = 0
    confidence_scores = []
    processing_times = []
    
    # 创建临时目录存储评估结果 Create temporary directory for evaluation results
    eval_dir = 'temp_eval'
    os.makedirs(eval_dir, exist_ok=True)
    
    print(f"🔍 评估 {total_images} 张图像 Evaluating {total_images} images...")
    
    # 对每张图像进行检测 Detect on each image
    for i, img_path in enumerate(test_images):
        try:
            import time
            start_time = time.time()
            
            # 构建检测命令 Build detection command
            detect_command = [
                sys.executable, 'yolov5/detect.py',
                '--weights', model_path,
                '--source', str(img_path),
                '--project', eval_dir,
                '--name', f'eval_{i}',
                '--save-txt',  # 保存检测结果 save detection results
                '--save-conf', # 保存置信度 save confidence scores
                '--exist-ok',
                '--nosave'     # 不保存图像，只要文本结果 don't save images, only text results
            ]
            
            # 运行检测 Run detection
            result = subprocess.run(detect_command, 
                                  capture_output=True, 
                                  text=True, 
                                  cwd='.')
            
            processing_time = time.time() - start_time
            processing_times.append(processing_time)
            
            if result.returncode == 0:
                # 读取检测结果 Read detection results
                result_dir = Path(eval_dir) / f'eval_{i}' / 'labels'
                if result_dir.exists():
                    for label_file in result_dir.glob('*.txt'):
                        with open(label_file, 'r') as f:
                            lines = f.readlines()
                            for line in lines:
                                if line.strip():
                                    parts = line.strip().split()
                                    if len(parts) >= 6:  # class x y w h conf
                                        confidence = float(parts[5])
                                        confidence_scores.append(confidence)
                                        total_detections += 1
                                        
            if (i + 1) % 10 == 0 or i == total_images - 1:
                print(f"   已处理 {i+1}/{total_images} 张图像 Processed {i+1}/{total_images} images")
                
        except Exception as e:
            print(f"评估图像 {i+1} 时出错 Error evaluating image {i+1}: {e}")
            continue
    
    # 清理临时目录 Clean up temporary directory
    import shutil
    if os.path.exists(eval_dir):
        shutil.rmtree(eval_dir)
    
    # 计算统计信息 Calculate statistics
    avg_detections_per_image = total_detections / total_images if total_images > 0 else 0
    avg_confidence = np.mean(confidence_scores) if confidence_scores else 0
    max_confidence = np.max(confidence_scores) if confidence_scores else 0
    min_confidence = np.min(confidence_scores) if confidence_scores else 0
    avg_processing_time = np.mean(processing_times) if processing_times else 0
    
    print("\n📈 YOLOv5模型性能统计 YOLOv5 Model Performance Statistics")
    print("=" * 60)
    print(f"总测试图像数量 Total test images: {total_images}")
    print(f"总检测数量 Total detections: {total_detections}")
    print(f"平均每张图像检测数量 Average detections per image: {avg_detections_per_image:.2f}")
    print(f"平均置信度 Average confidence: {avg_confidence:.3f}")
    print(f"最高置信度 Maximum confidence: {max_confidence:.3f}")
    print(f"最低置信度 Minimum confidence: {min_confidence:.3f}")
    print(f"平均处理时间 Average processing time: {avg_processing_time:.3f} 秒/张 seconds per image")
    
    # 创建可视化图表 Create visualization charts
    if confidence_scores or processing_times:
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        
        # 置信度分布图 Confidence distribution
        if confidence_scores:
            axes[0].hist(confidence_scores, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
            axes[0].set_title('🎯 检测置信度分布 Detection Confidence Distribution', fontsize=12, fontweight='bold')
            axes[0].set_xlabel('置信度 Confidence Score')
            axes[0].set_ylabel('频次 Frequency')
            axes[0].grid(True, alpha=0.3)
            
            # 添加统计线 Add statistical lines
            axes[0].axvline(avg_confidence, color='red', linestyle='--', 
                          label=f'平均值 Mean: {avg_confidence:.3f}')
            axes[0].legend()
        else:
            axes[0].text(0.5, 0.5, '无置信度数据\nNo Confidence Data', 
                       ha='center', va='center', transform=axes[0].transAxes)
        
        # 处理时间分布图 Processing time distribution
        if processing_times:
            axes[1].hist(processing_times, bins=20, alpha=0.7, color='lightcoral', edgecolor='black')
            axes[1].set_title('⏱️ 处理时间分布 Processing Time Distribution', fontsize=12, fontweight='bold')
            axes[1].set_xlabel('处理时间 Processing Time (seconds)')
            axes[1].set_ylabel('频次 Frequency')
            axes[1].grid(True, alpha=0.3)
            
            # 添加统计线 Add statistical lines
            axes[1].axvline(avg_processing_time, color='red', linestyle='--', 
                          label=f'平均值 Mean: {avg_processing_time:.3f}s')
            axes[1].legend()
        else:
            axes[1].text(0.5, 0.5, '无处理时间数据\nNo Processing Time Data', 
                       ha='center', va='center', transform=axes[1].transAxes)
        
        plt.tight_layout()
        plt.show()
    
    return {
        'total_images': total_images,
        'total_detections': total_detections,
        'avg_detections_per_image': avg_detections_per_image,
        'confidence_scores': confidence_scores,
        'avg_confidence': avg_confidence,
        'processing_times': processing_times,
        'avg_processing_time': avg_processing_time
    }

def run_yolov5_validation(model_path, dataset_yaml):
    """
    运行YOLOv5官方验证脚本
    Run YOLOv5 official validation script
    """
    print("🔬 运行YOLOv5官方验证 Running YOLOv5 official validation...")
    
    try:
        # 构建验证命令 Build validation command
        val_command = [
            sys.executable, 'yolov5/val.py',
            '--weights', model_path,
            '--data', dataset_yaml,
            '--img', '640',
            '--batch', '8',
            '--conf', '0.001',  # 低置信度阈值以获得更多检测 low confidence threshold for more detections
            '--iou', '0.6',     # IoU阈值 IoU threshold
            '--task', 'val',
            '--device', '0' if torch.cuda.is_available() else 'cpu',
            '--save-txt',
            '--save-conf',
            '--project', 'runs/val',
            '--name', 'tree_detection_val',
            '--exist-ok'
        ]
        
        print(f"验证命令 Validation command: {' '.join(val_command)}")
        
        # 运行验证 Run validation
        result = subprocess.run(val_command, 
                              capture_output=True, 
                              text=True, 
                              cwd='.')
        
        if result.returncode == 0:
            print("✅ YOLOv5官方验证完成 YOLOv5 official validation completed!")
            
            # 显示验证输出的关键信息 Show key information from validation output
            output_lines = result.stdout.split('\n')
            print("\n📊 验证结果摘要 Validation Results Summary:")
            print("-" * 50)
            
            for line in output_lines:
                if any(keyword in line.lower() for keyword in ['precision', 'recall', 'map', 'f1']):
                    print(f"   {line.strip()}")
            
            # 检查是否生成了结果文件 Check if result files were generated
            val_results_dir = Path('runs/val/tree_detection_val')
            if val_results_dir.exists():
                print(f"\n📁 验证结果保存在 Validation results saved in: {val_results_dir}")
                
                # 列出生成的文件 List generated files
                result_files = list(val_results_dir.glob('*'))
                for f in result_files:
                    if f.is_file():
                        print(f"   {f.name}")
            
        else:
            print(f"❌ YOLOv5验证失败 YOLOv5 validation failed")
            print(f"错误输出 Error output: {result.stderr}")
            
    except Exception as e:
        print(f"❌ 验证过程中出现错误 Error during validation: {e}")

# 如果有训练好的模型，进行性能评估 If trained model exists, perform performance evaluation
if 'model_path' in locals() and model_path is not None and os.path.exists(model_path):
    print(f"🎯 使用模型进行性能评估 Using model for performance evaluation: {model_path}")
    
    # 获取测试图像 Get test images
    test_image_dir = Path('data/processed/images/test')
    eval_images = []
    
    if test_image_dir.exists():
        eval_images = list(test_image_dir.glob('*.jpg'))
    
    # 如果没有处理过的测试图像，使用原始评估图像 If no processed test images, use original evaluation images
    if not eval_images:
        eval_rgb_dir = Path('data/raw/evaluation/RGB')
        if eval_rgb_dir.exists():
            eval_images = list(eval_rgb_dir.glob('*.tif'))[:20]  # 限制为20张图像
            print(f"📸 使用原始评估图像 Using original evaluation images: {len(eval_images)}")
    
    if eval_images:
        # 进行性能评估 Perform performance evaluation
        performance_stats = evaluate_yolov5_performance(model_path, eval_images)
        
        # 运行官方验证（如果数据集配置存在）Run official validation if dataset config exists
        if 'dataset_yaml_path' in locals() and os.path.exists(dataset_yaml_path):
            run_yolov5_validation(model_path, dataset_yaml_path)
        
    else:
        print("⚠️  未找到测试图像，跳过性能评估 No test images found, skipping performance evaluation")
        
else:
    print("⚠️  未找到训练好的模型，跳过性能评估 No trained model found, skipping performance evaluation")
    print("请先完成模型训练 Please complete model training first")

## 📝 项目总结 Project Summary

### 完成的功能 Completed Features
1. ✅ **环境检测** - 自动检测GPU和Colab环境
2. ✅ **数据下载** - 自动下载NeonTreeEvaluation数据集
3. ✅ **数据转换** - 将NeonTree格式转换为YOLO格式
4. ✅ **模型训练** - 使用YOLOv5进行树木检测训练
5. ✅ **结果推理** - 对测试图像进行树木检测
6. ✅ **结果可视化** - 显示检测结果和性能统计

### 主要参数说明 Key Parameter Explanations
- **训练轮数 Epochs**: 50轮（可根据效果调整）
- **批次大小 Batch Size**: 16（根据GPU内存调整）
- **图像尺寸 Image Size**: 640x640像素
- **学习率 Learning Rate**: 0.001（AdamW优化器）
- **数据分割 Data Split**: 70%训练/20%验证/10%测试

### 使用说明 Usage Instructions
1. 在Google Colab中运行所有代码单元
2. 确保GPU环境可用以加速训练
3. 根据需要调整训练参数
4. 查看results目录中的检测结果图像

### 故障排除 Troubleshooting
- 如果内存不足，减小batch_size参数
- 如果训练时间过长，减少epochs数量
- 如果检测效果不佳，尝试增加训练轮数
- 确保数据集正确下载和解压

### 进一步改进 Further Improvements
- 数据增强技术提升模型泛化能力
- 超参数调优优化模型性能
- 多尺度训练提高检测精度
- 模型集成提升整体效果