In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from GetdataSet import trainMYDataSet
from LOSS import PerceptualLoss
from torch.utils.data import DataLoader
import time
import os
from os.path import join
import gc

train_batch_size = 5
start_epochs = 0
learning_rate = 0.01*(0.5**(start_epochs//30))
num_epochs = 150

In [2]:
class ShareSepConv(nn.Module):
    def __init__(self, kernel_size):
        super(ShareSepConv, self).__init__()
        assert kernel_size % 2 == 1, 'kernel size should be odd'
        self.padding = (kernel_size - 1) // 2  # 设置该大小的padding,能使得进行卷积后，输出的特征图的尺寸大小不变
        weight_tensor = torch.zeros(1, 1, kernel_size, kernel_size)  # 定义一个1个种类,一个通道，大小为kernel_size的卷积核
        weight_tensor[0, 0, (kernel_size - 1) // 2, (kernel_size - 1) // 2] = 1  # 将卷积核中间那个数值设为1
        self.weight = nn.Parameter(weight_tensor)  # 将其卷积核变为可学习的参数
        self.kernel_size = kernel_size

    def forward(self, x):
        inc = x.size(1)  # 获取输入特征图的通道数
        expand_weight = self.weight.expand(inc, 1, self.kernel_size, self.kernel_size).contiguous()
        return F.conv2d(x, expand_weight,
                        None, 1, self.padding, 1, inc)


class SmoothDilatedResidualBlock(nn.Module):
    def __init__(self, channel_num, dilation=1, group=1):
        super(SmoothDilatedResidualBlock, self).__init__()
        self.pre_conv1 = ShareSepConv(dilation * 2 - 1)
        self.conv1 = nn.Conv2d(channel_num, channel_num, 3, 1, padding=dilation, dilation=dilation, groups=group,
                               bias=False)
        self.norm1 = nn.InstanceNorm2d(channel_num, affine=True)
        self.pre_conv2 = ShareSepConv(dilation * 2 - 1)
        self.conv2 = nn.Conv2d(channel_num, channel_num, 3, 1, padding=dilation, dilation=dilation, groups=group,
                               bias=False)
        self.norm2 = nn.InstanceNorm2d(channel_num, affine=True)

    def forward(self, x):
        y = F.relu(self.norm1(self.conv1(self.pre_conv1(x))))
        y = self.norm2(self.conv2(self.pre_conv2(y)))
        return F.relu(x + y)


class ResidualBlock(nn.Module):
    def __init__(self, channel_num, dilation=1, group=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channel_num, channel_num, 3, 1, padding=dilation, dilation=dilation, groups=group,
                               bias=False)
        self.norm1 = nn.InstanceNorm2d(channel_num, affine=True)
        self.conv2 = nn.Conv2d(channel_num, channel_num, 3, 1, padding=dilation, dilation=dilation, groups=group,
                               bias=False)
        self.norm2 = nn.InstanceNorm2d(channel_num, affine=True)

    def forward(self, x):
        y = F.relu(self.norm1(self.conv1(x)))
        y = self.norm2(self.conv2(y))
        return F.relu(x + y)


class CINR(nn.Module):
    def __init__(self, in_channel_num, out_channel_num, kernel_size=3, stride=1, padding=1):
        super(CINR, self).__init__()
        self.conv1 = nn.Conv2d(in_channel_num, out_channel_num, kernel_size, stride, padding, bias=False)
        self.norm1 = nn.InstanceNorm2d(out_channel_num, affine=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        return F.relu(x)


# 通道注意力机制
class Attu_1(nn.Module):
    def __init__(self, channel_num):
        super(Attu_1, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=False)
        self.conv2 = nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=False)

    def forward(self, x):
        y = F.relu(self.conv1(self.avg_pool(x)))
        y = torch.sigmoid(self.conv2(y))
        return x * y


# 残差组，包含三个残差块
class ResidualGroup(nn.Module):
    def __init__(self, channel_num, dilation=1, group=1):
        super(ResidualGroup, self).__init__()
        self.residual_block1 = ResidualBlock(channel_num, dilation, group)
        self.residual_block2 = ResidualBlock(channel_num, dilation, group)
        self.residual_block3 = ResidualBlock(channel_num, dilation, group)

    def forward(self, x):
        y = self.residual_block1(x)
        y = self.residual_block2(y)
        y = self.residual_block3(y)
        return y


# 像素注意力机制
class Attu_2(nn.Module):
    def __init__(self, channel_num):
        super(Attu_2, self).__init__()
        self.conv1 = nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=False)
        self.conv2 = nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=False)

    def forward(self, x):
        y = F.relu(self.conv1(x))
        y = torch.sigmoid(self.conv2(y))
        return x * y


class AmNet(nn.Module):
    def __init__(self, in_c=3, out_c=3):
        super(AmNet, self).__init__()
        self.Cinr1 = CINR(in_c, 64, 3, 1, 1)

        # MFE
        self.Mfe1_1 = nn.Conv2d(64, 64, 3, 1, 1, bias=False)
        self.Mfe1_2 = CINR(64, 64, 3, 1, 1)
        self.Mfe1_3 = Attu_1(64)

        self.Mfe2_1 = SmoothDilatedResidualBlock(64, dilation=2)
        self.Mfe2_2 = CINR(64, 64, 3, 1, 1)
        self.Mfe2_3 = Attu_1(64)

        self.Mfe3_1 = SmoothDilatedResidualBlock(64, dilation=4)
        self.Mfe3_2 = CINR(64, 64, 3, 1, 1)
        self.Mfe3_3 = Attu_1(64)

        self.Mfe_final = nn.Conv2d(64 * 3, 64, 3, 1, 1, bias=False)

        self.Cinr2_1 = CINR(64, 64, 3, 1, 1)
        self.max_pool2_1 = nn.MaxPool2d(2, 2, 0)
        self.Cinr2_2 = CINR(64, 128, 3, 1, 1)
        self.max_pool2_2 = nn.MaxPool2d(2, 2, 0)
        self.Cinr2_3 = CINR(128, 256, 3, 1, 1)

        self.ResidualGroup3_1 = ResidualGroup(256, dilation=1)

        self.Cinr3_1 = CINR(256, 256, 3, 1, 1)
        self.deconv3_1 = nn.ConvTranspose2d(256 * 2, 128, 4, 2, 1, bias=False)
        self.Cinr3_2 = CINR(128, 128, 3, 1, 1)
        self.deconv3_2 = nn.ConvTranspose2d(128 * 2, 64, 4, 2, 1, bias=False)
        self.Cinr3_3 = CINR(64, 64, 3, 1, 1)

        self.Cinr4 = CINR(64 * 2, 64, 3, 1, 1)

        self.Attu_Block1 = Attu_1(64)
        self.Attu_Block2 = Attu_2(64)

        self.conv5 = nn.Conv2d(64, out_c, 3, 1, 1, bias=False)

    def forward(self, x):
        y1 = self.Cinr1(x)

        y2 = self.Mfe1_3(self.Mfe1_2(self.Mfe1_1(y1)))
        y3 = self.Mfe2_3(self.Mfe2_2(self.Mfe2_1(y1)))
        y4 = self.Mfe3_3(self.Mfe3_2(self.Mfe3_1(y1)))
        # concat连接特征
        y = self.Mfe_final(torch.cat((y2, y3, y4), 1))  # 没确定特征连接方式感觉是cat
        y6 = self.Cinr2_1(y + y1)  # 基础网络
        y7 = self.Cinr2_2(self.max_pool2_1(y6))
        y8 = self.Cinr2_3(self.max_pool2_2(y7))
        y = self.Cinr3_1(self.ResidualGroup3_1(y8))
        y = self.deconv3_1(torch.cat((y8, y), 1))
        y = self.deconv3_2(torch.cat((self.Cinr3_2(y), y7), 1))
        y = torch.cat((self.Cinr3_3(y), y6), 1)
        y = self.Attu_Block2(self.Attu_Block1(self.Cinr4(y)))
        y = self.conv5(y)

        return y


In [3]:
smooth_l1 = nn.SmoothL1Loss()
vgg16_Loss = PerceptualLoss()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = AmNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /home/ma-user/.cache/torch/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

In [4]:
class DataPrefetcher():

    def __init__(self, loader):
        self.loader = iter(loader)
        self.preload()

    def preload(self):
        try:
            self.batch = next(self.loader)
        except StopIteration:
            self.batch = None
            return

    def next(self):
        batch = self.batch
        self.preload()
        return batch

In [5]:
def train(model, start_epoch):
    print("start train")
    model.train()
    #steps_per_epoch = len(datasetloader)
    #total_iteration = steps_per_epoch * num_epochs
    #print("total_iteration:", total_iteration)
    total_loss = 0
    for epoch in range(start_epoch + 1, num_epochs + 1):
        dataset = trainMYDataSet(src_data_path="./input_train/", lable_data_path="./gt_train/")
        datasetloader = DataLoader(dataset, batch_size=train_batch_size, shuffle=False, num_workers=0)
        gc.collect()
        start_time = time.time()
        prefetcher = DataPrefetcher(datasetloader)
        batch = prefetcher.next()
        i = 0
        epoch_loss = 0
        while batch is not None:
            i += 1
            inputs = batch[0].to(device)
            labels = batch[1].to(device)
            R = model(inputs)
            loss1 = smooth_l1(R, labels)
            loss2 = vgg16_Loss(R, labels)
            train_loss = loss1 + loss2
            epoch_loss+=train_loss
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
            batch = prefetcher.next()
            with open("epoch_output.txt", "a") as f:
                if i % 5 == 0:
                    f.write('{:.2f} => Epoch[{}/{}]: train_loss: {:.4f},l1: {:.4f},vgg: {:.4f}\n'.format(time.time() - start_time, epoch, num_epochs,
                                                                                      train_loss.item(),
                                                                                      loss1.item(),
                                                                                      loss2.item()))
        if epoch % 5 == 0:
            state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
            backup_model_dir = join('./Rlablemodel/')
            torch.save(state, join(backup_model_dir, '{}-model-epochs{}.pth'.format('AmNet', epoch)))
        if epoch % 30 == 0 and epoch != 0:
            for p in optimizer.param_groups:
                p['lr'] *= 0.5
        time_epoch = time.time() - start_time
        epoch_loss = epoch_loss*1.0/i
        total_loss += epoch_loss
        print("==>No: {} epoch, time: {:.2f}, loss: {:.5f}".format(epoch, time_epoch / 60, epoch_loss))
        with open("output.txt", "a") as f:
            f.write("==>No: {} epoch, time: {:.2f}, loss: {:.5f}\n".format(epoch, time_epoch / 60, epoch_loss))
    print("total_loss:",total_loss*1.0/num_epochs-start_epochs)

In [None]:
def main():
    start_epoch = start_epochs
    print(torch.cuda.is_available())
    print(learning_rate)
    if start_epoch == 0:
        print('==> 无保存模型，将从头开始训练！')
    else:
        print('模型加载')
        checkpoint_model = join('./Rlablemodel/',
                                '{}-model-epochs{}.pth'.format("AmNet", start_epoch))
        checkpoint = torch.load(checkpoint_model, map_location=device)
        model.load_state_dict(checkpoint['model'])
    train(model, start_epoch)


if __name__ == '__main__':
    main()

True
0.01
模型加载
start train
==>No: 21 epoch, time: 5.35, loss: 0.05980
