In [None]:
from mmengine import Config
import os

# 简化的配置文件修改方法
def update_single_config(config_path):
    """直接通过cfg属性修改配置文件"""
    # 加载配置文件
    cfg = Config.fromfile(config_path)
    
    # 从文件名提取信息
    filename = os.path.basename(config_path)
    
    # 确定模型类型
    if 'cascade-rcnn' in filename:
        model_type = 'cascade'
    elif 'faster-rcnn' in filename:
        model_type = 'faster'
    elif 'retinanet' in filename:
        model_type = 'retinanet'
    else:
        model_type = 'cascade'
    
    # 确定采样方法
    sampling_methods = ['ssc', 'sor', 'random', 'entropy', 'mus_cdb', 'margin', 'least_confidence']
    sampling_method = 'ssc'  # 默认值
    for method in sampling_methods:
        if method in filename:
            sampling_method = method
            break
    
    print(f"更新配置: {filename} ({model_type}_{sampling_method})")
    
    # 构建新的数据路径
    data_root = f'data/ForestDamages/active_learning_{model_type}_{sampling_method}'
    
    # 直接修改cfg的属性
    # 1. 修改数据根目录
    cfg.data_root = data_root
    
    # 2. 修改训练数据加载器路径
    cfg.train_dataloader.dataset.data_root = data_root
    cfg.train_dataloader.dataset.ann_file = f'{data_root}/annotations/instances_labeled_train.json'
    cfg.train_dataloader.dataset.data_prefix.img = f'{data_root}/images_labeled_train'
    
    # 3. 修改验证数据加载器路径  
    cfg.val_dataloader.dataset.data_root = data_root
    cfg.val_dataloader.dataset.ann_file = f'{data_root}/annotations/instances_labeled_val.json'
    cfg.val_dataloader.dataset.data_prefix.img = f'{data_root}/images_labeled_val'
    
    # 4. 修改验证评估器路径
    cfg.val_evaluator.ann_file = f'{data_root}/annotations/instances_labeled_val.json'
    
    # 5. 修改active_learning配置路径
    cfg.active_learning.data_root = data_root
    cfg.active_learning.ann_file = f'{data_root}/annotations/instances_unlabeled.json'
    cfg.active_learning.data_prefix.img = f'{data_root}/images_unlabeled'
    cfg.active_learning.train_pool_cfg.data_root = data_root
    
    # 6. 统一batch_size设置（重要：确保不同采样策略使用相同的batch size）
    # 训练batch size
    cfg.train_dataloader.batch_size = 4  # 统一设置为4，确保一致性
    
    # 验证batch size  
    cfg.val_dataloader.batch_size = 4
    
    # 主动学习推理batch size
    if hasattr(cfg, 'active_learning') and hasattr(cfg.active_learning, 'inference_options'):
        cfg.active_learning.inference_options.batch_size = 8  # 推理可以稍大一些
    
    # 7. 设置工作目录
    cfg.work_dir = f'work_dirs/{model_type}_{sampling_method}'
    
    # 保存修改后的配置
    cfg.dump(config_path)
    print(f"✓ 已保存: {data_root}, batch_size=4 (统一设置)")

# 批量更新所有配置文件
config_dir = 'al_configs'
config_files = []

# 收集所有配置文件
for root, dirs, files in os.walk(config_dir):
    for file in files:
        if file.endswith('.py'):
            config_files.append(os.path.join(root, file))

print(f"找到 {len(config_files)} 个配置文件")

# 批量更新
for config_file in sorted(config_files):
    try:
        update_single_config(config_file)
    except Exception as e:
        print(f"✗ 更新失败 {config_file}: {e}")

print("批量更新完成！")


### 准备配置文件

