In [None]:
import torch 
import torchvision
import torch.nn as nn
from torch.nn import init
from torchvision.transforms.functional import normalize as F_normalize
from torch.nn.functional import interpolate as F_upsample
import torch.nn.functional as F

In [None]:
class EfficientAttention(nn.Module):
    def __init__(self, val_channels=3, key_channels=4, in_channels=0):
        super().__init__()
        self.in_channels = in_channels if in_channels else val_channels
        self.key_channels = key_channels
        self.val_channels = val_channels

        self.keys = nn.Conv2d(self.val_channels, self.key_channels, 1)
        self.values = nn.Conv2d(self.val_channels, self.key_channels, 1)
        self.queries = nn.Conv2d(self.in_channels, self.key_channels, 1)
        self.reprojection = nn.Conv2d(self.key_channels, self.val_channels, 1)

    def forward(self, value_, input_=None):
        n, c, h, w = value_.size()
        values = self.values(value_).reshape((n, self.key_channels, h * w))
        keys = self.keys(value_).reshape((n, self.key_channels, h * w))
        
        if input_ is not None:
            queries = self.queries(input_)
            
            # maxpool the query if it is larger than the value 
            _, _, h_i, w_i = input_.size()
            if w_i > w or h_i > h:
                queries = F.max_pool2d(queries, (h_i//h, w_i//w))
            
            queries = queries.reshape(n, self.key_channels, h * w)
        else:
            queries = self.queries(value_).reshape(n, self.key_channels, h * w)

        key = F.softmax(keys, dim=2)
        query = F.softmax(queries, dim=1)
        
        context = key @ values.transpose(1, 2)
        attention = (
            context.transpose(1, 2) @ query
        ).reshape(n, self.key_channels, h, w)

        reprojected_value = self.reprojection(attention)
        attention = reprojected_value + value_
        return attention

In [None]:
class UnetGeneratorBilinear(nn.Module):
    def __init__(self, opt, norm_layer):
        super(UnetGeneratorBilinear, self).__init__()

        use_bias = norm_layer == nn.InstanceNorm2d
        
        self.normalize = opt.norm_G_out
        self.self_attention = opt.self_attention
        self.use_avgpool = opt.use_avgpool
        self.skip = opt.skip
        self.use_tanh = opt.tanh_G_out
        if self.use_tanh:
            if opt.hardtanh:
                self.final_tanh = nn.Hardtanh() 
            else:
                self.final_tanh = nn.Tanh() 

        p = 1
        if self.self_attention:
            self.conv1_1 = nn.Conv2d(6, 32, 3, padding=p)
            self.attention_in = EfficientAttention(val_channels=3, key_channels=3, in_channels=3)
            self.attention_out = EfficientAttention(val_channels=3, key_channels=3, in_channels=3)
            self.attention_1 = EfficientAttention(val_channels=32, key_channels=4, in_channels=3)
            self.attention_2 = EfficientAttention(val_channels=64, key_channels=4, in_channels=3)
            self.attention_3 = EfficientAttention(val_channels=128, key_channels=8, in_channels=3)
            self.attention_4 = EfficientAttention(val_channels=512, key_channels=16, in_channels=3)
        else:
            self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p)

        self.LReLU1_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn1_1 = norm_layer(32)
        self.conv1_2 = nn.Conv2d(32, 32, 3, padding=p)
        self.LReLU1_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn1_2 = norm_layer(32)
        self.max_pool1 = nn.AvgPool2d(2) if self.use_avgpool == 1 else nn.MaxPool2d(2)

        self.conv2_1 = nn.Conv2d(32, 64, 3, padding=p)
        self.LReLU2_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn2_1 = norm_layer(64)
        self.conv2_2 = nn.Conv2d(64, 64, 3, padding=p)
        self.LReLU2_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn2_2 = norm_layer(64)
        self.max_pool2 = nn.AvgPool2d(2) if self.use_avgpool == 1 else nn.MaxPool2d(2)

        self.conv3_1 = nn.Conv2d(64, 128, 3, padding=p)
        self.LReLU3_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn3_1 = norm_layer(128)
        self.conv3_2 = nn.Conv2d(128, 128, 3, padding=p)
        self.LReLU3_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn3_2 = norm_layer(128)
        self.max_pool3 = nn.AvgPool2d(2) if self.use_avgpool == 1 else nn.MaxPool2d(2)
        
        self.conv4_11 = nn.Conv2d(128, 128, 1, padding=p*0)
        self.LReLU4_11 = nn.LeakyReLU(0.2, inplace=True)
        self.bn4_11 = norm_layer(128)
        self.conv4_12 = nn.Conv2d(128, 128, 3, padding=p*1)
        self.LReLU4_12 = nn.LeakyReLU(0.2, inplace=True)
        self.bn4_12 = norm_layer(128)
        self.conv4_13 = nn.Conv2d(128, 128, 5, padding=p*2)
        self.LReLU4_13 = nn.LeakyReLU(0.2, inplace=True)
        self.bn4_13 = norm_layer(128)
        self.conv4_14 = nn.Conv2d(128, 128, 7, padding=p*3)
        self.LReLU4_14 = nn.LeakyReLU(0.2, inplace=True)
        self.bn4_14 = norm_layer(128)
        self.conv4_2 = nn.Conv2d(512, 256, 3, padding=p)
        self.LReLU4_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn4_2 = norm_layer(256)
        
        
        # Uncomment this block for further downsampling
        '''
        self.max_pool4 = nn.AvgPool2d(2) if self.use_avgpool == 1 else nn.MaxPool2d(2)

        self.conv5_1 = nn.Conv2d(256, 512, 3, padding=p)
        self.LReLU5_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn5_1 = norm_layer(512)
        self.conv5_2 = nn.Conv2d(512, 512, 3, padding=p)
        self.LReLU5_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn5_2 = norm_layer(512)

        
        self.deconv5 = nn.Conv2d(512, 256, 3, padding=p)
        self.conv6_1 = nn.Conv2d(512, 256, 3, padding=p)
        self.LReLU6_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn6_1 = norm_layer(256)
        self.conv6_2 = nn.Conv2d(256, 256, 3, padding=p)
        self.LReLU6_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn6_2 = norm_layer(256)
        '''

        self.deconv6 = nn.Conv2d(256, 128, 3, padding=p)
        self.conv7_1 = nn.Conv2d(256, 128, 3, padding=p)
        self.LReLU7_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn7_1 = norm_layer(128)
        self.conv7_2 = nn.Conv2d(128, 128, 3, padding=p)
        self.LReLU7_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn7_2 = norm_layer(128)

        self.deconv7 = nn.Conv2d(128, 64, 3, padding=p)
        self.conv8_1 = nn.Conv2d(128, 64, 3, padding=p)
        self.LReLU8_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn8_1 = norm_layer(64)
        self.conv8_2 = nn.Conv2d(64, 64, 3, padding=p)
        self.LReLU8_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn8_2 = norm_layer(64)

        self.deconv8 = nn.Conv2d(64, 32, 3, padding=p)
        self.conv9_1 = nn.Conv2d(64, 32, 3, padding=p)
        self.LReLU9_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn9_1 = norm_layer(32)
        self.conv9_2 = nn.Conv2d(32, 32, 3, padding=p)
        self.LReLU9_2 = nn.LeakyReLU(0.2, inplace=True)

        self.conv10 = nn.Conv2d(32, 3, 1)

    def forward(self, input):
        if self.self_attention:
            attended_inp = self.attention_in(input)
            x = self.bn1_1(self.LReLU1_1(self.conv1_1(torch.cat([input, attended_inp], dim=1))))
        else:
            x = self.bn1_1(self.LReLU1_1(self.conv1_1(input)))
        conv1 = self.bn1_2(self.LReLU1_2(self.conv1_2(x)))
        x = self.max_pool1(conv1)

        x = self.bn2_1(self.LReLU2_1(self.conv2_1(x)))
        conv2 = self.bn2_2(self.LReLU2_2(self.conv2_2(x)))
        x = self.max_pool2(conv2)

        x = self.bn3_1(self.LReLU3_1(self.conv3_1(x)))
        conv3 = self.bn3_2(self.LReLU3_2(self.conv3_2(x)))
        x = self.max_pool3(conv3)

        ########## Starts: Bottom of the U-NET ##########
        x_1 = self.bn4_11(self.LReLU4_11(self.conv4_11(x)))
        x_2 = self.bn4_12(self.LReLU4_12(self.conv4_12(x)))
        x_3 = self.bn4_13(self.LReLU4_13(self.conv4_13(x)))
        x_4 = self.bn4_14(self.LReLU4_14(self.conv4_14(x)))
        x = torch.cat([x_1,x_2,x_3,x_4], dim=1)
        x = self.attention_4(x, input) if self.self_attention else x
        conv6 = self.bn4_2(self.LReLU4_2(self.conv4_2(x)))
        
        # uncomment this block for further downsampling
        '''
        x = self.bn4_1(self.LReLU4_1(self.conv4_1(x)))
        conv4 = self.bn4_2(self.LReLU4_2(self.conv4_2(x)))
        x = self.max_pool4(conv4)

        x = self.bn5_1(self.LReLU5_1(self.conv5_1(x)))
        #x = x*attention_map5 if self.self_attention else x
        x = self.attention_5(x) if self.self_attention else x
        conv5 = self.bn5_2(self.LReLU5_2(self.conv5_2(x)))
        
        conv5 = F_upsample(conv5, scale_factor=2, mode='bilinear')
        #conv4 = conv4*attention_map4 if self.self_attention else conv4
        conv4 = self.attention_4(conv4) if self.self_attention else conv4
        up6 = torch.cat([self.deconv5(conv5), conv4], 1)
        x = self.bn6_1(self.LReLU6_1(self.conv6_1(up6)))
        conv6 = self.bn6_2(self.LReLU6_2(self.conv6_2(x)))
        '''
        ########### Ends: Bottom of the U-NET ##########

        conv6 = F_upsample(conv6, scale_factor=2, mode='bilinear')
        conv3 = self.attention_3(conv3, input) if self.self_attention else conv3
        up7 = torch.cat([self.deconv6(conv6), conv3], 1)
        x = self.bn7_1(self.LReLU7_1(self.conv7_1(up7)))
        conv7 = self.bn7_2(self.LReLU7_2(self.conv7_2(x)))

        conv7 = F_upsample(conv7, scale_factor=2, mode='bilinear')
        conv2 = self.attention_2(conv2, input) if self.self_attention else conv2
        up8 = torch.cat([self.deconv7(conv7), conv2], 1)
        x = self.bn8_1(self.LReLU8_1(self.conv8_1(up8)))
        conv8 = self.bn8_2(self.LReLU8_2(self.conv8_2(x)))

        conv8 = F_upsample(conv8, scale_factor=2, mode='bilinear')
        conv1 = self.attention_1(conv1, input) if self.self_attention else conv1
        up9 = torch.cat([self.deconv8(conv8), conv1], 1)
        x = self.bn9_1(self.LReLU9_1(self.conv9_1(up9)))
        conv9 = self.LReLU9_2(self.conv9_2(x))

        latent = self.conv10(conv9)
        latent = self.attention_out(latent, input) if self.self_attention else latent

        if self.skip:
            if self.normalize:
                min_latent = torch.amin(latent, dim=(0,2,3), keepdim=True)
                max_latent = torch.amax(latent, dim=(0,2,3), keepdim=True)
                latent = (latent - min_latent) / (max_latent - min_latent)
                
            output = latent + self.skip * input
        else:
            output = latent
        
        if self.use_tanh:
            output = self.final_tanh(output)

        return output


In [None]:
class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator

        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        self.n_layers = n_layers

        kw = 4
        padw = 1
        sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [[
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)

        sequence += [[
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]]

        sequence += [[nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]]  # output 1 channel prediction map

        for n in range(len(sequence)):
            setattr(self, 'model' + str(n), nn.Sequential(*sequence[n]))

    def forward(self, x, return_features=False):
        """Standard forward."""
        if return_features:
            feats = [x]
            for n in range(self.n_layers + 2):
                feats.append(getattr(self, 'model' + str(n))(feats[-1]))    
            return feats[1:]
        else:
            for n in range(self.n_layers + 2):
                x = getattr(self, 'model' + str(n))(x)
            return x

In [None]:
#https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth
netDet=torchvision.models.detection.retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT)

In [None]:
class SPLoss(nn.Module):
    def __init__(self):
        super(SPLoss, self).__init__()
        

    def __call__(self, input, reference):
        a = torch.sum(torch.sum(F.normalize(input, p=2, dim=2) * F.normalize(reference, p=2, dim=2),dim=2, keepdim=True))
        b = torch.sum(torch.sum(F.normalize(input, p=2, dim=3) * F.normalize(reference, p=2, dim=3),dim=3, keepdim=True))
        return -(a + b) / input.size(2)

class GPLoss(nn.Module):
    def __init__(self):
        super(GPLoss, self).__init__()
        self.trace = SPLoss()
  
    def get_image_gradients(self,input):        
        f_v_1 = F.pad(input,(0,-1,0,0))
        f_v_2 = F.pad(input,(-1,0,0,0))
        f_v = f_v_1-f_v_2

        f_h_1 = F.pad(input,(0,0,0,-1))
        f_h_2 = F.pad(input,(0,0,-1,0))
        f_h = f_h_1-f_h_2

        return f_v, f_h

    def __call__(self, input, reference, normalize=True):
        if normalize:
            input = (input+1)/2
            reference = (reference+1)/2

        input_v,input_h = self.get_image_gradients(input)
        ref_v, ref_h = self.get_image_gradients(reference)

        trace_v = self.trace(input_v,ref_v)
        trace_h = self.trace(input_h,ref_h)
        return trace_v + trace_h

class GANLoss(nn.Module):
    """Define different GAN objectives.

    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """

    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        """ Initialize the GANLoss class.

        Parameters:
            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_fake_label (bool) - - label of a fake image

        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode in ['wgangp']:
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    def MSE_loss_weighted(self, prediction, target, mask):
        return (mask * ((prediction - target)**2)).mean()

    def get_target_tensor(self, prediction, target_is_real):
        """Create label tensors with the same size as the input.

        Parameters:
            prediction (tensor) - - tpyically the prediction from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            A label tensor filled with ground truth label, and with the size of the input
        """

        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)

    def __call__(self, prediction, target_is_real, mask=None):
        """Calculate loss given Discriminator's output and grount truth labels.

        Parameters:
            prediction (tensor) - - tpyically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            the calculated loss.
        """
        ''' 
        if isinstance(prediction, list):
            if isinstance(prediction[0], list):
                loss = 0
                for pred in prediction:
                    loss += self.calculate_loss(pred[-1], target_is_real)
                return loss
            else:
                return self.calculate_loss(prediction[-1], target_is_real)
        else:    
        '''
        return self.calculate_loss(prediction, target_is_real, mask)
    
    def calculate_loss(self, prediction, target_is_real, mask):
        if self.gan_mode in ['lsgan', 'vanilla']:
            target_tensor = self.get_target_tensor(prediction, target_is_real)
            if self.gan_mode == 'lsgan' and mask is not None:
                mask = F_upsample(mask, size=prediction.shape[-2:])
                loss = self.MSE_loss_weighted(prediction, target_tensor, mask)
            else:
                loss = self.loss(prediction, target_tensor)
        elif self.gan_mode == 'wgangp':
            if target_is_real:
                loss = -prediction.mean()
            else:
                loss = prediction.mean()
        return loss
    
def get_optimizer(model, opt, model_name, extra_model=None):
    """Return an optimizer for the model

    Parameters:
        model               -- model that whose parameters will be optimized 
        opt (option class)  -- stores all the experiment flags; needs to be a subclass of BaseOptions．
                               opt.optim_[model_name] is the name of optimizer: SGD | Adam.　
        model_name          -- name of the model, needed for fetching the correct values for the model from opt.　
    """

    optim_choice = getattr(opt, "optim_" + model_name)
    lr = getattr(opt, "lr_" + model_name)
    
    learnable_params = list(model.parameters())
    if extra_model is not None:
        learnable_params += extra_model.parameters()

    return torch.optim.Adam(learnable_params, lr=lr, betas=(opt.beta1, 0.999))
   

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda:{}'.format(torch.cuda.current_device()))      
else:
    device = torch.device('cpu')

netG = UnetGeneratorBilinear()
netD = NLayerDiscriminator()
criterionGAN = GANLoss(opt.gan_mode).to(device)
criterionFeat = torch.nn.L1Loss()
citerionGPL = GPLoss().to(device) 

optimizer_D = get_optimizer(netD, opt, model_name="D")
optimizer_G = get_optimizer(netG, opt, model_name="G")
optimizer_Det = get_optimizer(netDet,opt,"Det")
#scheduler, backprop, detector, dataloader

In [None]:
def backward_D(real_A,fake_B,real_B):
        """Calculate GAN loss for the discriminator"""
        

        optimizer_D.zero_grad() 

        features = None
        # Fake; stop backprop to the generator by detaching fake_B
        fake_AB = torch.cat((real_A, fake_B), 1) 
        pred_fake = netD(fake_AB.detach())
        loss_D_fake = criterionGAN(pred_fake, False)
        loss_D_fake.backward()

        # Real
        real_AB = torch.cat((real_A, real_B), 1).requires_grad_()
        pred_real = netD(real_AB, opt.feature_matching)
        loss_D_real = criterionGAN(pred_real[-1] if opt.feature_matching else pred_real, True, mask)
        # regularize the discriminator with gradient norm.
        if sopt.lambda_gr:
            loss_D_real.backward(retain_graph=True)
            grad_reg = opt.lambda_GR * networks.compute_grad2(pred_real, real_AB).mean()
            grad_reg.backward()
        else:
            loss_D_real.backward()
        
        optimizer_D.step()
        return features

def backward_G(self, current_iter, current_epoch):
        """Calculate Detection, GAN and L1 loss for the generator"""
        
        debug_grads = False
        if self.opt.debug_grad_norms:
            grads_iter_modulo = current_iter % 400
            number_of_samples = self.number_of_iters * self.opt.batch_size
            debug_grads = self.opt.debug_grad_norms and grads_iter_modulo <= number_of_samples 
            is_last = grads_iter_modulo == (number_of_samples)

        if self._update_G(current_epoch):
        
            self.optimizer_G.zero_grad()
            ###############################################################################
            # #TODO: Decrease the contribution from all discriminator-related losses      # 
            # proportionally to the decay of the learning rate of discriminator.          #
            # I do not do this for now; thus the weight is always 1.                      #
            ###############################################################################
            
            if self.opt.alpha_mode_disc == 'dec':
                if self.opt.lr_policy_D == 'step':
                    disc_weight = 0.1 ** (current_epoch // self.opt.lr_decay_epoch_D)
                    #disc_weight = 0.1 if (current_iter // self.opt.dataset_size) >= self.opt.lr_decay_epoch_D else 1
                else:
                    if self.opt.n_epochs_decay_D == 0:
                        disc_weight = 1
                    else:
                        disc_weight = (self.opt.epochs_per_model['D']*self.opt.dataset_size - current_iter)/(self.opt.n_epochs_decay_D*self.opt.dataset_size)
                disc_weight = min(self.opt.alpha_disc, self.opt.alpha_disc*disc_weight)
            else:
                disc_weight = self.opt.alpha_disc
            
            use_disc_loss = self._update_D(current_epoch)

            # overall loss
            loss_G = 0

            # GAN-related losses
            self.loss_G_GAN = 0
            self.loss_G_GAN_feat = 0
            # gradient profile loss (cosine similarity between gradient/edge maps)
            self.loss_G_GPL = 0.
            if disc_weight > 0 and use_disc_loss:
                # G(A) should fake the discriminator
                fake_AB = torch.cat((self.real_A, self.fake_B), 1)
                pred_fake = self.netD(fake_AB, self.opt.feature_matching)
                self.loss_G_GAN = self.criterionGAN(pred_fake[-1] if self.opt.feature_matching else pred_fake, True, self.mask)

                if debug_grads:
                    self.optimizer_G.zero_grad()
                    curr_loss = self.loss_G_GAN * disc_weight
                    curr_loss.backward(retain_graph=True)
                    plot_grad_flow(
                        self.netG.module.named_parameters(),
                        self.grad_dict["Discriminator"],
                        'Generator Gradients from Discriminator ({:.2f})'.format(disc_weight),
                        os.path.join(
                            self.grad_plot_dir,
                            'generator_gradnorm_discriminator-{:.2f}_it-{}.png'.format(disc_weight, current_iter)),
                        plot=is_last
                    )
                    if is_last:
                        self.grad_dict["Discriminator"] = OrderedDict()
                else:
                    loss_G = loss_G + self.loss_G_GAN * disc_weight


                # feature matching loss
                # match the features of fake and real in the intermediate layers of the discriminator
                if self.opt.feature_matching:
                    real_AB = torch.cat((self.real_A, self.real_B), 1)
                    pred_real = self.netD(real_AB, return_features=True)
                    feat_weights = 4.0 / (self.n_layers_D + 1)
                    for i in range(len(pred_fake)-1):
                        self.loss_G_GAN_feat += feat_weights * self.criterionFeat(pred_fake[i], 
                                                                            pred_real[i].detach()) * self.opt.lambda_feat

                    if debug_grads:
                        self.optimizer_G.zero_grad()
                        curr_loss = self.loss_G_GAN_feat * disc_weight
                        curr_loss.backward(retain_graph=True)
                        plot_grad_flow(
                            self.netG.module.named_parameters(), 
                            self.grad_dict["FM"],
                            'Generator Gradients from Feature-Matching ({:.2f})'.format(self.opt.lambda_feat),
                            os.path.join(
                                self.grad_plot_dir,
                                'generator_gradnorm_FM-{:.2f}_it-{}.png'.format(self.opt.lambda_feat, current_iter)
                            ),
                            plot=is_last
                        )
                        if is_last:
                            self.grad_dict["FM"] = OrderedDict()
                    else:
                        loss_G = loss_G + self.loss_G_GAN_feat * disc_weight
                
                if self.opt.lambda_gpl:
                    #TODO: can use either real LDR or HDR as a reference. These options should be compared.
                    self.loss_G_GPL = self.citerionGPL(self.fake_B, self.real_B, normalize=True)

                    if debug_grads:
                        self.optimizer_G.zero_grad()
                        curr_loss = self.loss_G_GPL * self.opt.lambda_gpl
                        curr_loss.backward(retain_graph=True)
                        plot_grad_flow(
                            self.netG.module.named_parameters(),
                            self.grad_dict["GPL"],
                            'Generator Gradients from Gradient-Profile ({:.2f})'.format(self.opt.lambda_gpl),
                            os.path.join(
                                self.grad_plot_dir,
                                'generator_gradnorm_GPL-{:.2f}_it-{}.png'.format(self.opt.lambda_gpl, current_iter)
                            ),
                            plot=is_last
                        )
                        if is_last:
                            self.grad_dict["GPL"] = OrderedDict()
                    else:
                        loss_G = loss_G + self.loss_G_GPL * self.opt.lambda_gpl

            # Det(G(A)) should detect objects.
            if self.with_detector:    
                if self.opt.unroll > 0:
                    # see how detector reacts to real images + fake images 
                    # by unrolling <unroll> many steps
                    backup = copy.deepcopy(self.netDet.module.state_dict())
                    for _ in range(self.opt.unroll):
                        self.backward_Det_unrolled(current_iter, current_epoch)

                # detect objects on generated images to update the generator
                self.optimizer_Det.zero_grad()
                self.detect_objects(detach_input=False, current_iter=current_iter)
                detector_loss, log_vars = self.parse_detection_loss(register_loss=self.opt.simult_det_update)

                if debug_grads:
                    self.optimizer_G.zero_grad()
                    detector_loss.backward(retain_graph=True)
                    plot_grad_flow(
                        self.netG.module.named_parameters(),
                        self.grad_dict["Detector"],
                        'Generator Gradients from Detector ({:.2f})'.format(self.opt.alpha_det),
                        os.path.join(
                            self.grad_plot_dir,
                            'generator_gradnorm_detector-{:.2f}_it-{}.png'.format(self.opt.alpha_det, current_iter)
                        ),
                        plot=is_last
                    )
                    if is_last:
                        self.grad_dict["Detector"] = OrderedDict()
                else:
                    loss_G = loss_G + detector_loss
            
            if not debug_grads:
                loss_G.backward()  
                self.optimizer_G.step()

            if self.opt.unroll > 0:
                self.netDet.module.load_state_dict(backup)    
                del backup
        
        if not debug_grads:
            if self._update_Det(current_epoch) and self.opt.simult_det_update:
                self.optimizer_Det.step()

def backward_Det(self, current_iter, current_epoch):
        

        optimizer_Det.zero_grad()     

        self.detect_objects(detach_input=True, current_iter=current_iter)
        detector_loss, log_vars = self.parse_detection_loss(
                                        register_loss=not self.opt.maximize_detection_loss)
        
        if self.opt.train_on_real:
            self.detect_objects(current_iter=current_iter, detect_real=True)
            detector_loss_real, log_vars_real = self.parse_detection_loss(
                                        register_loss=self.opt.maximize_detection_loss)
            detector_loss = detector_loss + detector_loss_real
        
        detector_loss.backward()
        optimizer_Det.step()

In [None]:
real_A=None #HDR images
real_B=None #LDR images
fake_B=netG(real_A) #forward
backward_D(real_A,fake_B,real_B )
backward_Det(current_iter, current_epoch)
backward_G(current_iter, current_epoch)