In [1]:
import os
import argparse
import numpy as np
from einops import rearrange
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as utils
import pytorch_ssim
import time 
from torch.autograd import Variable
from torch.utils.data import DataLoader

from torch.nn.modules.loss import _Loss 
#from net.Ushape_Trans import *
#from dataset import prepare_data, Dataset
#from net.utils import *

import cv2
import matplotlib.pyplot as plt
from utility import plots as plots, ptcolor as ptcolor, ptutils as ptutils, data as data
from loss.LAB import *
from loss.LCH import *
from loss.VGG19_PercepLoss import *
from torchvision.utils import save_image

import warnings
warnings.filterwarnings('ignore')
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"  # 允许程序在存在重复的 OpenMP 库时继续运行

# 设置 GPU 使用和默认 tensor 类型
dtype = 'float32'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'  # 设置一个 环境变量，告诉程序只使用第 0 号 GPU
torch.set_default_tensor_type(torch.FloatTensor)

In [2]:
# split()：把一张图像缩放到不同分辨率（1/8, 1/4, 1/2, 原图），返回多尺度结果。
# batch_PSNR()：批量计算 PSNR 评价指标。

def split(img):
    output=[]
    output.append(F.interpolate(img, scale_factor=0.125))
    output.append(F.interpolate(img, scale_factor=0.25))
    output.append(F.interpolate(img, scale_factor=0.5))
    output.append(img)
    return output


def batch_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = 0
    for i in range(Img.shape[0]):
        PSNR += compare_psnr(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)
    return (PSNR/Img.shape[0])

In [3]:
# 从训练集目录里读取输入图像和 GT（目标图像），调整大小到 256×256，存入数组

training_x=[]
path= r'/root/LU2Net-master/LSUI/Train/train/'  #'./data/input/'#要改
path_list = os.listdir(path)
path_list.sort(key=lambda x:int(x.split('.')[0]))
for item in path_list:
    impath=path+item
    #print("开始处理"+impath)
    imgx= cv2.imread(path+item)
    imgx = cv2.cvtColor(imgx, cv2.COLOR_BGR2RGB)
    imgx=cv2.resize(imgx,(256,256))
    training_x.append(imgx)   

X_train = []

for features in training_x:
    X_train.append(features)
    
X_train = np.array(X_train)

X_train=X_train.astype(dtype)
X_train= torch.from_numpy(X_train)
X_train=X_train.permute(0,3,1,2)

X_train=X_train/255.0
X_train.shape

############################################################
############################################################

training_y=[]
path= r'/root/LU2Net-master/LSUI/Train/GT/'     #'./data/GT/'#要改
path_list = os.listdir(path)
path_list.sort(key=lambda x:int(x.split('.')[0]))
for item in path_list:
    impath=path+item
    #print("开始处理"+impath)
    imgx= cv2.imread(path+item)
    imgx = cv2.cvtColor(imgx, cv2.COLOR_BGR2RGB)
    imgx=cv2.resize(imgx,(256,256))
    training_y.append(imgx)

y_train = []

for features in training_y:
    y_train.append(features)
    
y_train = np.array(y_train)

y_train=y_train.astype(dtype)
y_train= torch.from_numpy(y_train)
y_train=y_train.permute(0,3,1,2)

y_train=y_train/255.0
y_train.shape

torch.Size([3449, 3, 256, 256])

In [4]:
# 读取测试集（输入图像和 GT），同样 resize 到 256×256

test_x=[]
path= r'/root/LU2Net-master/LSUI/Test/test/'       #'./test/input/'#要改
path_list = os.listdir(path)
# 只保留文件名是“数字.后缀”格式的 #####增加的
path_list = [f for f in path_list if f.split('.')[0].isdigit()]
path_list.sort(key=lambda x:int(x.split('.')[0]))
for item in path_list:
    impath=path+item
    #print("开始处理"+impath)
    imgx= cv2.imread(path+item)
    imgx = cv2.cvtColor(imgx, cv2.COLOR_BGR2RGB)
    imgx=cv2.resize(imgx,(256,256))
    test_x.append(imgx)
    
x_test = []

for features in test_x:
    x_test.append(features)
    
x_test = np.array(x_test)

x_test=x_test.astype(dtype)
x_test= torch.from_numpy(x_test)
x_test=x_test.permute(0,3,1,2)

x_test=x_test/255.0
x_test.shape

############################################################
############################################################

test_Y=[]
path= r'/root/LU2Net-master/LSUI/Test/GT/'   #'./test/GT/'#要改
path_list = os.listdir(path)
# 只保留文件名是“数字.后缀”格式的 #####增加的
path_list = [f for f in path_list if f.split('.')[0].isdigit()]
path_list.sort(key=lambda x:int(x.split('.')[0]))
for item in path_list:
    impath=path+item
    #print("开始处理"+impath)
    imgx= cv2.imread(path+item)
    imgx = cv2.cvtColor(imgx, cv2.COLOR_BGR2RGB)
    imgx=cv2.resize(imgx,(256,256))
    test_Y.append(imgx)
    
