# 基本模块导入

In [None]:
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function

import os
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.autograd import Variable, Function

from layers import *  # DSFD中使用的工具层，如 L2Norm, PriorBox, Detect 等
from data.config import cfg  # 配置文件，包含模型参数、训练参数等

# 插值模块

In [None]:
class Interpolate(nn.Module):
    def __init__(self, scale_factor):
        super(Interpolate, self).__init__()
        self.scale_factor = scale_factor

    def forward(self, x):
        return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')

# 特征增强模块 (FEM)

In [None]:
class FEM(nn.Module):
    """
    输入特征通道被均匀划分到三路分支：
      - 分支一：单层膨胀卷积
      - 分支二：两层膨胀卷积堆叠
      - 分支三：三层膨胀卷积堆叠
    多路特征融合增强特征提取能力
    """
    def __init__(self, in_planes):
        super(FEM, self).__init__()
        inter_planes = in_planes // 3
        inter_planes1 = in_planes - 2 * inter_planes

        self.branch1 = nn.Conv2d(in_planes, inter_planes, kernel_size=3, padding=3, dilation=3)

        self.branch2 = nn.Sequential(
            nn.Conv2d(in_planes, inter_planes, kernel_size=3, padding=3, dilation=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(inter_planes, inter_planes, kernel_size=3, padding=3, dilation=3)
        )

        self.branch3 = nn.Sequential(
            nn.Conv2d(in_planes, inter_planes1, kernel_size=3, padding=3, dilation=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(inter_planes1, inter_planes1, kernel_size=3, padding=3, dilation=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(inter_planes1, inter_planes1, kernel_size=3, padding=3, dilation=3)
        )

    def forward(self, x):
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x1, x2, x3), dim=1)
        return F.relu(out, inplace=True)

# 整体 DSFD 主体架构 (包含增强+检测联合)

In [None]:
class DSFD(nn.Module):
    """
    DSFD主干网络：
    - 基于VGG的基础特征提取器
    - 两阶段（PAL1 / PAL2）检测头
    - 特征金字塔模块 (FPN + FEM)
    - 反射增强模块 (Reflectance Decoder)
    - 蒸馏对齐模块 (KL Divergence)
    """
    def __init__(self, phase, base, extras, fem, head1, head2, num_classes):
        super(DSFD, self).__init__()
        self.phase = phase  # 训练 or 测试
        self.num_classes = num_classes

        # VGG主干网络
        self.vgg = nn.ModuleList(base)

        # 多层特征归一化 (L2Norm，防止数值震荡)
        self.L2Normof1 = L2Norm(256, 10)
        self.L2Normof2 = L2Norm(512, 8)
        self.L2Normof3 = L2Norm(512, 5)

        # 扩展额外层（更深层特征提取）
        self.extras = nn.ModuleList(extras)

        # FPN顶向下(top-down)融合 + 侧向连接(lateral) + FEM模块
        self.fpn_topdown = nn.ModuleList(fem[0])
        self.fpn_latlayer = nn.ModuleList(fem[1])
        self.fpn_fem = nn.ModuleList(fem[2])

        # 融合后的特征图再次L2归一化
        self.L2Normef1 = L2Norm(256, 10)
        self.L2Normef2 = L2Norm(512, 8)
        self.L2Normef3 = L2Norm(512, 5)

        # 两阶段检测分支（PAL1、PAL2）
        self.loc_pal1 = nn.ModuleList(head1[0])
        self.conf_pal1 = nn.ModuleList(head1[1])
        self.loc_pal2 = nn.ModuleList(head2[0])
        self.conf_pal2 = nn.ModuleList(head2[1])

        # 🌟 反射增强解码分支 (Reflectance Decoder)
        self.ref = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            Interpolate(2),  # 上采样放大两倍
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Sigmoid()     # 输出范围限制在[0,1]
        )

        # 🌟 蒸馏损失模块 (用于增强特征对齐)
        self.KL = DistillKL(T=4.0)

        # 测试阶段加入softmax分类器和最终Detect解码器
        if self.phase == 'test':
            self.softmax = nn.Softmax(dim=-1)
            self.detect = Detect(cfg)

# 自定义特征上采样并乘法融合函数（FPN用）

In [None]:
def _upsample_prod(self, x, y):
        _, _, H, W = y.size()
        return F.upsample(x, size=(H, W), mode='bilinear') * y

# 只用于增强解码（纯反射分支推理）