In [None]:
# 专门用于统一所有配置文件的batch_size设置
def unify_batch_sizes(target_batch_size=4):
    """
    统一所有配置文件的batch_size，解决不同采样策略性能不一致的问题
    """
    config_dir = 'al_configs'
    config_files = []
    
    # 收集所有配置文件
    for root, dirs, files in os.walk(config_dir):
        for file in files:
            if file.endswith('.py'):
                config_files.append(os.path.join(root, file))
    
    print(f"=== 统一batch_size设置为: {target_batch_size} ===")
    
    for config_file in sorted(config_files):
        try:
            cfg = Config.fromfile(config_file)
            
            # 检查当前的batch_size设置
            train_bs = cfg.train_dataloader.batch_size
            val_bs = cfg.val_dataloader.batch_size
            
            # 统一设置batch_size
            cfg.train_dataloader.batch_size = target_batch_size
            cfg.val_dataloader.batch_size = target_batch_size
            
            # 如果有主动学习推理设置，也统一
            if hasattr(cfg, 'active_learning') and hasattr(cfg.active_learning, 'inference_options'):
                old_inference_bs = cfg.active_learning.inference_options.batch_size
                cfg.active_learning.inference_options.batch_size = target_batch_size * 2  # 推理可以稍大
                print(f"{os.path.basename(config_file)}: train({train_bs}->{target_batch_size}), val({val_bs}->{target_batch_size}), inference({old_inference_bs}->{target_batch_size*2})")
            else:
                print(f"{os.path.basename(config_file)}: train({train_bs}->{target_batch_size}), val({val_bs}->{target_batch_size})")
            
            # 保存配置
            cfg.dump(config_file)
            
        except Exception as e:
            print(f"✗ 处理失败 {config_file}: {e}")
    
    print("=== batch_size统一完成 ===")

# 执行统一batch_size操作
# unify_batch_sizes(4)  # 取消注释来执行


In [None]:
# 解决mini-batch组成差异的核心问题
def fix_minibatch_consistency():
    """
    解决不同采样策略mini-batch组成不一致的问题
    """
    config_dir = 'al_configs'
    config_files = []
    
    # 收集所有配置文件
    for root, dirs, files in os.walk(config_dir):
        for file in files:
            if file.endswith('.py'):
                config_files.append(os.path.join(root, file))
    
    print("=== 修复mini-batch一致性问题 ===")
    
    for config_file in sorted(config_files):
        try:
            cfg = Config.fromfile(config_file)
            
            filename = os.path.basename(config_file)
            print(f"处理: {filename}")
            
            # 1. 固定随机种子 - 确保数据加载顺序一致
            cfg.randomness = dict(
                seed=42,                    # 固定种子
                deterministic=True,         # 确定性行为
                diff_rank_seed=False        # 不同rank使用相同种子
            )
            
            # 2. 修改数据采样器 - 禁用shuffle确保顺序一致
            if hasattr(cfg, 'train_dataloader') and hasattr(cfg.train_dataloader, 'sampler'):
                cfg.train_dataloader.sampler = dict(
                    type='DefaultSampler', 
                    shuffle=False,          # 禁用随机打乱
                    seed=42                 # 固定种子
                )
            
            # 3. 禁用AspectRatioBatchSampler - 避免按宽高比重新组合batch
            if hasattr(cfg, 'train_dataloader'):
                cfg.train_dataloader.batch_sampler = None  # 使用默认batch采样
            
            # 4. 固定数据处理pipeline中的随机性
            if hasattr(cfg, 'train_pipeline'):
                for i, transform in enumerate(cfg.train_pipeline):
                    if transform.get('type') == 'RandomFlip':
                        # 可以选择完全禁用或固定种子
                        cfg.train_pipeline[i]['prob'] = 0.0  # 禁用随机翻转
                    elif transform.get('type') == 'PhotoMetricDistortion':
                        # 禁用颜色增强的随机性
                        cfg.train_pipeline[i] = dict(type='Identity')  # 替换为无操作
            
            # 5. 确保验证也使用一致的设置
            if hasattr(cfg, 'val_dataloader') and hasattr(cfg.val_dataloader, 'sampler'):
                cfg.val_dataloader.sampler = dict(
                    type='DefaultSampler',
                    shuffle=False,
                    seed=42
                )
            
            print(f"  ✓ 已固定随机种子和采样顺序")
            
            # 保存配置
            cfg.dump(config_file)
            
        except Exception as e:
            print(f"✗ 处理失败 {config_file}: {e}")
    
    print("=== mini-batch一致性修复完成 ===")