Y_test = []

for features in test_Y:
    Y_test.append(features)
    
Y_test = np.array(Y_test)
#X_train = np.array(X_train)

Y_test=Y_test.astype(dtype)
Y_test= torch.from_numpy(Y_test)
Y_test=Y_test.permute(0,3,1,2)

Y_test=Y_test/255.0
Y_test.shape

torch.Size([830, 3, 256, 256])

In [5]:
# 构造 PyTorch Dataset 和 DataLoader 用于训练
# 定义各种损失函数（GAN、像素、SSIM、VGG感知、Lab颜色空间损失）

import torch.utils.data as dataf
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim

dataset = dataf.TensorDataset(X_train,y_train)
loader = dataf.DataLoader(dataset, batch_size=1, shuffle=True,num_workers=4)

# 定义损失函数
criterion_GAN      = nn.MSELoss(reduction='mean').cuda()
criterion_pixelwise= nn.MSELoss(reduction='mean').cuda()
MSE                = nn.MSELoss(reduction='mean').cuda()
SSIM               = pytorch_ssim.SSIM().cuda()
L_vgg              = VGG19_PercepLoss().cuda()
L_lab              = lab_Loss().cuda()
L_lch              = lch_Loss().cuda()

# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel=0.1
lambda_lab=0.001
lambda_lch=1
lambda_con = 100
lambda_ssim=100

# 修改后

In [None]:
# 初始化主模型 LightUNet；是否加载预训练模型

import HDAMS_Net
LightUNet = HDAMS_Net.HDAMS_Net().cuda()

# 如果有保存的模型，则加载模型，并在其基础上继续训练
# use_pretrain=True
use_pretrain=False
if use_pretrain:
    
    # start_epoch=490
    start_epoch=0
    LightUNet.load_state_dict(torch.load("/root/LU2Net-master/LightUNet_%d.pth" % (start_epoch)))
    print('successfully loading epoch {} 成功！'.format(start_epoch))
else:
    start_epoch = 0
    print('No pretrain model found, training will start from scratch!')

In [6]:
# 初始化主模型 LightUNet；是否加载预训练模型

import LU2Net
LightUNet = LU2Net.LU2Net().cuda()

# 如果有保存的模型，则加载模型，并在其基础上继续训练
# use_pretrain=True
use_pretrain=False
if use_pretrain:
    
    # start_epoch=490
    start_epoch=0
    LightUNet.load_state_dict(torch.load("/root/LU2Net-master/LightUNet_%d.pth" % (start_epoch)))
    print('successfully loading epoch {} 成功！'.format(start_epoch))
else:
    start_epoch = 0
    print('No pretrain model found, training will start from scratch!')

No pretrain model found, training will start from scratch!


In [7]:
# 定义函数 sample_images()，从测试集中随机取一张图片，生成预测结果并保存对比

def sample_images(batches_done):
    """Saves a generated sample from the validation set"""
    LightUNet.eval()
    i=random.randrange(1,90)
    real_A = Variable(x_test[i,:,:,:]).cuda()
    real_B = Variable(Y_test[i,:,:,:]).cuda()
    real_A=real_A.unsqueeze(0)
    real_B=real_B.unsqueeze(0)
    fake_B = LightUNet(real_A)
    #print(fake_B.shape)
    imgx=fake_B.data
    imgy=real_B.data
    x=imgx[:,:,:,:]
    y=imgy[:,:,:,:]
    img_sample = torch.cat((x,y), -2)
    save_image(img_sample, "images/%s/%s.png" % ('results', batches_done), nrow=5, normalize=True)#要改

In [8]:
# 设置优化器 Adam，学习率调度器，CUDA 配置

from torch.optim import lr_scheduler
LR=0.0005

# Optimizers
optimizer = torch.optim.Adam(LightUNet.parameters(), lr=LR,  betas=(0.5, 0.999))
scheduler=lr_scheduler.StepLR(optimizer,step_size=40,gamma=0.8)

cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [None]:
# 设置训练日志：准备保存训练过程中的 PSNR 和 SSIM 结果到 CSV 文件
# Training

import sys
import csv
import random

f1 = open('psnr.csv','w',encoding='utf-8')#要改
csv_writer1 = csv.writer(f1)
f2 = open('SSIM.csv','w',encoding='utf-8')#要改
csv_writer2 = csv.writer(f2)

checkpoint_interval=10  # 每 10 个 epoch 保存一次模型
epochs=start_epoch # 从上一次训练结束的 epoch 开始（可恢复训练）
n_epochs=1500  # 总训练轮数  ##################################################################################
sample_interval=1000  # 每训练 1000 个 batch 保存一次示例图片和指标