In [None]:
 def enh_forward(self, x):
        x = x[:1]  # 仅处理第一张图片，通常训练单张增强
        for k in range(5):
            x = self.vgg[k](x)
        R = self.ref(x)
        return R

# 完整测试前向推理 (检测与增强联合输出)

In [None]:
  def test_forward(self, x):
        size = x.size()[2:]  # 输入图像尺寸
        pal1_sources, pal2_sources = [], []  # 两阶段特征缓存
        loc_pal1, conf_pal1 = [], []         # PAL1预测输出
        loc_pal2, conf_pal2 = [], []         # PAL2预测输出

        # ===== VGG 特征抽取阶段 =====
        for k in range(16):
            x = self.vgg[k](x)
            if k == 4:  # conv4_3特征提前提取，用于增强模块
                x_ = x
        R = self.ref(x_[0:1])  # 增强模块反射解码输出

        of1 = x
        s = self.L2Normof1(of1)
        pal1_sources.append(s)

        for k in range(16, 23):
            x = self.vgg[k](x)
        of2 = x
        s = self.L2Normof2(of2)
        pal1_sources.append(s)

        for k in range(23, 30):
            x = self.vgg[k](x)
        of3 = x
        s = self.L2Normof3(of3)
        pal1_sources.append(s)

        for k in range(30, len(self.vgg)):
            x = self.vgg[k](x)
        of4 = x
        pal1_sources.append(of4)

        # ===== Extras 层提取更深层特征 =====
        for k in range(2):
            x = F.relu(self.extras[k](x), inplace=True)
        of5 = x
        pal1_sources.append(of5)

        for k in range(2, 4):
            x = F.relu(self.extras[k](x), inplace=True)
        of6 = x
        pal1_sources.append(of6)

        # ===== FPN特征金字塔自顶向下融合 =====
        conv7 = F.relu(self.fpn_topdown[0](of6), inplace=True)
        x = F.relu(self.fpn_topdown[1](conv7), inplace=True)
        conv6 = F.relu(self._upsample_prod(x, self.fpn_latlayer[0](of5)), inplace=True)

        x = F.relu(self.fpn_topdown[2](conv6), inplace=True)
        convfc7_2 = F.relu(self._upsample_prod(x, self.fpn_latlayer[1](of4)), inplace=True)

        x = F.relu(self.fpn_topdown[3](convfc7_2), inplace=True)
        conv5 = F.relu(self._upsample_prod(x, self.fpn_latlayer[2](of3)), inplace=True)

        x = F.relu(self.fpn_topdown[4](conv5), inplace=True)
        conv4 = F.relu(self._upsample_prod(x, self.fpn_latlayer[3](of2)), inplace=True)

        x = F.relu(self.fpn_topdown[5](conv4), inplace=True)
        conv3 = F.relu(self._upsample_prod(x, self.fpn_latlayer[4](of1)), inplace=True)

        # ===== 融合后通过FEM多尺度模块 =====
        ef1 = self.L2Normef1(self.fpn_fem[0](conv3))
        ef2 = self.L2Normef2(self.fpn_fem[1](conv4))
        ef3 = self.L2Normef3(self.fpn_fem[2](conv5))
        ef4 = self.fpn_fem[3](convfc7_2)
        ef5 = self.fpn_fem[4](conv6)
        ef6 = self.fpn_fem[5](conv7)

        pal2_sources = (ef1, ef2, ef3, ef4, ef5, ef6)

        # ===== PAL1 / PAL2 检测头输出特征转换 =====
        for (x, l, c) in zip(pal1_sources, self.loc_pal1, self.conf_pal1):
            loc_pal1.append(l(x).permute(0, 2, 3, 1).contiguous())
            conf_pal1.append(c(x).permute(0, 2, 3, 1).contiguous())

        for (x, l, c) in zip(pal2_sources, self.loc_pal2, self.conf_pal2):
            loc_pal2.append(l(x).permute(0, 2, 3, 1).contiguous())
            conf_pal2.append(c(x).permute(0, 2, 3, 1).contiguous())

        # ===== 统一Flatten为检测头最终预测格式 =====
        features_maps = []
        for i in range(len(loc_pal1)):
            feat = [loc_pal1[i].size(1), loc_pal1[i].size(2)]
            features_maps.append(feat)

        loc_pal1 = torch.cat([o.view(o.size(0), -1) for o in loc_pal1], 1)
        conf_pal1 = torch.cat([o.view(o.size(0), -1) for o in conf_pal1], 1)
        loc_pal2 = torch.cat([o.view(o.size(0), -1) for o in loc_pal2], 1)
        conf_pal2 = torch.cat([o.view(o.size(0), -1) for o in conf_pal2], 1)

        # ===== PriorBox 生成锚框 =====
        priorbox = PriorBox(size, features_maps, cfg, pal=1)
        self.priors_pal1 = Variable(priorbox.forward(), volatile=True)
        priorbox = PriorBox(size, features_maps, cfg, pal=2)
        self.priors_pal2 = Variable(priorbox.forward(), volatile=True)

        # ===== 测试与训练两种输出逻辑分支 =====
        if self.phase == 'test':
            output = self.detect.forward(
                loc_pal2.view(loc_pal2.size(0), -1, 4),
                self.softmax(conf_pal2.view(conf_pal2.size(0), -1, self.num_classes)),
                self.priors_pal2.type(type(x.data))
            )
        else:
            output = (
                loc_pal1.view(loc_pal1.size(0), -1, 4),
                conf_pal1.view(conf_pal1.size(0), -1, self.num_classes),
                self.priors_pal1,
                loc_pal2.view(loc_pal2.size(0), -1, 4),
                conf_pal2.view(conf_pal2.size(0), -1, self.num_classes),
                self.priors_pal2
            )

        return output, R