# 选择性解决方案：只固定种子但保留数据增强
def fix_seeds_only():
    """只固定随机种子，保持数据增强的多样性"""
    config_dir = 'al_configs'
    config_files = []
    
    for root, dirs, files in os.walk(config_dir):
        for file in files:
            if file.endswith('.py'):
                config_files.append(os.path.join(root, file))
    
    print("=== 仅固定随机种子 ===")
    
    for config_file in sorted(config_files):
        try:
            cfg = Config.fromfile(config_file)
            filename = os.path.basename(config_file)
            
            # 固定全局随机种子
            cfg.randomness = dict(
                seed=42,
                deterministic=True,
                diff_rank_seed=False
            )
            
            # 固定数据采样器种子，但保持shuffle
            if hasattr(cfg, 'train_dataloader'):
                cfg.train_dataloader.sampler = dict(
                    type='DefaultSampler',
                    shuffle=True,  # 保持shuffle但种子固定
                    seed=42
                )
            
            cfg.dump(config_file)
            print(f"✓ {filename}: 已固定种子")
            
        except Exception as e:
            print(f"✗ {filename}: {e}")

# 终极解决方案：确保所有策略使用完全相同的训练样本顺序
def ensure_identical_training_order():
    """
    确保所有采样策略在每个epoch中使用完全相同的样本顺序
    这是最彻底的解决方案
    """
    config_dir = 'al_configs'
    
    print("=== 确保训练样本顺序完全一致 ===")
    print("方法1: 使用固定的sample list")
    print("方法2: 禁用所有随机性")
    print("方法3: 使用相同的random state")
    
    # 添加环境变量配置
    env_config = """
# 添加到配置文件末尾
env_cfg = dict(
    cudnn_benchmark=False,     # 禁用cudnn benchmark确保确定性
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
    dist_cfg=dict(backend='nccl')
)

# 确保完全确定性的训练
import os
import torch
import numpy as np
import random

# 设置所有随机种子
def set_all_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_all_seeds(42)
"""
    
    print("推荐在主训练脚本中添加以上代码确保完全确定性")

# 执行选择（取消注释需要的方案）
# fix_minibatch_consistency()     # 最严格：禁用所有随机性
# fix_seeds_only()                # 中等：固定种子但保持增强
# ensure_identical_training_order() # 查看终极方案


In [None]:
# 主动学习专用解决方案：确保所有策略在每轮使用相同的已标注数据
def sync_active_learning_data():
    """
    主动学习场景下的特殊处理：
    确保不同采样策略在每一轮训练中使用完全相同的已标注数据池
    """
    print("=== 主动学习数据同步方案 ===")
    print("""
    主动学习中mini-batch差异的核心问题：
    
    1. 问题根源：
       - 不同采样策略选择不同的样本进行标注
       - 即使batch_size相同，但每轮训练的数据池不同
       - 导致模型学习到不同的特征分布
    
    2. 解决方案：
       
       方案A: 固定初始标注池（推荐用于对比实验）
       ├── 所有策略从相同的初始标注样本开始
       ├── 每轮训练使用完全相同的已标注数据
       └── 只有新选择的样本不同，但已标注部分保持一致
       
       方案B: 数据增强一致性
       ├── 固定数据增强的随机种子
       ├── 确保相同图片在不同策略中得到相同的增强
       └── 保持训练的随机性但增强过程一致
       
       方案C: 批次组成控制
       ├── 控制每个batch中不同类别样本的比例
       ├── 确保不同策略的batch有相似的样本分布
       └── 使用ClassAwareSampler或自定义采样器
    
    3. 实现建议：
    """)
    
    # 为主动学习添加专用配置
    al_config_template = '''
# 主动学习专用配置
active_learning = dict(
    # 数据同步设置
    sync_settings=dict(
        use_fixed_initial_pool=True,        # 使用固定的初始标注池
        initial_pool_seed=42,               # 初始池选择种子
        batch_composition_control=True,     # 控制batch组成
        augmentation_sync=True,             # 同步数据增强
    ),
    
    # 确保一致的推理设置
    inference_options=dict(
        batch_size=8,
        deterministic=True,                 # 确定性推理
        score_thr=0.05,
        seed=42,                           # 推理种子
    ),
    
    # 样本选择同步
    sample_selection=dict(
        num_samples=200,
        selection_seed=42,                  # 选择过程种子
        use_deterministic_selection=True,   # 确定性选择
    )
)

# 训练配置同步
train_cfg = dict(
    type='EpochBasedTrainLoop',
    max_epochs=3,
    val_interval=1,
    # 确定性训练设置
    deterministic=True,
    seed=42
)

# 随机性控制
randomness = dict(
    seed=42,
    deterministic=True,
    diff_rank_seed=False
)
'''
    
    print("建议在配置文件中添加以上同步设置")
    print("\n4. 代码实现示例：")
    
    code_example = '''
# 在训练脚本中添加
import torch
import numpy as np
import random

def ensure_reproducibility():
    # 设置所有随机种子
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    # 确定性设置
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # 环境变量
    import os
    os.environ['PYTHONHASHSEED'] = str(seed)

# 在每轮主动学习开始前调用
ensure_reproducibility()
'''
    
    print(code_example)

