# 模块导入 

## 介绍：本系统采用两阶段联合训练框架。主干使用基于DSFD的人脸检测模型，同时引入IAT高曝光图像矫正模型作为曝光伪标签生成器。训练过程中，IAT模型提供高曝光场景下的曝光修正结果作为辅助监督信号，通过L1增强损失与检测主损失联合优化，提升检测模型在极端曝光条件下的鲁棒性与泛化能力。整体训练在多GPU分布式环境中完成，并结合多阶段学习率调整与损失平衡策略控制训练稳定性。

In [None]:
import os
import time
import random
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torchmetrics.functional import structural_similarity_index_measure as ssim
from torch.utils.tensorboard import SummaryWriter

import torch.distributed as dist

from models.enhancer import RetinexNet
from models.factory import build_net, basenet_factory
from layers.modules import EnhanceLoss
from data.config import cfg
from data.widerface import WIDERDetection, detection_collate
from model.IAT_main import IAT
from utils.brightgISP import Low_Illumination_Degrading

import sys
sys.path.append('/data1/home/chenruoyu/IAD-Net/jiandan/IAT_enhance')

# 训练参数配置

In [None]:
parser = argparse.ArgumentParser(description='增强模型训练')

parser.add_argument('--batch_size', default=4, type=int)
parser.add_argument('--resume', default=None, type=str)
parser.add_argument('--num_workers', default=0, type=int)
parser.add_argument('--cuda', default=True, type=bool)
parser.add_argument('--lr', default=5e-4, type=float)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--weight_decay', default=5e-4, type=float)
parser.add_argument('--gamma', default=0.1, type=float)
parser.add_argument('--multigpu', default=True, type=bool)
parser.add_argument('--save_folder', default='weights/')
parser.add_argument('--local-rank', type=int)

args = parser.parse_args()
local_rank = args.local_rank
os.environ['LOCAL_RANK'] = str(args.local_rank)

# 分布式训练初始化

In [None]:
dist.init_process_group(backend='nccl')

if torch.cuda.is_available():
    if args.cuda:
        gpu_num = torch.cuda.device_count()
        rank = int(os.environ['RANK'])
        torch.cuda.set_device(rank % gpu_num)
    else:
        torch.set_default_tensor_type('torch.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')

# 数据集与数据加载器

In [None]:
save_folder = os.path.join(args.save_folder, 'IAT')
if not os.path.exists(save_folder):
    os.mkdir(save_folder)

train_dataset = WIDERDetection(cfg.FACE.TRAIN_FILE, mode='train')
val_dataset = WIDERDetection(cfg.FACE.VAL_FILE, mode='val')

train_sampler = data.distributed.DistributedSampler(train_dataset, shuffle=True)
val_sampler = data.distributed.DistributedSampler(val_dataset, shuffle=True)

train_loader = data.DataLoader(train_dataset, args.batch_size, num_workers=args.num_workers,
                                collate_fn=detection_collate, sampler=train_sampler, pin_memory=True)
val_loader = data.DataLoader(val_dataset, args.batch_size, num_workers=0,
                              collate_fn=detection_collate, sampler=val_sampler, pin_memory=True)

# 模型初始化部分

In [None]:
def train():
    writer = None
    per_epoch_size = len(train_dataset) // (args.batch_size * torch.cuda.device_count())
    start_epoch, iteration, step_index = 0, 0, 0

    log_dir = os.path.join('runs', 'IAT', time.strftime('%Y%m%d-%H%M%S'))
    if local_rank == 0 and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if local_rank == 0:
        writer = SummaryWriter(log_dir=log_dir)

    basenet = basenet_factory('dark')
    dsfd_net = build_net('train', cfg.NUM_CLASSES, 'dark')
    net = dsfd_net

    IAT_model = IAT(type='exp').cuda()
    IAT_model.load_state_dict(torch.load('/data1/home/chenruoyu/IAD-Net/jiandan/IAT_enhance/quanzhong/1.pth'))
    IAT_model.eval()

    if args.resume:
        start_epoch = net.load_weights(args.resume)
        iteration = start_epoch * per_epoch_size
    else:
        base_weights = torch.load(args.save_folder + basenet)
        if args.model == 'vgg' or args.model == 'dark':
            net.vgg.load_state_dict(base_weights)
        else:
            net.resnet.load_state_dict(base_weights)