# 训练阶段完整前向传播（增强与检测联合训练）

In [None]:
 def forward(self, x, x_light, I, I_light):
        """
        x: 暗光图像（原图）
        x_light: 亮光图像（伪标签）
        I, I_light: Retinex分解得到的照度图（伪GT）
        """
        size = x.size()[2:]
        pal1_sources, pal2_sources = [], []
        loc_pal1, conf_pal1 = [], []
        loc_pal2, conf_pal2 = [], []

        # ===== 亮图VGG特征提取（亮光增强分支） =====
        for k in range(5):
            x_light = self.vgg[k](x_light)

        # ===== 暗图VGG特征提取（暗光增强分支） =====
        for k in range(16):
            x = self.vgg[k](x)
            if k == 4:
                x_dark = x  # conv4_3特征存储用于增强分支

        # ===== 曝光增强解码器（第一轮反射重建） =====
        R_dark = self.ref(x_dark)
        R_light = self.ref(x_light)

        # ===== 互换交叉增强，生成二次伪增强图 =====
        x_dark_2 = (I * R_light).detach()
        x_light_2 = (I_light * R_dark).detach()

        for k in range(5):
            x_light_2 = self.vgg[k](x_light_2)
        for k in range(5):
            x_dark_2 = self.vgg[k](x_dark_2)

        # ===== 曝光增强解码器（二次反射重建） =====
        R_dark_2 = self.ref(x_light_2)
        R_light_2 = self.ref(x_dark_2)

        # ===== KL蒸馏损失计算（互信息对齐） =====
        x_light = x_light.flatten(start_dim=2).mean(dim=-1)
        x_dark = x_dark.flatten(start_dim=2).mean(dim=-1)
        x_light_2 = x_light_2.flatten(start_dim=2).mean(dim=-1)
        x_dark_2 = x_dark_2.flatten(start_dim=2).mean(dim=-1)

        loss_mutual = cfg.WEIGHT.MC * (
            self.KL(x_light, x_dark) + self.KL(x_dark, x_light)
            + self.KL(x_light_2, x_dark_2) + self.KL(x_dark_2, x_light_2)
        )

        # ===== 后续标准检测分支逻辑完全沿用PAL结构 =====

        of1 = x
        s = self.L2Normof1(of1)
        pal1_sources.append(s)

        for k in range(16, 23):
            x = self.vgg[k](x)
        of2 = x
        s = self.L2Normof2(of2)
        pal1_sources.append(s)

        for k in range(23, 30):
            x = self.vgg[k](x)
        of3 = x
        s = self.L2Normof3(of3)
        pal1_sources.append(s)

        for k in range(30, len(self.vgg)):
            x = self.vgg[k](x)
        of4 = x
        pal1_sources.append(of4)

        for k in range(2):
            x = F.relu(self.extras[k](x), inplace=True)
        of5 = x
        pal1_sources.append(of5)
        for k in range(2, 4):
            x = F.relu(self.extras[k](x), inplace=True)
        of6 = x
        pal1_sources.append(of6)

        conv7 = F.relu(self.fpn_topdown[0](of6), inplace=True)
        x = F.relu(self.fpn_topdown[1](conv7), inplace=True)
        conv6 = F.relu(self._upsample_prod(x, self.fpn_latlayer[0](of5)), inplace=True)
        x = F.relu(self.fpn_topdown[2](conv6), inplace=True)
        convfc7_2 = F.relu(self._upsample_prod(x, self.fpn_latlayer[1](of4)), inplace=True)
        x = F.relu(self.fpn_topdown[3](convfc7_2), inplace=True)
        conv5 = F.relu(self._upsample_prod(x, self.fpn_latlayer[2](of3)), inplace=True)
        x = F.relu(self.fpn_topdown[4](conv5), inplace=True)
        conv4 = F.relu(self._upsample_prod(x, self.fpn_latlayer[3](of2)), inplace=True)
        x = F.relu(self.fpn_topdown[5](conv4), inplace=True)
        conv3 = F.relu(self._upsample_prod(x, self.fpn_latlayer[4](of1)), inplace=True)

        ef1 = self.L2Normef1(self.fpn_fem[0](conv3))
        ef2 = self.L2Normef2(self.fpn_fem[1](conv4))
        ef3 = self.L2Normef3(self.fpn_fem[2](conv5))
        ef4 = self.fpn_fem[3](convfc7_2)
        ef5 = self.fpn_fem[4](conv6)
        ef6 = self.fpn_fem[5](conv7)

        pal2_sources = (ef1, ef2, ef3, ef4, ef5, ef6)

        for (x, l, c) in zip(pal1_sources, self.loc_pal1, self.conf_pal1):
            loc_pal1.append(l(x).permute(0, 2, 3, 1).contiguous())
            conf_pal1.append(c(x).permute(0, 2, 3, 1).contiguous())

        for (x, l, c) in zip(pal2_sources, self.loc_pal2, self.conf_pal2):
            loc_pal2.append(l(x).permute(0, 2, 3, 1).contiguous())
            conf_pal2.append(c(x).permute(0, 2, 3, 1).contiguous())

        features_maps = []
        for i in range(len(loc_pal1)):
            feat = [loc_pal1[i].size(1), loc_pal1[i].size(2)]
            features_maps.append(feat)

        loc_pal1 = torch.cat([o.view(o.size(0), -1) for o in loc_pal1], 1)
        conf_pal1 = torch.cat([o.view(o.size(0), -1) for o in conf_pal1], 1)
        loc_pal2 = torch.cat([o.view(o.size(0), -1) for o in loc_pal2], 1)
        conf_pal2 = torch.cat([o.view(o.size(0), -1) for o in conf_pal2], 1)

        priorbox = PriorBox(size, features_maps, cfg, pal=1)
        self.priors_pal1 = Variable(priorbox.forward(), volatile=True)
        priorbox = PriorBox(size, features_maps, cfg, pal=2)
        self.priors_pal2 = Variable(priorbox.forward(), volatile=True)

        if self.phase == 'test':
            output = self.detect.forward(
                loc_pal2.view(loc_pal2.size(0), -1, 4),
                self.softmax(conf_pal2.view(conf_pal2.size(0), -1, self.num_classes)),
                self.priors_pal2.type(type(x.data))
            )
        else:
            output = (
                loc_pal1.view(loc_pal1.size(0), -1, 4),
                conf_pal1.view(loc_pal1.size(0), -1, self.num_classes),
                self.priors_pal1,
                loc_pal2.view(loc_pal2.size(0), -1, 4),
                conf_pal2.view(loc_pal2.size(0), -1, self.num_classes),
                self.priors_pal2
            )

        return output, [R_dark, R_light, R_dark_2, R_light_2], loss_mutual

