In [1]:
import os
import torch 
import torchvision
import torch.nn as nn
from torch.nn import init
from torchvision.models.detection import retinanet_resnet50_fpn
from torchvision.transforms.functional import normalize as F_normalize
from torch.nn.functional import interpolate as F_upsample
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.nn.functional as F
import copy
import numpy as np
import OpenEXR, Imath
import functools

In [4]:
batch_size=8
num_classes=7 #(6+1 for background)
dataset_size=1270

In [None]:
def load_exr(filename):
    """Load an EXR file and return as a NumPy array."""
    file = OpenEXR.InputFile(filename)
    dw = file.header()['dataWindow']
    size = (dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1)

    pt = Imath.PixelType(Imath.PixelType.FLOAT)
    channels = ['R', 'G', 'B']

    rgb = [np.frombuffer(file.channel(c, pt), dtype=np.float32) for c in channels]
    rgb = [np.reshape(c, (size[1], size[0])) for c in rgb]
    image = np.stack(rgb, axis=-1)

    return image

class HDRDataset(Dataset):
    def __init__(self, png_dir,exr_dir,txt_dir):

        self.png_dir = png_dir
        self.exr_dir = exr_dir
        self.txt_dir = txt_dir
        self.filenames = [os.path.splitext(f)[0] for f in os.listdir(png_dir) if f.endswith('.png')]

        self.labels = []
        self.bboxes = []

        for filename in self.filenames:
            txt_filename = filename + '.txt'
            txt_path = os.path.join(txt_dir, txt_filename)
            with open(txt_path, 'r') as file:
                data = file.readlines()
                cls, x, y, w, h = map(float, data[0].split())
                self.classes.append(int(cls))
                self.bboxes.append([x, y, w, h])

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        png_path = os.path.join(self.png_dir, self.filenames[idx]+'.png')
        exr_path = os.path.join(self.exr_dir, self.filenames[idx]+'.exr')
        
        exr_image = load_exr(exr_path)
    
        png_image = Image.open(png_path)
        png_image = np.array(png_image)

        label = self.classes[idx]
        bbox = self.bboxes[idx]

        # Convert image and bbox to tensor
        png_image = torch.tensor(png_image, dtype=torch.float32).permute(2, 0, 1)  # Channel first format
        bbox = torch.tensor(bbox, dtype=torch.float32)

        return exr_image,png_image, label, bbox


