In [1]:
import logging
from functools import partial
import functools
import numpy as np
import copy

import cv2
import torch
import torch.optim as optim
import tqdm
import yaml
from joblib import cpu_count
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.autograd as autograd
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable

from collections import deque

from models.fpn_inception import FPNInception
from schedulers import LinearDecay

from dataset import PairedDataset

In [2]:
cv2.setNumThreads(0)

In [3]:
def get_norm_layer(norm_type='instance'):#归一化
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer

## 生成器的感知loss

In [4]:
class PerceptualLoss():

    def contentFunc(self):
        conv_3_3_layer = 14
        cnn = models.vgg19(pretrained=True).features
        #cnn = cnn.cuda()
        model = nn.Sequential()
        #model = model.cuda()
        model = model.eval()
        for i, layer in enumerate(list(cnn)):
            model.add_module(str(i), layer)
            if i == conv_3_3_layer:
                break
        return model

    def initialize(self, loss):
        with torch.no_grad():
            self.criterion = loss
            self.contentFunc = self.contentFunc()
            self.transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def get_loss(self, fakeIm, realIm):
        fakeIm = (fakeIm + 1) / 2.0
        realIm = (realIm + 1) / 2.0
        fakeIm[0, :, :, :] = self.transform(fakeIm[0, :, :, :])
        realIm[0, :, :, :] = self.transform(realIm[0, :, :, :])
        f_fake = self.contentFunc.forward(fakeIm)
        f_real = self.contentFunc.forward(realIm)
        f_real_no_grad = f_real.detach()
        loss = self.criterion(f_fake, f_real_no_grad)
        return 0.006 * torch.mean(loss) + 0.5 * nn.MSELoss()(fakeIm, realIm)

    def __call__(self, fakeIm, realIm):
        return self.get_loss(fakeIm, realIm)

## 判别器的WGAN-GP loss

$$L=\underset{\tilde{\boldsymbol{x}} \sim \mathbb{P}_{g}}{\mathbb{E}}[D(\tilde{\boldsymbol{x}})]-\underset{\boldsymbol{x} \sim \mathbb{P}_{r}}{\mathbb{E}}[D(\boldsymbol{x})]+\lambda \underset{\hat{\boldsymbol{x}} \sim \mathbb{P}_{\hat{\boldsymbol{x}}}}{\mathbb{E}}\left[\left(\left\|\nabla_{\hat{\boldsymbol{x}}} D(\hat{\boldsymbol{x}})\right\|_{2}-1\right)^{2}\right]$$

In [5]:



# class GANLoss(nn.Module):
#     def __init__(self, use_l1=True, target_real_label=1.0, target_fake_label=0.0,
#                  tensor=torch.FloatTensor):
#         super(GANLoss, self).__init__()
#         self.real_label = target_real_label
#         self.fake_label = target_fake_label
#         self.real_label_var = None
#         self.fake_label_var = None
#         self.Tensor = tensor
#         if use_l1:
#             self.loss = nn.L1Loss()
#         else:
#             self.loss = nn.BCEWithLogitsLoss()

#     def get_target_tensor(self, input, target_is_real):
#         if target_is_real:
#             create_label = ((self.real_label_var is None) or
#                             (self.real_label_var.numel() != input.numel()))
#             if create_label:
#                 real_tensor = self.Tensor(input.size()).fill_(self.real_label)
#                 self.real_label_var = Variable(real_tensor, requires_grad=False)
#             target_tensor = self.real_label_var
#         else:
#             create_label = ((self.fake_label_var is None) or
#                             (self.fake_label_var.numel() != input.numel()))
#             if create_label:
#                 fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
#                 self.fake_label_var = Variable(fake_tensor, requires_grad=False)
#             target_tensor = self.fake_label_var
#         return target_tensor

#     def __call__(self, input, target_is_real):
        
#         target_tensor = self.get_target_tensor(input, target_is_real)
#         return self.loss(input, target_tensor)