# 权重初始化与优化器配置

In [None]:
  if not args.resume:
        net.extras.apply(net.weights_init)
        net.fpn_topdown.apply(net.weights_init)
        net.fpn_latlayer.apply(net.weights_init)
        net.fpn_fem.apply(net.weights_init)
        net.ref.apply(net.weights_init)

    lr = args.lr * np.round(np.sqrt(args.batch_size / 4 * torch.cuda.device_count()), 4)
    lr = lr * 0.001

    param_group = []
    param_group += [{'params': dsfd_net.vgg.parameters(), 'lr': lr}]
    param_group += [{'params': dsfd_net.extras.parameters(), 'lr': lr}]
    param_group += [{'params': dsfd_net.fpn_topdown.parameters(), 'lr': lr}]
    param_group += [{'params': dsfd_net.fpn_latlayer.parameters(), 'lr': lr}]
    param_group += [{'params': dsfd_net.fpn_fem.parameters(), 'lr': lr}]
    param_group += [{'params': dsfd_net.ref.parameters(), 'lr': lr / 10.}]
    optimizer = optim.SGD(param_group, lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)

    if args.cuda and args.multigpu:
        net = torch.nn.parallel.DistributedDataParallel(dsfd_net.cuda(), find_unused_parameters=True)
        cudnn.benchmark = True

    criterion_enhance = EnhanceLoss()

# 训练主循环

In [None]:
 net.train()
    for epoch in range(start_epoch, cfg.EPOCHES):
        losses = 0
        for batch_idx, (images, targets, _) in enumerate(train_loader):
            images = Variable(images.cuda())
            targets = [Variable(ann.cuda(), requires_grad=False) for ann in targets]

            with torch.no_grad():
                _, _, R_light_gt = IAT_model(images / 255.)

            out = net(images)
            loss_l, loss_c = 0, 0  # 这里只保留增强部分
            detection_loss = 0
            enhance_loss = F.l1_loss(out['enhanced'], R_light_gt.detach())
            total_loss = detection_loss + 0.1 * enhance_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            losses += total_loss.item()

            if iteration % 100 == 0 and local_rank == 0:
                tloss = losses / (batch_idx + 1)
                print('epoch:{} iter:{} Loss:{:.4f}'.format(epoch, iteration, tloss))

                if writer:
                    writer.add_scalar('Loss/train', tloss, iteration)

            iteration += 1

        val(epoch, net, dsfd_net, IAT_model, writer)

# 验证函数

In [None]:
def val(epoch, net, dsfd_net, IAT_model, writer=None):
    net.eval()
    step = 0
    losses = torch.tensor(0.).cuda()

    for batch_idx, (images, targets, img_paths) in enumerate(val_loader):
        images = Variable(images.cuda() / 255.)
        img_dark = torch.stack([Low_Illumination_Degrading(images[i])[0] for i in range(images.shape[0])], dim=0)

        with torch.no_grad():
            _, R_val = net.module.test_forward(img_dark)
            _, _, R_dark_gt_val = IAT_model(img_dark)
            loss_val_enhance = F.l1_loss(R_val, R_dark_gt_val.detach()) + (1. - ssim(R_val, R_dark_gt_val.detach()))
            loss = loss_val_enhance

        losses += loss.item()
        step += 1

    dist.reduce(losses, 0, op=dist.ReduceOp.SUM)
    tloss = losses / step / torch.cuda.device_count()

    if local_rank == 0:
        print('验证集: epoch:{} Loss:{:.4f}'.format(epoch, tloss))
        if writer:
            writer.add_scalar('Loss/val', tloss, epoch)

# 学习率动态调整

In [None]:
def adjust_learning_rate(optimizer, gamma, step):
    for param_group in optimizer.param_groups:
        param_group['lr'] *= gamma
if __name__ == '__main__':
    train()