train_data = HDRDataset('train/png','train/exr','train/txt')
test_data = HDRDataset('test/png','test/exr','test/txt')

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [None]:
class EfficientAttention(nn.Module):
    def __init__(self, val_channels=3, key_channels=4, in_channels=0):
        super().__init__()
        self.in_channels = in_channels if in_channels else val_channels
        self.key_channels = key_channels
        self.val_channels = val_channels

        self.keys = nn.Conv2d(self.val_channels, self.key_channels, 1)
        self.values = nn.Conv2d(self.val_channels, self.key_channels, 1)
        self.queries = nn.Conv2d(self.in_channels, self.key_channels, 1)
        self.reprojection = nn.Conv2d(self.key_channels, self.val_channels, 1)

    def forward(self, value_, input_=None):
        n, c, h, w = value_.size()
        values = self.values(value_).reshape((n, self.key_channels, h * w))
        keys = self.keys(value_).reshape((n, self.key_channels, h * w))
        
        if input_ is not None:
            queries = self.queries(input_)
            
            # maxpool the query if it is larger than the value 
            _, _, h_i, w_i = input_.size()
            if w_i > w or h_i > h:
                queries = F.max_pool2d(queries, (h_i//h, w_i//w))
            
            queries = queries.reshape(n, self.key_channels, h * w)
        else:
            queries = self.queries(value_).reshape(n, self.key_channels, h * w)

        key = F.softmax(keys, dim=2)
        query = F.softmax(queries, dim=1)
        
        context = key @ values.transpose(1, 2)
        attention = (
            context.transpose(1, 2) @ query
        ).reshape(n, self.key_channels, h, w)

        reprojected_value = self.reprojection(attention)
        attention = reprojected_value + value_
        return attention

In [None]:
class UnetGeneratorBilinear(nn.Module):
    def __init__(self, opt, norm_layer):
        super(UnetGeneratorBilinear, self).__init__()

        use_bias = norm_layer == nn.InstanceNorm2d
        
        self.normalize = True
        self.self_attention = True
        self.use_avgpool = True
        self.skip = 0.8
        self.use_tanh = True
        # if self.use_tanh:
        #     if opt.hardtanh:
        self.final_tanh = nn.Hardtanh() 
            # else:
            #     self.final_tanh = nn.Tanh() 

        p = 1
        if self.self_attention:
            self.conv1_1 = nn.Conv2d(6, 32, 3, padding=p)
            self.attention_in = EfficientAttention(val_channels=3, key_channels=3, in_channels=3)
            self.attention_out = EfficientAttention(val_channels=3, key_channels=3, in_channels=3)
            self.attention_1 = EfficientAttention(val_channels=32, key_channels=4, in_channels=3)
            self.attention_2 = EfficientAttention(val_channels=64, key_channels=4, in_channels=3)
            self.attention_3 = EfficientAttention(val_channels=128, key_channels=8, in_channels=3)
            self.attention_4 = EfficientAttention(val_channels=512, key_channels=16, in_channels=3)
        else:
            self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p)

        self.LReLU1_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn1_1 = norm_layer(32)
        self.conv1_2 = nn.Conv2d(32, 32, 3, padding=p)
        self.LReLU1_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn1_2 = norm_layer(32)
        self.max_pool1 = nn.AvgPool2d(2) if self.use_avgpool == 1 else nn.MaxPool2d(2)

        self.conv2_1 = nn.Conv2d(32, 64, 3, padding=p)
        self.LReLU2_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn2_1 = norm_layer(64)
        self.conv2_2 = nn.Conv2d(64, 64, 3, padding=p)
        self.LReLU2_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn2_2 = norm_layer(64)
        self.max_pool2 = nn.AvgPool2d(2) if self.use_avgpool == 1 else nn.MaxPool2d(2)

        self.conv3_1 = nn.Conv2d(64, 128, 3, padding=p)
        self.LReLU3_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn3_1 = norm_layer(128)
        self.conv3_2 = nn.Conv2d(128, 128, 3, padding=p)
        self.LReLU3_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn3_2 = norm_layer(128)
        self.max_pool3 = nn.AvgPool2d(2) if self.use_avgpool == 1 else nn.MaxPool2d(2)
        
        self.conv4_11 = nn.Conv2d(128, 128, 1, padding=p*0)
        self.LReLU4_11 = nn.LeakyReLU(0.2, inplace=True)
        self.bn4_11 = norm_layer(128)
        self.conv4_12 = nn.Conv2d(128, 128, 3, padding=p*1)
        self.LReLU4_12 = nn.LeakyReLU(0.2, inplace=True)
        self.bn4_12 = norm_layer(128)
        self.conv4_13 = nn.Conv2d(128, 128, 5, padding=p*2)
        self.LReLU4_13 = nn.LeakyReLU(0.2, inplace=True)
        self.bn4_13 = norm_layer(128)
        self.conv4_14 = nn.Conv2d(128, 128, 7, padding=p*3)
        self.LReLU4_14 = nn.LeakyReLU(0.2, inplace=True)
        self.bn4_14 = norm_layer(128)
        self.conv4_2 = nn.Conv2d(512, 256, 3, padding=p)
        self.LReLU4_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn4_2 = norm_layer(256)
        
        
        # Uncomment this block for further downsampling
        '''
        self.max_pool4 = nn.AvgPool2d(2) if self.use_avgpool == 1 else nn.MaxPool2d(2)

        self.conv5_1 = nn.Conv2d(256, 512, 3, padding=p)
        self.LReLU5_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn5_1 = norm_layer(512)
        self.conv5_2 = nn.Conv2d(512, 512, 3, padding=p)
        self.LReLU5_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn5_2 = norm_layer(512)

        
        self.deconv5 = nn.Conv2d(512, 256, 3, padding=p)
        self.conv6_1 = nn.Conv2d(512, 256, 3, padding=p)
        self.LReLU6_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn6_1 = norm_layer(256)
        self.conv6_2 = nn.Conv2d(256, 256, 3, padding=p)
        self.LReLU6_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn6_2 = norm_layer(256)
        '''

        self.deconv6 = nn.Conv2d(256, 128, 3, padding=p)
        self.conv7_1 = nn.Conv2d(256, 128, 3, padding=p)
        self.LReLU7_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn7_1 = norm_layer(128)
        self.conv7_2 = nn.Conv2d(128, 128, 3, padding=p)
        self.LReLU7_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn7_2 = norm_layer(128)

        self.deconv7 = nn.Conv2d(128, 64, 3, padding=p)
        self.conv8_1 = nn.Conv2d(128, 64, 3, padding=p)
        self.LReLU8_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn8_1 = norm_layer(64)
        self.conv8_2 = nn.Conv2d(64, 64, 3, padding=p)
        self.LReLU8_2 = nn.LeakyReLU(0.2, inplace=True)
        self.bn8_2 = norm_layer(64)

        self.deconv8 = nn.Conv2d(64, 32, 3, padding=p)
        self.conv9_1 = nn.Conv2d(64, 32, 3, padding=p)
        self.LReLU9_1 = nn.LeakyReLU(0.2, inplace=True)
        self.bn9_1 = norm_layer(32)
        self.conv9_2 = nn.Conv2d(32, 32, 3, padding=p)
        self.LReLU9_2 = nn.LeakyReLU(0.2, inplace=True)

        self.conv10 = nn.Conv2d(32, 3, 1)

    def forward(self, input):
        if self.self_attention:
            attended_inp = self.attention_in(input)
            x = self.bn1_1(self.LReLU1_1(self.conv1_1(torch.cat([input, attended_inp], dim=1))))
        else:
            x = self.bn1_1(self.LReLU1_1(self.conv1_1(input)))
        conv1 = self.bn1_2(self.LReLU1_2(self.conv1_2(x)))
        x = self.max_pool1(conv1)

        x = self.bn2_1(self.LReLU2_1(self.conv2_1(x)))
        conv2 = self.bn2_2(self.LReLU2_2(self.conv2_2(x)))
        x = self.max_pool2(conv2)

        x = self.bn3_1(self.LReLU3_1(self.conv3_1(x)))
        conv3 = self.bn3_2(self.LReLU3_2(self.conv3_2(x)))
        x = self.max_pool3(conv3)

        ########## Starts: Bottom of the U-NET ##########
        x_1 = self.bn4_11(self.LReLU4_11(self.conv4_11(x)))
        x_2 = self.bn4_12(self.LReLU4_12(self.conv4_12(x)))
        x_3 = self.bn4_13(self.LReLU4_13(self.conv4_13(x)))
        x_4 = self.bn4_14(self.LReLU4_14(self.conv4_14(x)))
        x = torch.cat([x_1,x_2,x_3,x_4], dim=1)
        x = self.attention_4(x, input) if self.self_attention else x
        conv6 = self.bn4_2(self.LReLU4_2(self.conv4_2(x)))
        
        # uncomment this block for further downsampling
        '''
        x = self.bn4_1(self.LReLU4_1(self.conv4_1(x)))
        conv4 = self.bn4_2(self.LReLU4_2(self.conv4_2(x)))
        x = self.max_pool4(conv4)

        x = self.bn5_1(self.LReLU5_1(self.conv5_1(x)))
        #x = x*attention_map5 if self.self_attention else x
        x = self.attention_5(x) if self.self_attention else x
        conv5 = self.bn5_2(self.LReLU5_2(self.conv5_2(x)))
        
        conv5 = F_upsample(conv5, scale_factor=2, mode='bilinear')
        #conv4 = conv4*attention_map4 if self.self_attention else conv4
        conv4 = self.attention_4(conv4) if self.self_attention else conv4
        up6 = torch.cat([self.deconv5(conv5), conv4], 1)
        x = self.bn6_1(self.LReLU6_1(self.conv6_1(up6)))
        conv6 = self.bn6_2(self.LReLU6_2(self.conv6_2(x)))
        '''
        ########### Ends: Bottom of the U-NET ##########

        conv6 = F_upsample(conv6, scale_factor=2, mode='bilinear')
        conv3 = self.attention_3(conv3, input) if self.self_attention else conv3
        up7 = torch.cat([self.deconv6(conv6), conv3], 1)
        x = self.bn7_1(self.LReLU7_1(self.conv7_1(up7)))
        conv7 = self.bn7_2(self.LReLU7_2(self.conv7_2(x)))

        conv7 = F_upsample(conv7, scale_factor=2, mode='bilinear')
        conv2 = self.attention_2(conv2, input) if self.self_attention else conv2
        up8 = torch.cat([self.deconv7(conv7), conv2], 1)
        x = self.bn8_1(self.LReLU8_1(self.conv8_1(up8)))
        conv8 = self.bn8_2(self.LReLU8_2(self.conv8_2(x)))

        conv8 = F_upsample(conv8, scale_factor=2, mode='bilinear')
        conv1 = self.attention_1(conv1, input) if self.self_attention else conv1
        up9 = torch.cat([self.deconv8(conv8), conv1], 1)
        x = self.bn9_1(self.LReLU9_1(self.conv9_1(up9)))
        conv9 = self.LReLU9_2(self.conv9_2(x))

        latent = self.conv10(conv9)
        latent = self.attention_out(latent, input) if self.self_attention else latent

        if self.skip:
            if self.normalize:
                min_latent = torch.amin(latent, dim=(0,2,3), keepdim=True)
                max_latent = torch.amax(latent, dim=(0,2,3), keepdim=True)
                latent = (latent - min_latent) / (max_latent - min_latent)
                
            output = latent + self.skip * input
        else:
            output = latent
        
        if self.use_tanh:
            output = self.final_tanh(output)

        return output


In [None]:
class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator

        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        self.n_layers = n_layers

        kw = 4
        padw = 1
        sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [[
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)

        sequence += [[
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]]

        sequence += [[nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]]  # output 1 channel prediction map

        for n in range(len(sequence)):
            setattr(self, 'model' + str(n), nn.Sequential(*sequence[n]))

    def forward(self, x, return_features=False):
        """Standard forward."""
        if return_features:
            feats = [x]
            for n in range(self.n_layers + 2):
                feats.append(getattr(self, 'model' + str(n))(feats[-1]))    
            return feats[1:]
        else:
            for n in range(self.n_layers + 2):
                x = getattr(self, 'model' + str(n))(x)
            return x

In [None]:
class SPLoss(nn.Module):
    def __init__(self):
        super(SPLoss, self).__init__()
        

    def __call__(self, input, reference):
        a = torch.sum(torch.sum(F.normalize(input, p=2, dim=2) * F.normalize(reference, p=2, dim=2),dim=2, keepdim=True))
        b = torch.sum(torch.sum(F.normalize(input, p=2, dim=3) * F.normalize(reference, p=2, dim=3),dim=3, keepdim=True))
        return -(a + b) / input.size(2)

class GPLoss(nn.Module):
    def __init__(self):
        super(GPLoss, self).__init__()
        self.trace = SPLoss()
  
    def get_image_gradients(self,input):        
        f_v_1 = F.pad(input,(0,-1,0,0))
        f_v_2 = F.pad(input,(-1,0,0,0))
        f_v = f_v_1-f_v_2

        f_h_1 = F.pad(input,(0,0,0,-1))
        f_h_2 = F.pad(input,(0,0,-1,0))
        f_h = f_h_1-f_h_2

        return f_v, f_h

    def __call__(self, input, reference, normalize=True):
        if normalize:
            input = (input+1)/2
            reference = (reference+1)/2

        input_v,input_h = self.get_image_gradients(input)
        ref_v, ref_h = self.get_image_gradients(reference)

        trace_v = self.trace(input_v,ref_v)
        trace_h = self.trace(input_h,ref_h)
        return trace_v + trace_h

class GANLoss(nn.Module):
    """Define different GAN objectives.

    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """

    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        """ Initialize the GANLoss class.

        Parameters:
            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_fake_label (bool) - - label of a fake image

        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode in ['wgangp']:
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    def MSE_loss_weighted(self, prediction, target, mask):
        return (mask * ((prediction - target)**2)).mean()

    def get_target_tensor(self, prediction, target_is_real):
        """Create label tensors with the same size as the input.

        Parameters:
            prediction (tensor) - - tpyically the prediction from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            A label tensor filled with ground truth label, and with the size of the input
        """

        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)

    def __call__(self, prediction, target_is_real, mask=None):
        """Calculate loss given Discriminator's output and grount truth labels.

        Parameters:
            prediction (tensor) - - tpyically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images

        Returns:
            the calculated loss.
        """
        ''' 
        if isinstance(prediction, list):
            if isinstance(prediction[0], list):
                loss = 0
                for pred in prediction:
                    loss += self.calculate_loss(pred[-1], target_is_real)
                return loss
            else:
                return self.calculate_loss(prediction[-1], target_is_real)
        else:    
        '''
        return self.calculate_loss(prediction, target_is_real, mask)
    
    def calculate_loss(self, prediction, target_is_real, mask):
        if self.gan_mode in ['lsgan', 'vanilla']:
            target_tensor = self.get_target_tensor(prediction, target_is_real)
            if self.gan_mode == 'lsgan' and mask is not None:
                mask = F_upsample(mask, size=prediction.shape[-2:])
                loss = self.MSE_loss_weighted(prediction, target_tensor, mask)
            else:
                loss = self.loss(prediction, target_tensor)
        elif self.gan_mode == 'wgangp':
            if target_is_real:
                loss = -prediction.mean()
            else:
                loss = prediction.mean()
        return loss
    
def get_optimizer(model,lr, extra_model=None):
    """Return an optimizer for the model

    Parameters:
        model               -- model that whose parameters will be optimized 
        opt (option class)  -- stores all the experiment flags; needs to be a subclass of BaseOptions．
                               opt.optim_[model_name] is the name of optimizer: SGD | Adam.　
        model_name          -- name of the model, needed for fetching the correct values for the model from opt.　
    """

    learnable_params = list(model.parameters())
    if extra_model is not None:
        learnable_params += extra_model.parameters()

    return torch.optim.Adam(learnable_params, lr=lr, betas=(0.5, 0.999))
   

In [None]:
class MulConstant(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor, constant):
        # ctx is a context object that can be used to stash information
        # for backward computation
        ctx.constant = constant
        return tensor

    @staticmethod
    def backward(ctx, grad_output):
        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        grad_input = grad_output.clone()
        grad_input = ctx.constant * grad_input
        return grad_input, None
    
class CustomNormalization(nn.Module):
    def __init__(self, normalize):
        super(CustomNormalization, self).__init__()
        self.normalize = normalize
        

    def forward(self, x, alpha=1.0):
        
            # this function acts as an identity in the forward pass,
            # but scales the gradients in the backward pass.
        x = MulConstant.apply(x, alpha)

        if self.normalize: 
            # map the tanh output to 8bit range
            x = x + 1
            x = x / 2 * 255
            # apply the normalization values for the pretrained detection network
            x = F_normalize(x, 
                    [123.675, 116.28, 103.53], 
                    [58.395, 57.12, 57.375]
            )
        return x
custom_normalization=CustomNormalization(normalize=True)

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda:{}'.format(torch.cuda.current_device()))      
else:
    device = torch.device('cpu')

lambda_gpl=0.8
lambda_feat=10
lambda_gr=0
gan_mode='lsgan'
norm_layer='instance'
lr_Det=5e-4
lr_D=2e-4
lr_G=2e-4
beta1=0.5
lr_decay_epoch_Det=30
lr_decay_epoch_D=20
lr_decay_epoch_G=20
n_layers_D=3
n_epochs_model=20
n_epochs_decay_G=30 
n_epochs_decay_D=30 
unroll=0
lr_policy_D='linear'
alpha_disc=1.0
alpha_det=1.0
alpha_mode_det='cg_id'
alpha_epochs_det=5
alpha_mode_disc='Det'
maximize_detection_loss=True  #generator tries to maximize the detection loss
simult_det_update=False #NOT USED. if set, update the generator and detector simultaneouslys
input_nc=3
output_nc=3
feature_matching=True #use feature matching loss on the discriminator
lr_start_epoch_Det=0
with_detector=True
train_on_real=True #train the detector on real images as well
freeze_detector=False #If set, detector network is not updated
no_disc_loss=False #If set, do not use discriminator loss (only for debugging)
#input_nc the number of channels in input images
#ndf the number of filters in the first conv layer
norm_layer1 = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
netG = UnetGeneratorBilinear(norm_layer=norm_layer).to(device)
netD = NLayerDiscriminator(input_nc=input_nc+output_nc,ndf=64,n_layers=3,norm_layer=norm_layer1).to(device)
#https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth
netDet=retinanet_resnet50_fpn(pretrained=True)
in_features=netDet.head.classification_head.conv[0].out_channels
netDet.head.classification_head=nn.Conv2d(
    in_features,
    num_classes*netDet.head.classification_head.num_anchors,
    kernel_size=(3,3),
    stride=(1,1),
    padding=(1,1)
)
netDet=netDet.to(device)
criterionGAN = GANLoss(gan_mode).to(device)
criterionFeat = nn.L1Loss().to(device)
citerionGPL = GPLoss().to(device) 
l1_loss_func = nn.L1Loss().to(device)
class_loss_func = nn.CrossEntropyLoss().to(device)

optimizer_D = get_optimizer(netD, lr_D)
optimizer_G = get_optimizer(netG, lr_G)
optimizer_Det = get_optimizer(netDet,lr_Det)
#scheduler, backprop, detector, dataloader

In [None]:
def compute_grad2(disc_out, real_data):
    ''' Calculate the gradient norm for regularization. Different from WGAN-GP, norm is zero centered.

    Arguments:
        disc_out (tensor)           -- discriminator output
        real_data (tensor array)    -- real images

    Returns the squared gradient norm     
    '''
    batch_size = real_data.size(0)
    grad_dout = torch.autograd.grad(
        outputs=disc_out.sum(), inputs=real_data,
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    grad_dout2 = grad_dout.pow(2)
    
    assert(grad_dout2.size() == real_data.size())
    
    reg = grad_dout2.view(batch_size, -1).sum(1)
    return reg

def detect_objects(detach_input=False, current_iter=None, detect_real=False,real_B=None,fake_B=None):
        isTrain=True
        if current_iter:
            if 'ig' in alpha_mode_det:
                alpha_weight = current_iter / (alpha_epochs_det*dataset_size)
                curr_alpha = float(min(alpha_det, alpha_det * alpha_weight))
            else:
                curr_alpha = float(alpha_det)
        else:
            curr_alpha = 1.0

        if detect_real:
            _to_detector = custom_normalization(real_B, alpha=curr_alpha)
        else:
            if isTrain and maximize_detection_loss:
                curr_alpha = -1 * curr_alpha
            _to_detector = custom_normalization(fake_B, alpha=curr_alpha)
        
        if detach_input:
            _to_detector = _to_detector.detach()

        detector_out = netDet(_to_detector)
        
        return detector_out
    
def parse_detection_loss(register_loss,detector_out,cls_labels,bboxes_true):
        losses={}
        losses['classification_loss'] = class_loss_func(detector_out[0]['scores'].view(-1, num_classes), cls_labels.view(-1))
        losses['bbox_loss'] = l1_loss_func(detector_out[0]['boxes'], bboxes_true)
        log_vars = {}
        for loss_name, loss_value in losses.items():
            log_vars[loss_name] = loss_value.mean()  # Calculating mean might be redundant depending on tensor shape
        loss = sum(value for key, value in log_vars.items() if 'loss' in key)
        return loss
        
    
def _update_D(current_epoch):
        '''returns true if the discriminator needs to be updated'''
        return (not no_disc_loss and
                current_epoch <= n_epochs_model)
def _update_Det(current_epoch):
        '''returns true if the detector needs to be updated'''
        return  (with_detector and not freeze_detector and 
                current_epoch <= n_epochs_model and 
                lr_start_epoch_Det <= current_epoch)
def _update_G(current_epoch):
        '''returns true if the generator needs to be updated'''
        return current_epoch <= n_epochs_model
def backward_D(real_A,fake_B,real_B):
        """Calculate GAN loss for the discriminator"""
        optimizer_D.zero_grad() 
        features = None
        # Fake; stop backprop to the generator by detaching fake_B
        fake_AB = torch.cat((real_A, fake_B), 1) 
        pred_fake = netD(fake_AB.detach())
        loss_D_fake = criterionGAN(pred_fake, False)
        loss_D_fake.backward()

        # Real
        real_AB = torch.cat((real_A, real_B), 1).requires_grad_()
        pred_real = netD(real_AB,feature_matching)
        loss_D_real = criterionGAN(pred_real[-1] if feature_matching else pred_real, True, mask=None)
        # regularize the discriminator with gradient norm.
        if lambda_gr:
            loss_D_real.backward(retain_graph=True)
            grad_reg = lambda_gr * compute_grad2(pred_real, real_AB).mean()
            grad_reg.backward()
        else:
            loss_D_real.backward()
        
        optimizer_D.step()
        return features

def backward_G(current_iter, current_epoch,real_A,fake_B,real_B):
        """Calculate Detection, GAN and L1 loss for the generator"""
        
        debug_grads = False
        debug_grad_norms=False
        # if debug_grad_norms:
        #     grads_iter_modulo = current_iter % 400
        #     number_of_samples = number_of_iters * batch_size
        #     debug_grads = debug_grad_norms and grads_iter_modulo <= number_of_samples 
        #     is_last = grads_iter_modulo == (number_of_samples)

        if _update_G(current_epoch):
        
            optimizer_G.zero_grad()
            ###############################################################################
            # #TODO: Decrease the contribution from all discriminator-related losses      # 
            # proportionally to the decay of the learning rate of discriminator.          #
            # I do not do this for now; thus the weight is always 1.                      #
            ###############################################################################
            
            if alpha_mode_disc == 'dec':
                if lr_policy_D == 'step':
                    disc_weight = 0.1 ** (current_epoch // lr_decay_epoch_D)
                    #disc_weight = 0.1 if (current_iter // self.opt.dataset_size) >= self.opt.lr_decay_epoch_D else 1
                else:
                    if n_epochs_decay_D == 0:
                        disc_weight = 1
                    else:
                        disc_weight = (n_epochs_model*dataset_size - current_iter)/(n_epochs_decay_D*dataset_size)
                disc_weight = min(alpha_disc, alpha_disc*disc_weight)
            else:
                disc_weight = alpha_disc
            
            use_disc_loss = _update_D(current_epoch)

            # overall loss
            loss_G = 0

            # GAN-related losses
            loss_G_GAN = 0
            loss_G_GAN_feat = 0
            # gradient profile loss (cosine similarity between gradient/edge maps)
            loss_G_GPL = 0.
            if disc_weight > 0 and use_disc_loss:
                # G(A) should fake the discriminator
                fake_AB = torch.cat((real_A, fake_B), 1)
                pred_fake = netD(fake_AB, feature_matching)
                loss_G_GAN = criterionGAN(pred_fake[-1] if feature_matching else pred_fake, True, mask=None)
                loss_G = loss_G + loss_G_GAN * disc_weight
                # feature matching loss
                # match the features of fake and real in the intermediate layers of the discriminator
                if feature_matching:
                    real_AB = torch.cat((real_A, real_B), 1)
                    pred_real = netD(real_AB, return_features=True)
                    feat_weights = 4.0 / (n_layers_D + 1)
                    for i in range(len(pred_fake)-1):
                        loss_G_GAN_feat += feat_weights * criterionFeat(pred_fake[i], 
                                                                            pred_real[i].detach()) * lambda_feat
                    loss_G = loss_G + loss_G_GAN_feat * disc_weight
                
                if lambda_gpl:
                    #TODO: can use either real LDR or HDR as a reference. These options should be compared.
                    loss_G_GPL = citerionGPL(fake_B, real_B, normalize=True)
                    loss_G = loss_G + loss_G_GPL * lambda_gpl

            # Det(G(A)) should detect objects.
            if with_detector:    
                if unroll > 0:
                    # see how detector reacts to real images + fake images 
                    # by unrolling <unroll> many steps
                    backup = copy.deepcopy(netDet.module.state_dict())
                    # for _ in range(unroll):
                    #     backward_Det_unrolled(current_iter, current_epoch)

                # detect objects on generated images to update the generator
                optimizer_Det.zero_grad()
                detector_out=detect_objects(detach_input=False, current_iter=current_iter,real_B=real_B,fake_B=fake_B)
                detector_loss = parse_detection_loss(register_loss=simult_det_update,detector_out=detector_out,cls_labels=None,bboxes_true=None)
                loss_G = loss_G + detector_loss
            
            if not debug_grads:
                loss_G.backward()  
                optimizer_G.step()

            if  unroll > 0:
                netDet.module.load_state_dict(backup)    
                del backup
        
        if not debug_grads:
            if _update_Det(current_epoch) and simult_det_update:
                optimizer_Det.step()

def backward_Det(current_iter, current_epoch,real_B,fake_B,cls_labels,bbox):
        

        optimizer_Det.zero_grad()     

        detector_out=detect_objects(detach_input=True, current_iter=current_iter,real_B=real_B,fake_B=fake_B)
        detector_loss = parse_detection_loss(register_loss=not maximize_detection_loss,detector_out=detector_out,cls_labels=cls_labels,bboxes_true=bbox)
        
        if train_on_real:
            detector_out=detect_objects(current_iter=current_iter, detect_real=True)
            detector_loss_real = parse_detection_loss(register_loss=maximize_detection_loss, detector_out=detector_out,cls_labels=cls_labels,bboxes_true=bbox)
            detector_loss = detector_loss + detector_loss_real
        
        detector_loss.backward()
        optimizer_Det.step()

def optimize_parameters(current_iter, current_epoch,real_A,real_B):
        # compute fake images: G(A)
        fake_B = netG(real_A) #forward

        # update D
        backward_D(real_A,fake_B,real_B)

        # update Det
        if not simult_det_update:
            backward_Det(current_iter, current_epoch,real_B,fake_B)

        # update G
        backward_G(current_iter, current_epoch,real_A,fake_B,real_B)

In [None]:
n_epochs=30
for epoch in range(n_epochs):
    optimize_parameters(epoch,real_A,real_B)