class DiscLoss(nn.Module):
    def name(self):
        return 'DiscLoss'

    def __init__(self):
        super(DiscLoss, self).__init__()

        #self.criterionGAN = GANLoss(use_l1=False)
        #self.fake_AB_pool = ImagePool(50)

#     def get_g_loss(self, net, fakeB, realB):
#         # First, G(A) should fake the discriminator
#         pred_fake = net.forward(fakeB)
#         return self.criterionGAN(pred_fake, 1)

    #def get_loss(self, net, fakeB, realB):
        # Fake
        # stop backprop to the generator by detaching fake_B
        # Generated Image Disc Output should be close to zero
#         self.pred_fake = net.forward(fakeB.detach())

#         self.loss_D_fake = self.criterionGAN(self.pred_fake, 0)

#         # Real
#         self.pred_real = net.forward(realB)
#         self.loss_D_real = self.criterionGAN(self.pred_real, 1)

#         # Combined loss
#         self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
#         return self.loss_D

    def __call__(self, net, fakeB, realB):
        
        return self.get_loss(net, fakeB, realB)
    
    
    
# class DiscLossLS(DiscLoss):
#     def name(self):
#         return 'DiscLossLS'

#     def __init__(self):
#         super(DiscLossLS, self).__init__()
#         self.criterionGAN = GANLoss(use_l1=True)

#     def get_g_loss(self, net, fakeB, realB):
#         return DiscLoss.get_g_loss(self, net, fakeB)

#     def get_loss(self, net, fakeB, realB):
#         return DiscLoss.get_loss(self, net, fakeB, realB)
    
    
class DiscLossWGANGP(DiscLoss):
    def name(self):
        return 'DiscLossWGAN-GP'

    def __init__(self):
        super(DiscLossWGANGP, self).__init__()
        self.LAMBDA = 10

    def get_g_loss(self, net, fakeB, realB):
        # First, G(A) should fake the discriminator
        self.D_fake = net.forward(fakeB)
        return -self.D_fake.mean()

    def calc_gradient_penalty(self, netD, real_data, fake_data):
        alpha = torch.rand(1, 1)
        alpha = alpha.expand(real_data.size())

        interpolates = alpha * real_data + ((1 - alpha) * fake_data)

        interpolates = Variable(interpolates, requires_grad=True)

        disc_interpolates = netD.forward(interpolates)


        gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                  grad_outputs=torch.ones(disc_interpolates.size()),
                                  create_graph=True, retain_graph=True, only_inputs=True)[0]

        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA
        return gradient_penalty

    def get_loss(self, net, fakeB, realB):

        self.D_fake = net.forward(fakeB.detach())
        self.D_fake = self.D_fake.mean()
        

        # Real
        self.D_real = net.forward(realB)
        self.D_real = self.D_real.mean()
        

        # Combined loss
        self.loss_D = self.D_fake - self.D_real

        gradient_penalty = self.calc_gradient_penalty(net, realB.data, fakeB.data)

        return self.loss_D + gradient_penalty

In [6]:
def get_loss(model):
    
    content_loss = PerceptualLoss()
    content_loss.initialize(nn.MSELoss())

    disc_loss = DiscLossWGANGP()


    return content_loss, disc_loss

## 生成器和判别器的网络

