#### 导入库

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import torchvision.models as models
from Net import *
from Loss import *
from DataLoader import *
from torch.utils.data import DataLoader
import time

#### 基本参数

In [None]:
train_batch_size = 16
start_epochs = 0
learning_rate = 0.0002
# 总共训练200个epoch
num_epochs = 200
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
save_point = 5
# 1:分解模型
model_choose = 1

#### 损失函数

In [None]:
consLoss = nn.MSELoss()
recLoss = nn.MSELoss()
# structure-aware TV loss
smoothLoss = TVLoss()

#### 数据缓存入内存，加快读入速度

In [None]:
class DataPrefetcher():

    def __init__(self, loader):
        self.loader = iter(loader)
        self.preload()

    def preload(self):
        try:
            self.batch = next(self.loader)
        except StopIteration:
            self.batch = None
            return

    def next(self):
        batch = self.batch
        self.preload()
        return batch

#### 模型1_分解模型

In [None]:
def train_1(start_epoch):
    print("模型导入中")
    model = Retinex_Decomposition_net().to(device)
    if start_epoch != 0:
        # ./checkpoints/Retinex_Decomposition_net文件夹中epoch最大的模型
        model_path = './checkpoints/Retinex_Decomposition_net/epoch_' + str(start_epoch) + '.pth'
        model.load_state_dict(torch.load(model_path))
    print("模型导入完成")
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    total_loss = 0
    for epoch in range(start_epoch+1, num_epochs+1):
        print("epoch: ", epoch)
        L_no_light_path = r"./dataset/UIALN_datasest/train_data/dataset_no_AL"
        L_light_path = r"./dataset/UIALN_datasest/train_data/dataset_with_AL/train"
        dataset = retinex_decomposition_data(L_no_light_path, L_light_path)
        train_loader = DataLoader(dataset, batch_size=train_batch_size, shuffle=True, num_workers=0)
        start_time = time.time()
        prefetcher = DataPrefetcher(train_loader)
        batch = prefetcher.next()
        i = 0
        epoch_loss = 0
        while batch is not None:
            i += 1
            L_no_light = batch[0].to(device)
            L_light = batch[1].to(device)
            I_no_light_hat, R_no_light_hat = model(L_no_light)
            I_light_hat, R_light_hat = model(L_light)
            loss_1 = consLoss(R_light_hat, R_no_light_hat)
            loss_2 = recLoss(I_light_hat*R_light_hat-L_light, I_no_light_hat*R_no_light_hat-L_no_light)
            loss_3 = smoothLoss(R_light_hat)
            loss_4 = smoothLoss(R_no_light_hat)
            loss = loss_1 + loss_2 + loss_3 + loss_4
            epoch_loss += loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batch = prefetcher.next()
        if epoch % save_point == 0:
            state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
            torch.save(state, './checkpoints/Retinex_Decomposition_net/epoch_' + str(epoch) + '.pth')
        time_epoch = time.time() - start_time
        epoch_loss = epoch_loss*1.0/i
        total_loss += epoch_loss
        print("==>No: {} epoch, time: {:.2f}, loss: {:.5f}".format(epoch, time_epoch / 60, epoch_loss))
        with open("output.txt", "a") as f:
            f.write("==>No: {} epoch, time: {:.2f}, loss: {:.5f}\n".format(epoch, time_epoch / 60, epoch_loss))
    print("total_loss:",total_loss*1.0/num_epochs-start_epochs)

#### 主函数-判定训练哪个模型

In [None]:
if __name__ == '__main__':
    print(torch.cuda.is_available())
    if model_choose == 1:
        train_1(start_epochs)