# 环境设置

In [3]:
import sys
import os
import numpy as np
import torch

# 将 detectron2 的源码路径添加到 Python 的搜索路径中
# 这是我们从源码编译 detectron2 的文件夹
d2_path = '/mnt/lyh/DA-FasterCNN/detectron2-main'
if d2_path not in sys.path:
    sys.path.insert(0, d2_path)

# 验证路径是否已添加
print("sys.path has been updated:")
print(sys.path)

sys.path has been updated:
['/mnt/lyh/DA-FasterCNN/detectron2-main', '/home/refrain/anaconda3/envs/lyh_env/lib/python310.zip', '/home/refrain/anaconda3/envs/lyh_env/lib/python3.10', '/home/refrain/anaconda3/envs/lyh_env/lib/python3.10/lib-dynload', '', '/home/refrain/anaconda3/envs/lyh_env/lib/python3.10/site-packages']


# Check the installation
The right output running the following cell should be:<br><br>
nvcc: NVIDIA (R) Cuda compiler driver<br>
Copyright (c) 2005-2020 NVIDIA Corporation<br>
Built on Wed_Sep_21_10:33:58_PDT_2022<br>
Cuda compilation tools, release 11.8, V11.8.89<br>
Build cuda_11.8.r11.8/compiler.31833905_0<br>
torch:  2.0 ; cuda:  cu118<br>
detectron2: 0.6<br>

In [4]:
import torch, detectron2
!nvcc --version
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
print("detectron2:", detectron2.__version__)

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0
torch:  2.0 ; cuda:  cu118
detectron2: 0.6
torch:  2.0 ; cuda:  cu118
detectron2: 0.6


In [5]:
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
import numpy as np
import cv2
import random
from detectron2 import model_zoo
from detectron2.config import get_cfg
import logging
import os
from collections import OrderedDict
from torch.nn.parallel import DistributedDataParallel
import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer
from detectron2.data import MetadataCatalog, build_detection_test_loader, build_detection_train_loader
from detectron2.modeling import build_model
from detectron2.solver import build_lr_scheduler, build_optimizer
from detectron2.utils.events import EventStorage
from detectron2.engine import default_argument_parser, default_setup, default_writers, launch
import torch, torchvision
from detectron2.data.datasets import register_coco_instances, load_coco_json, register_pascal_voc
# from google.colab import drive
# drive.mount('/content/drive')

  import pkg_resources


# 数据准备与注册

In [5]:
# 数据准备 (Cityscapes -> PASCAL VOC 格式)
# !!! 注意：此单元格仅在需要从原始数据生成VOC格式时运行一次。
# 如果 "/mnt/lyh/DA-FasterCNN/DA-Faster-RCNN/datasets/cityscape/VOC2007" 目录中已有数据，请跳过此单元格。

import os, json, glob, shutil, pprint
import cv2
import xml.etree.ElementTree as ET

# --- 配置路径 ---
CITY_ROOT = "/mnt/lyh/DA-FasterCNN/DA-Faster-RCNN/datasets/cityscapes_raw"
OUT_BASE = "/mnt/lyh/DA-FasterCNN/DA-Faster-RCNN/datasets/cityscape/VOC2007"
CLASSES = ['car','person','rider','truck','bus','train','motorcycle','bicycle']

# --- 确保输出目录存在 ---
os.makedirs(os.path.join(OUT_BASE, "JPEGImages"), exist_ok=True)
os.makedirs(os.path.join(OUT_BASE, "Annotations"), exist_ok=True)
os.makedirs(os.path.join(OUT_BASE, "ImageSets", "Main"), exist_ok=True)
print(f"输出目录 {OUT_BASE} 已准备好。")