**判别器网络**：
>    
    {'patch': 
          (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
          (1): LeakyReLU(negative_slope=0.2, inplace=True)
          (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
          (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
          (4): LeakyReLU(negative_slope=0.2, inplace=True)
          (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
          (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
          (7): LeakyReLU(negative_slope=0.2, inplace=True)
          (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
          (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
          (10): LeakyReLU(negative_slope=0.2, inplace=True)
          (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
     , 'full': 
          (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
          (1): LeakyReLU(negative_slope=0.2, inplace=True)
          (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
          (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
          (4): LeakyReLU(negative_slope=0.2, inplace=True)
          (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
          (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
          (7): LeakyReLU(negative_slope=0.2, inplace=True)
          (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
          (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
          (10): LeakyReLU(negative_slope=0.2, inplace=True)
          (11): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
          (12): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
          (13): LeakyReLU(negative_slope=0.2, inplace=True)
          (14): Conv2d(512, 512, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
          (15): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
          (16): LeakyReLU(negative_slope=0.2, inplace=True)
          (17): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
    }

In [7]:
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_parallel=True):
        super(NLayerDiscriminator, self).__init__()
        self.use_parallel = use_parallel
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = int(np.ceil((kw-1)/2))
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):

        return self.model(input)

In [8]:
def get_generator(model_config):
    model_g = FPNInception(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))

    return nn.DataParallel(model_g)

In [9]:
def get_discriminator(model_config):
    patch_gan = NLayerDiscriminator(n_layers=model_config['d_layers'],
                                        norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
                                        use_sigmoid=False)
    patch_gan = nn.DataParallel(patch_gan)
    
    full_gan = NLayerDiscriminator(n_layers=5,
                                  norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
                                  use_sigmoid=False)
    full_gan = nn.DataParallel(full_gan)
    
    model_d = {'patch': patch_gan,
               'full': full_gan}


    return model_d

In [10]:
def get_nets(model_config):
    return get_generator(model_config), get_discriminator(model_config)

In [11]:
class GANFactory:
    factories = {}

    def __init__(self):
        pass

    def add_factory(gan_id, model_factory):
        GANFactory.factories.put[gan_id] = model_factory

    add_factory = staticmethod(add_factory)

    # A Template Method:

    def create_model(gan_id, net_d=None, criterion=None):
        if gan_id not in GANFactory.factories:
            GANFactory.factories[gan_id] = \
                eval(gan_id + '.Factory()')
        return GANFactory.factories[gan_id].create(net_d, criterion)

    create_model = staticmethod(create_model)


In [12]:
class GANTrainer(object):
    def __init__(self, net_d, criterion):
        self.net_d = net_d
        self.criterion = criterion

    def loss_d(self, pred, gt):
        pass

    def loss_g(self, pred, gt):
        pass

    def get_params(self):
        pass


loss_d会call

In [13]:
class DoubleGAN(GANTrainer):
    def __init__(self, net_d, criterion):
        GANTrainer.__init__(self, net_d, criterion)
        self.patch_d = net_d['patch']
        self.full_d = net_d['full']
        self.full_criterion = copy.deepcopy(criterion)

    def loss_d(self, pred, gt):##pred：fake；gt：real
        return (self.criterion(self.patch_d, pred, gt) + self.full_criterion(self.full_d, pred, gt)) / 2

    def loss_g(self, pred, gt):
        return (self.criterion.get_g_loss(self.patch_d, pred, gt) + self.full_criterion.get_g_loss(self.full_d, pred,
                                                                                                  gt)) / 2

    def get_params(self):
        return list(self.patch_d.parameters()) + list(self.full_d.parameters())

    class Factory:
        @staticmethod
        def create(net_d, criterion): return DoubleGAN(net_d, criterion)

In [14]:
class DeblurModel(nn.Module):
    def __init__(self):
        super(DeblurModel, self).__init__()

    def get_input(self, data):
        img = data['a']
        inputs = img
        targets = data['b']

        return inputs, targets

    def tensor2im(self, image_tensor, imtype=np.uint8):
        image_numpy = image_tensor[0].cpu().float().numpy()
        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
        return image_numpy.astype(imtype)


def get_model(model_config):
    return DeblurModel()

In [15]:
class Trainer:
    def __init__(self, config, train: DataLoader, val: DataLoader):
        self.config = config
        self.train_dataset = train
        self.val_dataset = val
        self.adv_lambda = config['model']['adv_lambda']
        self.warmup_epochs = config['warmup_num']

    def train(self):
        self._init_params()

        for epoch in range(0, 1):
            if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0):
                self.netG.module.unfreeze()
                self.optimizer_G = self._get_optim(self.netG.parameters())
                self.scheduler_G = self._get_scheduler(self.optimizer_G)
            self._run_epoch(epoch)
            self._validate(epoch)
            self.scheduler_G.step()
            self.scheduler_D.step()
         

    def _run_epoch(self, epoch):
        
        for param_group in self.optimizer_G.param_groups:
            lr = param_group['lr']

        epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset)
        tq = tqdm.tqdm(self.train_dataset, total=epoch_size)
        tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
        i = 0

        for data in tq:
            inputs, targets = self.model.get_input(data)
            outputs = self.netG(inputs) #outputs为经过生产器生成的
            
            loss_D = self._update_d(outputs, targets)#更新判别器的参数，并返回loss
            
            self.optimizer_G.zero_grad()
            
            loss_content = self.criterionG(outputs, targets)#感知loss
            loss_adv = self.adv_trainer.loss_g(outputs, targets)#生成器的loss
            loss_G = loss_content + self.adv_lambda * loss_adv
            
            loss_G.backward()#生成器的loss反向传播
            self.optimizer_G.step()#更新生成器的参数

            i += 1
            if i > epoch_size:
                break
        tq.close()


        

    def _validate(self, epoch):
        
        epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset)
        tq = tqdm.tqdm(self.val_dataset, total=epoch_size)
        tq.set_description('Validation')
        i = 0
        for data in tq:
            inputs, targets = self.model.get_input(data)
            outputs = self.netG(inputs)
            loss_content = self.criterionG(outputs, targets)
            loss_adv = self.adv_trainer.loss_g(outputs, targets)
            loss_G = loss_content + self.adv_lambda * loss_adv
            
            
            
            i += 1
            if i > epoch_size:
                break
        tq.close()
        

    def _update_d(self, outputs, targets):
        
        self.optimizer_D.zero_grad()
        loss_D = self.adv_lambda * self.adv_trainer.loss_d(outputs, targets)#计算判别器的loss
        
        loss_D.backward(retain_graph=True)#判别器的loss反向传播求梯度
        self.optimizer_D.step()#更新判别器的各个参数
        
        return loss_D.item()

    def _get_optim(self, params):
        optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
        
        return optimizer

    def _get_scheduler(self, optimizer):
        scheduler = LinearDecay(optimizer,
                                    min_lr=self.config['scheduler']['min_lr'],
                                    num_epochs=self.config['num_epochs'],
                                    start_epoch=self.config['scheduler']['start_epoch'])
        return scheduler

    @staticmethod
    def _get_adversarial_trainer(d_name, net_d, criterion_d):
        
        return GANFactory.create_model('DoubleGAN', net_d, criterion_d)
        

    def _init_params(self):
        self.criterionG, criterionD = get_loss(self.config['model'])

        self.netG, netD = get_nets(self.config['model'])
        
        #self.netG.cuda()
        self.adv_trainer = self._get_adversarial_trainer(self.config['model']['d_name'], netD, criterionD)#double_gan
        self.model = get_model(self.config['model'])
        
        self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
        self.optimizer_D = self._get_optim(self.adv_trainer.get_params())
        self.scheduler_G = self._get_scheduler(self.optimizer_G)
        self.scheduler_D = self._get_scheduler(self.optimizer_D)



In [16]:
with open('config/config.yaml', 'r') as f:
    config = yaml.load(f)

batch_size = config.pop('batch_size')
get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=True)#简单总结functools.partial的作用就是，把一个函数的某些参数给固定住（也就是设置默认值），返回一个新的函数，调用这个新函数会更简单。

datasets = map(config.pop, ('train', 'val'))
datasets = map(PairedDataset.from_config, datasets)
train, val = map(get_dataloader, datasets)
trainer = Trainer(config, train=train, val=val)
trainer.train()

  

I1225 00:43:49.310564 14679 dataset.py:28] Subsampling buckets from 0 to 90.0, total buckets number is 100
I1225 00:43:49.311258 14679 dataset.py:71] Dataset has been created with 16 samples
I1225 00:43:49.312687 14679 dataset.py:28] Subsampling buckets from 90.0 to 100, total buckets number is 100
I1225 00:43:49.313170 14679 dataset.py:71] Dataset has been created with 4 samples

Epoch 0, lr 0.0001:   0%|          | 2/1000 [00:12<1:45:34,  6.35s/it]

KeyboardInterrupt: 