# ingnored when opt.mode=='S'
psnr_list = []   # 可用于保存每个 batch 的 PSNR（这里暂未使用）

for epoch in range(epochs,n_epochs):
    print("epoch"+str(epoch))
    for i, batch in enumerate(loader):

        # Model inputs
        real_A = Variable(batch[0]).cuda() #############################################
        real_B = Variable(batch[1]).cuda() #############################################
        real_A1=split(real_A)
        real_B1=split(real_B)
        # print(real_B1[3].shape) #############################################

        # ------------------
        #  Train Generators
        # ------------------
        optimizer.zero_grad()

        # GAN loss
        fake_B = LightUNet(real_A)
        # print(fake_B.shape) #############################################
        
        # Pixel-wise loss
        loss_pixel =  criterion_pixelwise(fake_B, real_B1[3])
        loss_ssim= -SSIM(fake_B, real_B1[3])
        ssim_value = - loss_ssim.item()
        loss_con = L_vgg(fake_B, real_B1[3])
        loss_lab = L_lab(fake_B, real_B1[3])
        loss_lch = L_lch(fake_B, real_B1[3])   

        # Total loss
        loss =  lambda_pixel * loss_pixel+  lambda_ssim*loss_ssim+\
            lambda_con*loss_con+  lambda_lab*loss_lab+  lambda_lch*loss_lch

        loss.backward(retain_graph=True)
        
        # torch.nn.utils.clip_grad_norm_(LightUNet.parameters(), max_norm=1.0)  ##################### 梯度裁剪
        
        optimizer.step()

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(loader) + i
        out_train= torch.clamp(fake_B, 0., 1.) 
        psnr_train = batch_PSNR(out_train,real_B, 1.)

        # If at sample interval save image
        if batches_done % sample_interval == 0:
            sample_images(batches_done)
            csv_writer1.writerow([str(psnr_train)])
            csv_writer2.writerow([str(ssim_value)])
            
    if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
        # Save model checkpoints  
        torch.save(LightUNet.state_dict(), "LightUNet_%d.pth" % (epoch))
        # Print log
        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d][PSNR: %f] [SSIM: %f] [loss: %f] ,[lab: %f],[lch: %f], [pixel: %f],[VGG_loss: %f]"
            % (
                epoch,
                n_epochs,
                i,
                len(loader),
                psnr_train,
                ssim_value,
                loss.item(),
                0.001*loss_lab.item(),
                0.1*loss_lch.item(),
                0.1*loss_pixel.item(),
                100*loss_con.item(),
            ))

epoch0
[Epoch 0/1500] [Batch 3448/3449][PSNR: 21.729459] [SSIM: 0.832180] [loss: 295.494995] ,[lab: 339.099625],[lch: 1.293802], [pixel: 0.000672],[VGG_loss: 26.674655]epoch1
epoch2
epoch3
epoch4
epoch5
epoch6
epoch7
epoch8
epoch9
epoch10
[Epoch 10/1500] [Batch 3448/3449][PSNR: 23.462398] [SSIM: 0.844405] [loss: 70.317596] ,[lab: 133.586578],[lch: 1.046382], [pixel: 0.000451],[VGG_loss: 10.707246]epoch11
epoch12
epoch13
epoch14
epoch15
epoch16
epoch17
epoch18
epoch19
epoch20
[Epoch 20/1500] [Batch 3448/3449][PSNR: 25.765349] [SSIM: 0.952382] [loss: 26.474304] ,[lab: 108.005469],[lch: 0.817065], [pixel: 0.000267],[VGG_loss: 5.536105]epoch21
epoch22
epoch23
epoch24
epoch25
epoch26
epoch27
epoch28
epoch29
epoch30
[Epoch 30/1500] [Batch 3448/3449][PSNR: 26.804490] [SSIM: 0.902323] [loss: -5.003420] ,[lab: 68.095297],[lch: 0.591323], [pixel: 0.000209],[VGG_loss: 11.220148]epoch31
epoch32
epoch33
epoch34
epoch35
epoch36
epoch37
epoch38
epoch39
epoch40
[Epoch 40/1500] [Batch 3448/3449][PSNR: 

In [None]:
#####################################################################################
# 测试模型参数量及算力需求
#####################################################################################
import torch
from thop import profile

# 模型定义
model = LU2Net.LU2Net().cuda()

# 统计参数量
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Params: {num_params / 1e6:.2f}M")

# 假设输入是 3x256x256 图像，根据你的任务改成实际输入大小
dummy_input = torch.randn(1, 3, 256, 256).to(next(model.parameters()).device)

# 统计 FLOPs 和 Params
flops, params = profile(model, inputs=(dummy_input,))
print(f"FLOPs: {flops / 1e9:.2f}G")
print(f"Params (from thop): {params / 1e6:.2f}M")