# --- 辅助函数 ---
def write_voc_xml(img_id, fname, w, h, objs, out_path):
    # ... (此函数内容保持不变) ...
    ann = ET.Element("annotation")
    ET.SubElement(ann, "folder").text = "VOC2007"
    ET.SubElement(ann, "filename").text = fname
    size = ET.SubElement(ann, "size")
    ET.SubElement(size, "width").text = str(w)
    ET.SubElement(size, "height").text = str(h)
    ET.SubElement(size, "depth").text = "3"
    ET.SubElement(ann, "segmented").text = "0"
    for cls, (x1, y1, x2, y2) in objs:
        obj = ET.SubElement(ann, "object")
        ET.SubElement(obj, "name").text = cls
        ET.SubElement(obj, "pose").text = "Unspecified"
        ET.SubElement(obj, "truncated").text = "0"
        ET.SubElement(obj, "difficult").text = "0"
        bb = ET.SubElement(obj, "bndbox")
        ET.SubElement(bb, "xmin").text = str(max(1, int(x1)))
        ET.SubElement(bb, "ymin").text = str(max(1, int(y1)))
        ET.SubElement(bb, "xmax").text = str(int(x2))
        ET.SubElement(bb, "ymax").text = str(int(y2))
    ET.ElementTree(ann).write(out_path, encoding="utf-8", xml_declaration=True)

def polygon_to_bbox(poly):
    xs = [p[0] for p in poly]
    ys = [p[1] for p in poly]
    return min(xs), min(ys), max(xs), max(ys)

def collect_ids(img_root, suffix='_leftImg8bit.png'):
    # ... (此函数内容可简化以提高鲁棒性) ...
    ids = {}
    for f in glob.glob(os.path.join(img_root, "**", f"*{suffix}"), recursive=True):
         bn = os.path.basename(f)
         img_id = bn.split("_leftImg8bit")[0]
         ids[img_id] = f
    return ids

# --- 1. 收集所有图像ID和路径 ---
clear_train_ids = collect_ids(os.path.join(CITY_ROOT, "leftImg8bit", "train"))
clear_val_ids   = collect_ids(os.path.join(CITY_ROOT, "leftImg8bit", "val"))

# 查找并收集 foggy 图像
foggy_root_path = os.path.join(CITY_ROOT, "leftImg8bit_foggyDBF")
foggy_train_ids = collect_ids(foggy_root_path, suffix='_leftImg8bit_foggy_beta_0.02.png')

all_image_paths = {**clear_train_ids, **clear_val_ids, **foggy_train_ids}
print(f"总共找到 {len(all_image_paths)} 个唯一的图像ID。")

# --- 2. 收集所有标注文件 ---
gt_files = glob.glob(os.path.join(CITY_ROOT, "gtFine", "**", "*_gtFine_polygons.json"), recursive=True)
gt_index = {os.path.basename(p).replace("_gtFine_polygons.json", ""): p for p in gt_files}
print(f"找到 {len(gt_index)} 个标注文件。")

# --- 3. 转换图像和标注 ---
for img_id, img_src in all_image_paths.items():
    img = cv2.imread(img_src)
    if img is None: continue
    h, w = img.shape[:2]
    
    # 保存为jpg
    out_jpg = os.path.join(OUT_BASE, "JPEGImages", f"{img_id}.jpg")
    cv2.imwrite(out_jpg, img)

    # 转换标注
    objs = []
    if img_id in gt_index:
        with open(gt_index[img_id], "r") as f:
            data = json.load(f)
        for s in data.get("objects", []):
            if s.get("label") in CLASSES:
                objs.append((s["label"], polygon_to_bbox(s["polygon"])))
    
    xml_out = os.path.join(OUT_BASE, "Annotations", f"{img_id}.xml")
    write_voc_xml(img_id, f"{img_id}.jpg", w, h, objs, xml_out)

print("图像和标注转换完成。")

# --- 4. 生成 ImageSets 列表 ---
def dump_list(fn, ids):
    path = os.path.join(OUT_BASE, "ImageSets", "Main", fn + ".txt")
    with open(path, "w") as f:
        f.write("\n".join(sorted(ids)))
    print(f"已写入 {len(ids)} 个ID到 {path}")

dump_list("train_s", clear_train_ids.keys())
dump_list("train_t", foggy_train_ids.keys())
dump_list("test_t",  clear_val_ids.keys())