# 预训练模型权重加载函数

In [None]:
def load_weights(self, base_file):
        """
        从文件加载模型权重
        """
        other, ext = os.path.splitext(base_file)
        if ext == '.pkl' or '.pth':
            print('Loading weights into state dict...')
            mdata = torch.load(base_file, map_location=lambda storage, loc: storage)
            epoch = 50  # 默认训练epoch初始化为50
            self.load_state_dict(mdata)
            print('Finished!')
        else:
            print('只支持 .pth 和 .pkl 格式权重文件')
        return epoch

# Xavier初始化（所有卷积层统一初始化）

In [None]:
  def xavier(self, param):
        init.xavier_uniform_(param)

    def weights_init(self, m):
        """
        网络模块初始化权重
        """
        if isinstance(m, nn.Conv2d):
            self.xavier(m.weight.data)
            m.bias.data.zero_()

        if isinstance(m, nn.ConvTranspose2d):
            self.xavier(m.weight.data)
            if 'bias' in m.state_dict().keys():
                m.bias.data.zero_()

        if isinstance(m, nn.BatchNorm2d):
            m.weight.data[...] = 1
            m.bias.data.zero_()

# 配置模块：基础网络层级配置

In [None]:
vgg_cfg = [
    64, 64, 'M', 
    128, 128, 'M', 
    256, 256, 256, 'C', 
    512, 512, 512, 'M',
    512, 512, 512, 'M'
]