# 快速修复函数：添加确定性设置到所有配置
def add_deterministic_settings():
    """为所有配置文件添加确定性设置"""
    config_dir = 'al_configs'
    config_files = []
    
    for root, dirs, files in os.walk(config_dir):
        for file in files:
            if file.endswith('.py'):
                config_files.append(os.path.join(root, file))
    
    print("=== 添加确定性设置 ===")
    
    for config_file in sorted(config_files):
        try:
            cfg = Config.fromfile(config_file)
            filename = os.path.basename(config_file)
            
            # 添加随机性控制
            cfg.randomness = dict(seed=42, deterministic=True, diff_rank_seed=False)
            
            # 修改环境配置
            if not hasattr(cfg, 'env_cfg'):
                cfg.env_cfg = dict()
            cfg.env_cfg.update(dict(
                cudnn_benchmark=False,  # 重要：禁用benchmark确保确定性
                mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
                dist_cfg=dict(backend='nccl')
            ))
            
            # 修改数据采样器确保一致性
            if hasattr(cfg, 'train_dataloader'):
                cfg.train_dataloader.sampler = dict(
                    type='DefaultSampler', 
                    shuffle=True,  # 保持shuffle但种子固定
                    seed=42
                )
                
                # 添加worker初始化函数确保每个worker的种子一致
                cfg.train_dataloader.worker_init_fn = 'seed_worker'
            
            cfg.dump(config_file)
            print(f"✓ {filename}: 已添加确定性设置")
            
        except Exception as e:
            print(f"✗ {filename}: {e}")
    
    print("=== 确定性设置添加完成 ===")

# 执行函数
# add_deterministic_settings()  # 取消注释执行
# sync_active_learning_data()   # 查看详细方案


In [None]:
# 测试单个配置文件
config_path = 'al_configs/cascade-rcnn/cascade-rcnn_r101_ssc.py'

# 加载配置文件
cfg = Config.fromfile(config_path)

# 查看关键配置信息
print("=== 当前配置信息 ===")
print(f"数据根目录: {cfg.data_root}")
print(f"训练数据: {cfg.train_dataloader.dataset.ann_file}")
print(f"验证数据: {cfg.val_dataloader.dataset.ann_file}")
print(f"主动学习数据: {cfg.active_learning.ann_file}")
print(f"工作目录: {getattr(cfg, 'work_dir', '未设置')}")
print(f"批次大小: {cfg.train_dataloader.batch_size}")


In [None]:
# 配置文件后处理工具
import os
import re
from pathlib import Path
from mmengine import Config