print("\n数据准备完成！VOC数据集位于:", OUT_BASE)

输出目录 /mnt/lyh/DA-FasterCNN/DA-Faster-RCNN/datasets/cityscape/VOC2007 已准备好。
总共找到 3475 个唯一的图像ID。
找到 5000 个标注文件。
图像和标注转换完成。
已写入 2975 个ID到 /mnt/lyh/DA-FasterCNN/DA-Faster-RCNN/datasets/cityscape/VOC2007/ImageSets/Main/train_s.txt
已写入 3475 个ID到 /mnt/lyh/DA-FasterCNN/DA-Faster-RCNN/datasets/cityscape/VOC2007/ImageSets/Main/train_t.txt
已写入 500 个ID到 /mnt/lyh/DA-FasterCNN/DA-Faster-RCNN/datasets/cityscape/VOC2007/ImageSets/Main/test_t.txt

数据准备完成！VOC数据集位于: /mnt/lyh/DA-FasterCNN/DA-Faster-RCNN/datasets/cityscape/VOC2007
图像和标注转换完成。
已写入 2975 个ID到 /mnt/lyh/DA-FasterCNN/DA-Faster-RCNN/datasets/cityscape/VOC2007/ImageSets/Main/train_s.txt
已写入 3475 个ID到 /mnt/lyh/DA-FasterCNN/DA-Faster-RCNN/datasets/cityscape/VOC2007/ImageSets/Main/train_t.txt
已写入 500 个ID到 /mnt/lyh/DA-FasterCNN/DA-Faster-RCNN/datasets/cityscape/VOC2007/ImageSets/Main/test_t.txt

数据准备完成！VOC数据集位于: /mnt/lyh/DA-FasterCNN/DA-Faster-RCNN/datasets/cityscape/VOC2007


In [8]:
# 注册数据集
# 运行此单元格来清理并重新注册所有需要的数据集。
# 每次启动 notebook/kernel 进行训练或评估前，都需要运行一次。

from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.datasets import register_pascal_voc
from pprint import pprint

def register_datasets(voc_root):
    """
    安全地注销并重新注册所有 Cityscapes VOC 数据集。
    """
    classes = ['car','person','rider','truck','bus','train','motorcycle','bicycle']
    
    # 定义需要注册的数据集和对应的txt文件名
    split_map = {
        "city_trainS": "train_s", # 源域训练集
        "city_trainT": "train_t", # 目标域训练集
        "city_testT": "test_t",   # 目标域测试集
    }

    # 1. 安全地注销已存在的数据集，避免重注册错误
    for name in split_map.keys():
        if name in DatasetCatalog.list():
            DatasetCatalog.remove(name)
        if name in MetadataCatalog.list():
            MetadataCatalog.remove(name)
    
    print(f"已清理旧的注册: {list(split_map.keys())}")

    # 2. 重新注册所有数据集
    for name, split_file in split_map.items():
        register_pascal_voc(name, voc_root, split_file, 2007, classes)
    
    print("\n已重新注册数据集:")
    
    # 3. 打印统计信息以供验证
    for name in split_map.keys():
        try:
            num_samples = len(DatasetCatalog.get(name))
            print(f"- {name}: {num_samples} 个样本")
        except Exception as e:
            print(f"- {name}: 注册失败! {e}")

# --- 执行注册 ---
VOC_ROOT_PATH = "/mnt/lyh/DA-FasterCNN/DA-Faster-RCNN/datasets/cityscape/VOC2007/"
register_datasets(VOC_ROOT_PATH)

# 可选：预览一条数据记录来检查格式是否正确
print("\n预览 'city_trainS' 的第一条记录:")
pprint(DatasetCatalog.get("city_trainS")[0])

已清理旧的注册: ['city_trainS', 'city_trainT', 'city_testT']

已重新注册数据集:
- city_trainS: 2975 个样本
- city_trainS: 2975 个样本
- city_trainT: 3475 个样本
- city_testT: 500 个样本