extras_cfg = [256, 'S', 512, 128, 'S', 256]  # 扩展额外检测层
fem_cfg = [256, 512, 512, 1024, 512, 256]    # 特征金字塔增强模块FEM配置

# FEM模块构造函数

In [None]:
def fem_module(cfg):
    topdown_layers, lat_layers, fem_layers = [], [], []
    topdown_layers += [nn.Conv2d(cfg[-1], cfg[-1], kernel_size=1)]
    for k, v in enumerate(cfg):
        fem_layers += [FEM(v)]
        cur_channel = cfg[len(cfg) - 1 - k]
        if len(cfg) - 1 - k > 0:
            last_channel = cfg[len(cfg) - 2 - k]
            topdown_layers += [nn.Conv2d(cur_channel, last_channel, kernel_size=1)]
            lat_layers += [nn.Conv2d(last_channel, last_channel, kernel_size=1)]
    return (topdown_layers, lat_layers, fem_layers)

#  VGG主干网络构造函数

In [None]:
def vgg(cfg, i, batch_norm=False):
    layers = []
    in_channels = i
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(2, 2)]
        elif v == 'C':
            layers += [nn.MaxPool2d(2, 2, ceil_mode=True)]
        else:
            conv2d = nn.Conv2d(in_channels, v, 3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    layers += [
        nn.Conv2d(512, 1024, kernel_size=3, padding=3, dilation=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(1024, 1024, kernel_size=1),
        nn.ReLU(inplace=True)
    ]
    return layers

# Extras扩展检测头构造函数

In [None]:
def add_extras(cfg, i, batch_norm=False):
    layers = []
    in_channels = i
    flag = False
    for k, v in enumerate(cfg):
        if in_channels != 'S':
            if v == 'S':
                layers += [nn.Conv2d(in_channels, cfg[k + 1],
                                     kernel_size=(1, 3)[flag], stride=2, padding=1)]
            else:
                layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])]
            flag = not flag
        in_channels = v
    return layers

# 多尺度检测头模块构造器（PAL1 / PAL2）

In [None]:
def multibox(vgg, extra_layers, num_classes):
    loc_layers, conf_layers = [], []
    vgg_source = [14, 21, 28, -2]

    for v in vgg_source:
        loc_layers += [nn.Conv2d(vgg[v].out_channels, 4, kernel_size=3, padding=1)]
        conf_layers += [nn.Conv2d(vgg[v].out_channels, num_classes, kernel_size=3, padding=1)]

    for v in extra_layers[1::2]:
        loc_layers += [nn.Conv2d(v.out_channels, 4, kernel_size=3, padding=1)]
        conf_layers += [nn.Conv2d(v.out_channels, num_classes, kernel_size=3, padding=1)]

    return (loc_layers, conf_layers)


# 完整模型构造函数（统一入口）

In [None]:
def build_net_dark(phase, num_classes=2):
    base = vgg(vgg_cfg, 3)
    extras = add_extras(extras_cfg, 1024)
    head1 = multibox(base, extras, num_classes)
    head2 = multibox(base, extras, num_classes)
    fem = fem_module(fem_cfg)
    return DSFD(phase, base, extras, fem, head1, head2, num_classes)

# 蒸馏对齐KL散度损失函数

In [None]:
class DistillKL(nn.Module):
    def __init__(self, T):
        super(DistillKL, self).__init__()
        self.T = T

    def forward(self, y_s, y_t):
        p_s = F.log_softmax(y_s / self.T, dim=1)
        p_t = F.softmax(y_t / self.T, dim=1)
        loss = F.kl_div(p_s, p_t, reduction='batchmean') * (self.T ** 2)
        return loss