class ConfigPostProcessor:
    def __init__(self, base_dir="al_configs"):
        self.base_dir = Path(base_dir)
        
    def extract_info_from_filename(self, filepath):
        """从文件名中提取模型类型和采样方法信息"""
        filename = filepath.stem
        
        # 提取模型类型
        if "cascade-rcnn" in filename:
            model_type = "cascade-rcnn"
        elif "faster-rcnn" in filename:
            model_type = "faster-rcnn"
        elif "retinanet" in filename:
            model_type = "retinanet"
        else:
            model_type = "unknown"
            
        # 提取采样方法
        sampling_methods = ["ssc", "sor", "random", "entropy", "mus_cdb", "margin", "least_confidence"]
        sampling_method = "unknown"
        
        for method in sampling_methods:
            if method in filename.replace("-", "_").replace(" ", "_"):
                sampling_method = method
                break
                
        return model_type, sampling_method
    
    def update_active_learning_paths(self, config_path):
        """更新主动学习相关的数据路径"""
        with open(config_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # 从文件名提取信息
        model_type, sampling_method = self.extract_info_from_filename(config_path)
        
        # 生成新的路径
        model_prefix = model_type.replace("-", "_")
        new_data_root = f"data/ForestDamages/active_learning_{model_prefix}_{sampling_method}"
        
        # 更新路径的正则表达式
        patterns_replacements = [
            # 更新 data_root
            (r"data_root\s*=\s*['\"][^'\"]*['\"]", f"data_root='{new_data_root}'"),
            # 更新 ann_file 中的路径
            (r"ann_file\s*=\s*['\"]([^'\"]*)/annotations/instances_unlabeled\.json['\"]",
             f"ann_file='{new_data_root}/annotations/instances_unlabeled.json'"),
            # 更新 data_prefix 中的路径
            (r"data_prefix\s*=\s*dict\s*\(\s*img\s*=\s*['\"]([^'\"]*)/images_unlabeled['\"]",
             f"data_prefix=dict(img='{new_data_root}/images_unlabeled'"),
            # 更新 train_pool_cfg 中的路径
            (r"(train_pool_cfg\s*=\s*dict\s*\(\s*data_root\s*=\s*['\"])[^'\"]*(['\"])",
             rf"\g<1>{new_data_root}\g<2>"),
        ]
        
        for pattern, replacement in patterns_replacements:
            content = re.sub(pattern, replacement, content)
            
        return content
    
    def update_batch_sizes(self, content, model_type, sampling_method):
        """根据模型类型和采样方法优化batch size"""
        
        # 定义不同情况下的最优batch size
        batch_size_map = {
            ("cascade-rcnn", "sor"): 2,      # SOR方法内存需求大
            ("cascade-rcnn", "mus_cdb"): 32, # MUS-CDB可以用较大batch
            ("faster-rcnn", "sor"): 2,
            ("faster-rcnn", "mus_cdb"): 32,
            ("retinanet", "sor"): 4,         # RetinaNet相对内存友好
            ("retinanet", "mus_cdb"): 64,
        }
        
        # 默认batch size
        default_batch_sizes = {
            "cascade-rcnn": 16,
            "faster-rcnn": 16, 
            "retinanet": 32
        }
        
        # 获取合适的batch size
        batch_size = batch_size_map.get(
            (model_type, sampling_method), 
            default_batch_sizes.get(model_type, 16)
        )
        
        # 更新batch_size
        content = re.sub(
            r"batch_size\s*=\s*\d+",
            f"batch_size={batch_size}",
            content
        )
        
        return content
    
    def update_score_thresholds(self, content, model_type):
        """根据模型类型调整检测阈值"""
        
        thresholds = {
            "cascade-rcnn": {"score": 0.05, "nms": 0.5},  # Cascade R-CNN可以用更低阈值
            "faster-rcnn": {"score": 0.08, "nms": 0.4},   # 标准设置
            "retinanet": {"score": 0.1, "nms": 0.3}       # RetinaNet需要更高阈值
        }
        
        if model_type in thresholds:
            threshold_info = thresholds[model_type]
            
            # 更新 score_threshold
            content = re.sub(
                r"score_threshold\s*=\s*[\d.]+",
                f"score_threshold = {threshold_info['score']}",
                content
            )
            
            # 更新 nms_iou_threshold  
            content = re.sub(
                r"nms_iou_threshold\s*=\s*[\d.]+",
                f"nms_iou_threshold = {threshold_info['nms']}",
                content
            )
            
        return content
    
    def add_missing_work_dir(self, content, config_path):
        """添加工作目录配置"""
        filename = config_path.stem
        
        if "work_dir" not in content:
            work_dir = f"work_dirs/{filename}"
            content += f"\n\n# 工作目录\nwork_dir = '{work_dir}'\n"
            
        return content
    
    def process_config_file(self, config_path):
        """处理单个配置文件"""
        print(f"🔧 处理配置文件: {config_path}")
        
        # 提取模型信息
        model_type, sampling_method = self.extract_info_from_filename(config_path)
        print(f"   模型: {model_type}, 采样方法: {sampling_method}")
        
        # 更新路径
        content = self.update_active_learning_paths(config_path)
        
        # 更新参数
        content = self.update_batch_sizes(content, model_type, sampling_method)
        content = self.update_score_thresholds(content, model_type)
        content = self.add_missing_work_dir(content, config_path)
        
        # 保存文件
        with open(config_path, 'w', encoding='utf-8') as f:
            f.write(content)
            
        print(f"   ✅ 完成")
        return True
    
    def get_all_config_files(self):
        """获取所有配置文件路径"""
        config_files = []
        model_types = ["cascade-rcnn", "faster-rcnn", "retinanet"]
        
        for model_type in model_types:
            model_dir = self.base_dir / model_type
            if model_dir.exists():
                config_files.extend(list(model_dir.glob("*.py")))
        return config_files
    
    def process_all_configs(self):
        """处理所有配置文件"""
        config_files = self.get_all_config_files()
        
        print(f"🚀 找到 {len(config_files)} 个配置文件")
        print("="*60)
        
        success_count = 0
        for config_file in config_files:
            try:
                if self.process_config_file(config_file):
                    success_count += 1
            except Exception as e:
                print(f"❌ 处理 {config_file} 时出错: {e}")
        
        print("="*60)
        print(f"✅ 处理完成: {success_count}/{len(config_files)} 个文件成功")
        
        return success_count == len(config_files)
    
    def validate_config(self, config_path):
        """验证配置文件是否可以正常加载"""
        try:
            cfg = Config.fromfile(str(config_path))
            print(f"✅ 配置文件验证成功: {config_path.name}")
            return True
        except Exception as e:
            print(f"❌ 配置文件验证失败: {config_path.name}, 错误: {e}")
            return False
    
    def validate_all_configs(self):
        """验证所有配置文件"""
        config_files = self.get_all_config_files()
        
        print("🔍 验证所有配置文件...")
        print("="*60)
        
        success_count = 0
        for config_file in config_files:
            if self.validate_config(config_file):
                success_count += 1
        
        print("="*60)
        print(f"验证完成: {success_count}/{len(config_files)} 个配置文件通过验证")

# 创建处理器实例
processor = ConfigPostProcessor()
print("✅ 配置文件后处理器初始化完成")


In [None]:
# 🚀 执行配置文件后处理
print("开始处理所有配置文件...")
success = processor.process_all_configs()

if success:
    print("\n🎉 所有配置文件更新成功!")
else:
    print("\n⚠️ 部分配置文件更新失败，请检查错误信息")


In [None]:
# 🔍 验证配置文件加载
print("验证配置文件是否可以正常加载...")
processor.validate_all_configs()


In [None]:
# 🧪 测试单个配置文件加载 (原有的测试代码)
print("测试单个配置文件加载...")

try:
    # 加载配置文件
    cfg = Config.fromfile('al_configs/cascade-rcnn/cascade-rcnn_F_r101_ssc_16_200.py')
    
    # 打印配置信息
    print("✅ 配置文件加载成功!")
    print(f"模型类型: {cfg.model.type}")
    print(f"数据根目录: {cfg.active_learning.data_root}")
    print(f"检测阈值: {cfg.score_threshold}")
    print(f"最大检测框数: {cfg.max_boxes_per_img}")
    
    # 保存到文件
    with open('generated_config.py', 'w') as f:
        f.write(cfg.pretty_text)
    print("✅ 配置文件已保存到 generated_config.py")
    
except Exception as e:
    print(f"❌ 配置文件加载失败: {e}")


In [5]:
from mmengine import Config

# 加载配置文件
cfg = Config.fromfile('al_configs/cascade-rcnn/cascade-rcnn_F_r101_ssc_16_200.py')

    # 直接修改cfg的属性
    # 1. 修改数据根目录
cfg.data_root = data_root

# 2. 修改训练数据加载器路径
cfg.train_dataloader.dataset.data_root = data_root
cfg.train_dataloader.dataset.ann_file = f'{data_root}/annotations/instances_labeled_train.json'
cfg.train_dataloader.dataset.data_prefix.img = f'{data_root}/images_labeled_train'

# 3. 修改验证数据加载器路径  
cfg.val_dataloader.dataset.data_root = data_root
cfg.val_dataloader.dataset.ann_file = f'{data_root}/annotations/instances_labeled_val.json'
cfg.val_dataloader.dataset.data_prefix.img = f'{data_root}/images_labeled_val'

# 4. 修改验证评估器路径
cfg.val_evaluator.ann_file = f'{data_root}/annotations/instances_labeled_val.json'

# 5. 修改active_learning配置路径
cfg.active_learning.data_root = data_root
cfg.active_learning.ann_file = f'{data_root}/annotations/instances_unlabeled.json'
cfg.active_learning.data_prefix.img = f'{data_root}/images_unlabeled'
cfg.active_learning.train_pool_cfg.data_root = data_root

# 6. 设置工作目录
cfg.work_dir = f'work_dirs/{model_type}_{sampling_method}'

# 打印配置信息
print(cfg.pretty_text)

# 保存到文件
with open('try.py', 'w') as f:
    f.write(cfg.pretty_text)

active_learning = dict(
    ann_file=
    'data/ForestDamages/active_learning_cascade_ssc/annotations/instances_unlabeled.json',
    data_prefix=dict(
        img='data/ForestDamages/active_learning_cascade_ssc/images_unlabeled'),
    data_root='data/ForestDamages/active_learning_cascade_ssc',
    inference_options=dict(
        batch_size=16,
        sample_size=0,
        save_results=True,
        score_thr=0.08,
        selected_metric='ssc_score',
        uncertainty_methods=[
            'ssc',
        ]),
    max_iterations=16,
    sample_selection=dict(
        num_samples=200,
        rl_metric='',
        sample_selector='default',
        uncertainty_metric='ssc_score'),
    train_pool_cfg=dict(
        ann_file='annotations/instances_labeled_train.json',
        data_prefix=dict(img='images_labeled_train'),
        data_root='data/ForestDamages/active_learning_cascade_ssc'))
data_root = 'data/ForestDamages/active_learning_cascade_sor'
dataset_type = 'ForestDamagesDataset'
d

In [None]:


->

from mmengine import Config
import os

# 简化的配置文件修改方法
def update_single_config(config_path):
    """直接通过cfg属性修改配置文件"""
    # 加载配置文件
    cfg = Config.fromfile(config_path)
    
    # 从文件名提取信息
    filename = os.path.basename(config_path)
    
    # 确定模型类型
    if 'cascade-rcnn' in filename:
        model_type = 'cascade'
    elif 'faster-rcnn' in filename:
        model_type = 'faster'
    elif 'retinanet' in filename:
        model_type = 'retinanet'
    else:
        model_type = 'cascade'
    
    # 确定采样方法
    sampling_methods = ['ssc', 'sor', 'random', 'entropy', 'mus_cdb', 'margin', 'least_confidence']
    sampling_method = 'ssc'  # 默认值
    for method in sampling_methods:
        if method in filename:
            sampling_method = method
            break
    
    print(f"更新配置: {filename} ({model_type}_{sampling_method})")
    
    # 构建新的数据路径
    data_root = f'data/ForestDamages/active_learning_{model_type}_{sampling_method}'
    
    # 直接修改cfg的属性
    # 1. 修改数据根目录
    cfg.data_root = data_root
    
    # 2. 修改训练数据加载器路径
    cfg.train_dataloader.dataset.data_root = data_root
    cfg.train_dataloader.dataset.ann_file = f'{data_root}/annotations/instances_labeled_train.json'
    cfg.train_dataloader.dataset.data_prefix.img = f'{data_root}/images_labeled_train'
    
    # 3. 修改验证数据加载器路径  
    cfg.val_dataloader.dataset.data_root = data_root
    cfg.val_dataloader.dataset.ann_file = f'{data_root}/annotations/instances_labeled_val.json'
    cfg.val_dataloader.dataset.data_prefix.img = f'{data_root}/images_labeled_val'
    
    # 4. 修改验证评估器路径
    cfg.val_evaluator.ann_file = f'{data_root}/annotations/instances_labeled_val.json'
    
    # 5. 修改active_learning配置路径
    cfg.active_learning.data_root = data_root
    cfg.active_learning.ann_file = f'{data_root}/annotations/instances_unlabeled.json'
    cfg.active_learning.data_prefix.img = f'{data_root}/images_unlabeled'
    cfg.active_learning.train_pool_cfg.data_root = data_root
    
    # 6. 设置工作目录
    cfg.work_dir = f'work_dirs/{model_type}_{sampling_method}'
    
    # 保存修改后的配置
    cfg.dump(config_path)
    print(f"✓ 已保存: {data_root}")

# 批量更新所有配置文件
config_dir = 'al_configs'
config_files = []

# 收集所有配置文件
for root, dirs, files in os.walk(config_dir):
    for file in files:
        if file.endswith('.py'):
            config_files.append(os.path.join(root, file))

print(f"找到 {len(config_files)} 个配置文件")

# 批量更新
for config_file in sorted(config_files):
    try:
        update_single_config(config_file)
    except Exception as e:
        print(f"✗ 更新失败 {config_file}: {e}")

print("批量更新完成！")