预览 'city_trainS' 的第一条记录:
- city_trainT: 3475 个样本
- city_testT: 500 个样本

预览 'city_trainS' 的第一条记录:
{'annotations': [{'bbox': [608.0, 419.0, 807.0, 532.0],
                  'bbox_mode': <BoxMode.XYXY_ABS: 0>,
                  'category_id': 0},
                 {'bbox': [144.0, 428.0, 304.0, 502.0],
                  'bbox_mode': <BoxMode.XYXY_ABS: 0>,
                  'category_id': 0},
                 {'bbox': [144.0, 428.0, 304.0, 502.0],
                  'bbox_mode': <BoxMode.XYXY_ABS: 0>,
                  'category_id': 0},
                 {'bbox': [1961.0, 487.0, 2047.0, 526.0],
                  'bbox_mode': <BoxMode.XYXY_ABS: 0>,
                  'category_id': 0},
                 {'bbox': [1511.0, 445.0, 1660.0, 499.0],
                  'bbox_mode': <BoxMode.XYXY_ABS: 0>,
                  'category_id': 0},
    

#Training Loop Definition
Run the following block

In [6]:
logger = logging.getLogger("detectron2")

def do_train(cfg_source, cfg_target, model, resume = False):

    model.train()
    optimizer = build_optimizer(cfg_source, model)
    scheduler = build_lr_scheduler(cfg_source, optimizer)
    checkpointer = DetectionCheckpointer(model, cfg_source.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler)

    start_iter = (checkpointer.resume_or_load(cfg_source.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1)
    max_iter = cfg_source.SOLVER.MAX_ITER

    periodic_checkpointer = PeriodicCheckpointer(checkpointer, cfg_source.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter)
    writers = default_writers(cfg_source.OUTPUT_DIR, max_iter) if comm.is_main_process() else []

    data_loader_source = build_detection_train_loader(cfg_source)
    data_loader_target = build_detection_train_loader(cfg_target)
    logger.info("Starting training from iteration {}".format(start_iter))

    lambda_hyper = 0.1

    with EventStorage(start_iter) as storage:
        for data_source, data_target, iteration in zip(data_loader_source, data_loader_target, range(start_iter, max_iter)):
            storage.iter = iteration

            loss_dict = model(data_source, False, 1)
            loss_dict_target = model(data_target, True, 1)
            
            loss_dict["loss_image_d"] += loss_dict_target["loss_image_d"]
            loss_dict["loss_instance_d"] += loss_dict_target["loss_instance_d"]
            loss_dict["loss_consistency_d"] += loss_dict_target["loss_consistency_d"]

            loss_dict["loss_image_d"] *= ( 0.5 * lambda_hyper)
            loss_dict["loss_instance_d"] *= ( 0.5 * lambda_hyper)
            loss_dict["loss_consistency_d"] *= ( 0.5 * lambda_hyper)

            losses = sum(loss_dict.values())
            assert torch.isfinite(losses).all(), loss_dict

            loss_dict_reduced = {k: v.item() for k, v in comm.reduce_dict(loss_dict).items()}
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())
            if comm.is_main_process():
                storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced)

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
            scheduler.step()

            if iteration - start_iter > 5 and ((iteration + 1) % 50 == 0 or iteration == max_iter - 1):
                for writer in writers:
                    writer.write()
            periodic_checkpointer.step(iteration)

#Configuration Definition
Define the configuration for the source (cfg_source) and target dataset (cfg_target). The cfg_source contains also the parameters which will be used by the network such us:<br>
learning rate, number of training iterations, weight decay, number of classes etc...

## Backbone for Faster RCNN
this implementations works with three kind of backbone:<br> FPN: "COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml"<br>
DC5: "COCO-Detection/faster_rcnn_R_50_DC5_1x.yaml"<br>
C4: "COCO-Detection/faster_rcnn_R_50_C4_1x.yaml"<br>

You can also use their variants such us faster_rcnn_R_101_C4_3x, faster_rcnn_R_50_DC5_3x, faster_rcnn_R_101_DC5_3x, etc...





In [7]:
cfg_source = get_cfg()
cfg_source.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml"))
cfg_source.DATASETS.TRAIN = ("city_trainS",)
cfg_source.DATALOADER.NUM_WORKERS = 2
cfg_source.MODEL.WEIGHTS = "/mnt/lyh/DA-FasterCNN/weights/COCO-Detection/faster_rcnn_R_50_FPN_1x/model_final_b275ba.pkl"
#cfg_source.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml")
cfg_source.SOLVER.IMS_PER_BATCH = 4
cfg_source.SOLVER.BASE_LR = 0.0005
cfg_source.SOLVER.WARMUP_FACTOR = 1.0 / 100
cfg_source.SOLVER.WARMUP_ITERS = 1000
cfg_source.SOLVER.MAX_ITER = 5000
cfg_source.INPUT.MIN_SIZE_TRAIN = (600,)
cfg_source.INPUT.MIN_SIZE_TEST = 0
os.makedirs(cfg_source.OUTPUT_DIR, exist_ok=True)
cfg_source.MODEL.ROI_HEADS.NUM_CLASSES = 8
model = build_model(cfg_source)

cfg_target = get_cfg()
cfg_target.DATALOADER.FILTER_EMPTY_ANNOTATIONS = False
cfg_target.DATASETS.TRAIN = ("city_trainT",)
cfg_target.INPUT.MIN_SIZE_TRAIN = (600,)
cfg_target.DATALOADER.NUM_WORKERS = 0
cfg_target.SOLVER.IMS_PER_BATCH = 4

build_resnet_fpn_backbone


In [9]:
do_train(cfg_source,cfg_target,model)

[32m[10/21 08:50:35 d2.checkpoint.detection_checkpoint]: [0m[DetectionCheckpointer] Loading from /mnt/lyh/DA-FasterCNN/weights/COCO-Detection/faster_rcnn_R_50_FPN_1x/model_final_b275ba.pkl ...
[32m[10/21 08:50:35 d2.checkpoint.detection_checkpoint]: [0m[DetectionCheckpointer] Loading from /mnt/lyh/DA-FasterCNN/weights/COCO-Detection/faster_rcnn_R_50_FPN_1x/model_final_b275ba.pkl ...


Skip loading parameter 'roi_heads.box_predictor.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (9, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (9,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.weight' to the model due to incompatible shapes: (320, 1024) in the checkpoint but (32, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.bias' to the model due to incompatible shapes: (320,) in the checkpoint but (32,) in the model! You might want to double check if this is expected.
Some model parameters or buffers are not found in the checkpoint:
[34mdiscriminator.net.0.weight[0m
[34mdiscriminator.net.2.weight[0m
[34mdiscriminator

[32m[10/21 08:50:35 d2.data.build]: [0mRemoved 10 images with no usable annotations. 2965 images left.
[32m[10/21 08:50:35 d2.data.build]: [0mDistribution of instances among all 8 categories:
[36m|  category  | #instances   |  category  | #instances   |  category  | #instances   |
|:----------:|:-------------|:----------:|:-------------|:----------:|:-------------|
|    car     | 27155        |   person   | 17994        |   rider    | 1807         |
|   truck    | 489          |    bus     | 385          |   train    | 171          |
| motorcycle | 739          |  bicycle   | 3729         |            |              |
|   total    | 52469        |            |              |            |              |[0m
[32m[10/21 08:50:35 d2.data.dataset_mapper]: [0m[DatasetMapper] Augmentations used in training: [ResizeShortestEdge(short_edge_length=(600,), max_size=1333, sample_style='choice'), RandomFlip()]
[32m[10/21 08:50:35 d2.data.build]: [0mUsing training sampler TrainingSampler
[

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[32m[10/21 08:50:54 d2.utils.events]: [0m iter: 49  total_loss: 3.013  loss_cls: 1.699  loss_box_reg: 0.687  loss_rpn_cls: 0.1223  loss_rpn_loc: 0.2381  loss_image_d: 0.06936  loss_instance_d: 0.06531  loss_consistency_d: 0.0005173     lr: 2.9255e-05  max_mem: 4885M
[32m[10/21 08:51:12 d2.utils.events]: [0m eta: 0:29:47  iter: 99  total_loss: 1.995  loss_cls: 0.7895  loss_box_reg: 0.6646  loss_rpn_cls: 0.09943  loss_rpn_loc: 0.3001  loss_image_d: 0.06927  loss_instance_d: 0.06101  loss_consistency_d: 0.0008913     lr: 5.4005e-05  max_mem: 4885M
[32m[10/21 08:51:12 d2.utils.events]: [0m eta: 0:29:47  iter: 99  total_loss: 1.995  loss_cls: 0.7895  loss_box_reg: 0.6646  loss_rpn_cls: 0.09943  loss_rpn_loc: 0.3001  loss_image_d: 0.06927  loss_instance_d: 0.06101  loss_consistency_d: 0.0008913     lr: 5.4005e-05  max_mem: 4885M
[32m[10/21 08:51:31 d2.utils.events]: [0m eta: 0:29:28  iter: 149  total_loss: 1.646  loss_cls: 0.5828  loss_box_reg: 0.6551  loss_rpn_cls: 0.08639  loss_rpn

##Evalutate the performance
runt the PascalVOCDetectionEvaluator if your annotations are in PASCAL VOC otherwhise run the COCOEvaluator<br>

The mAP50 is the object detection result on the dataset. In this case, for the cityscape dataset, the result is 39.5%


In [10]:
#PASCAL VOC evaluation
from detectron2.evaluation import inference_on_dataset, PascalVOCDetectionEvaluator
from detectron2.data import build_detection_test_loader
evaluator = PascalVOCDetectionEvaluator("city_testT")
val_loader = build_detection_test_loader(cfg_source, "city_testT")
res = inference_on_dataset(model, val_loader, evaluator)
print(res)

[32m[10/21 09:40:20 d2.data.build]: [0mDistribution of instances among all 8 categories:
[36m|  category  | #instances   |  category  | #instances   |  category  | #instances   |
|:----------:|:-------------|:----------:|:-------------|:----------:|:-------------|
|    car     | 4667         |   person   | 3419         |   rider    | 556          |
|   truck    | 93           |    bus     | 98           |   train    | 23           |
| motorcycle | 149          |  bicycle   | 1175         |            |              |
|   total    | 10180        |            |              |            |              |[0m
[32m[10/21 09:40:20 d2.data.dataset_mapper]: [0m[DatasetMapper] Augmentations used in inference: [ResizeShortestEdge(short_edge_length=(0, 0), max_size=1333, sample_style='choice')]
[32m[10/21 09:40:20 d2.data.common]: [0mSerializing the dataset using: <class 'detectron2.data.common._TorchSerializedList'>
[32m[10/21 09:40:20 d2.data.common]: [0mSerializing 500 elements to byt

In [None]:
# 短测训练（使用本地权重）：加载本地 model_final_b275ba.pkl 并运行小规模训练
print("=== 短测训练（本地权重）===")
# 覆盖 cfg 为小规模并指向本地权重
cfg_source.MODEL.WEIGHTS = "/mnt/lyh/DA-FasterCNN/model_final_b275ba.pkl"
cfg_source.SOLVER.MAX_ITER = 50
cfg_source.SOLVER.IMS_PER_BATCH = 2
cfg_source.DATALOADER.NUM_WORKERS = 0
cfg_source.SOLVER.CHECKPOINT_PERIOD = 50

cfg_target.SOLVER.IMS_PER_BATCH = 2

# 重新构建模型
model = build_model(cfg_source)
print("Model built with local weights")

# 运行短测训练
try:
    do_train(cfg_source, cfg_target, model, resume=False)
    print("短测训练完成")
except Exception as e:
    import traceback
    print("短测训练失败:")
    traceback.print_exc()