# MLiP Group 33 - Photo to Monet using the CUT Algorithm
### Otto van der Himst, Simon Arends, & Hendrik Hoch

The CUT algorithm is introduced in: 
  
    Park et al (2020) Contrastive Learning for Unpaired Image-to-Image Translation
    
Most of the code is directly copied, or a modified version of the code found on [Park et al's GitHub](https://github.com/taesungp/contrastive-unpaired-translation)

In [None]:
%config Completer.use_jedi = False # Enables code auto-completion

In [None]:
from kaggle_datasets import KaggleDatasets
import os
import time

import matplotlib.pyplot as plt
import cv2

import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.optim import lr_scheduler
from packaging import version

from collections import OrderedDict
import functools

!pip install GPUtil
from GPUtil import showUtilization as gpu_usage

import albumentations as A

In [None]:
BASE_PATH = "../input/gan-getting-started/"
MONET_PATH = os.path.join(BASE_PATH, "monet_jpg")
PHOTO_PATH = os.path.join(BASE_PATH, "photo_jpg")

In [None]:
def print_gpu_usage(id):
    time.sleep(1)
    print(f"\n{id}:")
    gpu_usage()

## The CUT Algorithm

The following two blocks of code concern the CUT algorithm. For most purposes these blocks can be collapsed and ignored. When looking into specific implementational details this will be relevant.

In [None]:
def get_norm_layer(norm_type='instance'):
    """Return a normalization layer
    Parameters:
        norm_type (str) -- the name of the normalization layer: batch | instance | none
    For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
    For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
    """
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
    elif norm_type == 'none':
        def norm_layer(x):
            return Identity()
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer

class Downsample(nn.Module):
    def __init__(self, channels, pad_type='reflect', filt_size=3, stride=2, pad_off=0):
        super(Downsample, self).__init__()
        self.filt_size = filt_size
        self.pad_off = pad_off
        self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2)), int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
        self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
        self.stride = stride
        self.off = int((self.stride - 1) / 2.)
        self.channels = channels

        filt = get_filter(filt_size=self.filt_size)
        self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))

        self.pad = get_pad_layer(pad_type)(self.pad_sizes)

    def forward(self, inp):
        if(self.filt_size == 1):
            if(self.pad_off == 0):
                return inp[:, :, ::self.stride, ::self.stride]
            else:
                return self.pad(inp)[:, :, ::self.stride, ::self.stride]
        else:
            return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])

class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False):
        """Construct a PatchGAN discriminator
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        if(no_antialias):
            sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        else:
            sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=1, padding=padw), nn.LeakyReLU(0.2, True), Downsample(ndf)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            if(no_antialias):
                sequence += [
                    nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                    norm_layer(ndf * nf_mult),
                    nn.LeakyReLU(0.2, True)
                ]
            else:
                sequence += [
                    nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
                    norm_layer(ndf * nf_mult),
                    nn.LeakyReLU(0.2, True),
                    Downsample(ndf * nf_mult)]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)

class ResnetBlock(nn.Module):
    """Define a Resnet block"""

    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Initialize the Resnet block
        A resnet block is a conv block with skip connections
        We construct a conv block with build_conv_block function,
        and implement skip connections in <forward> function.
        Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
        """
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        """Construct a convolutional block.
        Parameters:
            dim (int)           -- the number of channels in the conv layer.
            padding_type (str)  -- the name of padding layer: reflect | replicate | zero
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers.
            use_bias (bool)     -- if the conv layer uses bias or not
        Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
        """
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        """Forward function (with skip connections)"""
        out = x + self.conv_block(x)  # add skip connections
        return out

def init_weights(net, init_type='normal', init_gain=0.02, debug=False):
    """Initialize network weights.
    Parameters:
        net (network)   -- network to be initialized
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.
    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
    work better for some applications. Feel free to try yourself.
    """
    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if debug:
                print(classname)
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    net.apply(init_func)  # apply the initialization function <init_func>
        
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True):
    """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
    Parameters:
        net (network)      -- the network to be initialized
        init_type (str)    -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        gain (float)       -- scaling factor for normal, xavier and orthogonal.
        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
    Return an initialized network.
    """
    if len(gpu_ids) > 0:
        assert(torch.cuda.is_available())
        net.to(gpu_ids[0])
        # if not amp:
        # net = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs for non-AMP training
    if initialize_weights:
        init_weights(net, init_type, init_gain=init_gain, debug=debug)
    return net

class ResnetGenerator(nn.Module):
    """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
    We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
    """

    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None):
        """Construct a Resnet-based generator
        Parameters:
            input_nc (int)      -- the number of channels in input images
            output_nc (int)     -- the number of channels in output images
            ngf (int)           -- the number of filters in the last conv layer
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers
            n_blocks (int)      -- the number of ResNet blocks
            padding_type (str)  -- the name of padding layer in conv layers: reflect | replicate | zero
        """
        assert(n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        self.opt = opt
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        model = [nn.ReflectionPad2d(2), # ORIGINALLY 3 ***
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):  # add downsampling layers
            mult = 2 ** i
            if(no_antialias):
                model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                          norm_layer(ngf * mult * 2),
                          nn.ReLU(True)]
            else:
                model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
                          norm_layer(ngf * mult * 2),
                          nn.ReLU(True),
                          Downsample(ngf * mult * 2)]

        mult = 2 ** n_downsampling
        for i in range(n_blocks):       # add ResNet blocks

            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

        for i in range(n_downsampling):  # add upsampling layers
            mult = 2 ** (n_downsampling - i)
            if no_antialias_up:
                model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                             kernel_size=3, stride=2,
                                             padding=1, output_padding=1,
                                             bias=use_bias),
                          norm_layer(int(ngf * mult / 2)),
                          nn.ReLU(True)]
            else:
                model += [Upsample(ngf * mult),
                          nn.Conv2d(ngf * mult, int(ngf * mult / 2),
                                    kernel_size=3, stride=1,
                                    padding=1,  # output_padding=1,
                                    bias=use_bias),
                          norm_layer(int(ngf * mult / 2)),
                          nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input, layers=[], encode_only=False):
        if -1 in layers:
            layers.append(len(self.model))
        if len(layers) > 0:
            feat = input
            feats = []
            for layer_id, layer in enumerate(self.model):
                # print(layer_id, layer)
                feat = layer(feat)
                if layer_id in layers:
                    # print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
                    feats.append(feat)
                else:
                    # print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
                    pass
                if layer_id == layers[-1] and encode_only:
                    # print('encoder only return features')
                    return feats  # return intermediate features alone; stop in the last layers

            return feat, feats  # return both output and intermediate features
        else:
            """Standard forward"""
            fake = self.model(input)
            return fake

class Normalize(nn.Module):

    def __init__(self, power=2):
        super(Normalize, self).__init__()
        self.power = power

    def forward(self, x):
        norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
        out = x.div(norm + 1e-7)
        return out

class PatchSampleF(nn.Module):
    def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[]):
        # potential issues: currently, we use the same patch_ids for multiple images in the batch (comment from Park et al***)
        super(PatchSampleF, self).__init__()
        self.l2norm = Normalize(2)
        self.use_mlp = use_mlp
        self.nc = nc  # hard-coded
        self.mlp_init = False
        self.init_type = init_type
        self.init_gain = init_gain
        self.gpu_ids = gpu_ids

    def create_mlp(self, feats):
        for mlp_id, feat in enumerate(feats):
            input_nc = feat.shape[1]
            mlp = nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)])
            if len(self.gpu_ids) > 0:
                mlp.cuda()
            setattr(self, 'mlp_%d' % mlp_id, mlp)
        init_net(self, self.init_type, self.init_gain, self.gpu_ids)
        self.mlp_init = True

    def forward(self, feats, num_patches=64, patch_ids=None):
        return_ids = []
        return_feats = []
        if self.use_mlp and not self.mlp_init:
            self.create_mlp(feats)
        for feat_id, feat in enumerate(feats):
            B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
            feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
            if num_patches > 0:
                if patch_ids is not None:
                    patch_id = patch_ids[feat_id]
                else:
                    patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device)
                    patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))]  # .to(patch_ids.device)
                x_sample = feat_reshape[:, patch_id, :].flatten(0, 1)  # reshape(-1, x.shape[1])
            else:
                x_sample = feat_reshape
                patch_id = []
            if self.use_mlp:
                mlp = getattr(self, 'mlp_%d' % feat_id)
                x_sample = mlp(x_sample)
            return_ids.append(patch_id)
            x_sample = self.l2norm(x_sample)

            if num_patches == 0:
                x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
            return_feats.append(x_sample)
        return return_feats, return_ids

class GANLoss(nn.Module):
    """Define different GAN objectives.
    The GANLoss class abstracts away the need to create the target label tensor
    that has the same size as the input.
    """

    def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
        """ Initialize the GANLoss class.
        Parameters:
            gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
            target_real_label (bool) - - label for a real image
            target_fake_label (bool) - - label of a fake image
        Note: Do not use sigmoid as the last layer of Discriminator.
        LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
        """
        super(GANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.gan_mode = gan_mode
        if gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
        elif gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode in ['wgangp', 'nonsaturating']:
            self.loss = None
        else:
            raise NotImplementedError('gan mode %s not implemented' % gan_mode)

    def get_target_tensor(self, prediction, target_is_real):
        """Create label tensors with the same size as the input.
        Parameters:
            prediction (tensor) - - tpyically the prediction from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images
        Returns:
            A label tensor filled with ground truth label, and with the size of the input
        """

        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(prediction)

    def __call__(self, prediction, target_is_real):
        """Calculate loss given Discriminator's output and grount truth labels.
        Parameters:
            prediction (tensor) - - tpyically the prediction output from a discriminator
            target_is_real (bool) - - if the ground truth label is for real images or fake images
        Returns:
            the calculated loss.
        """
        bs = prediction.size(0)
        if self.gan_mode in ['lsgan', 'vanilla']:
            target_tensor = self.get_target_tensor(prediction, target_is_real)
            loss = self.loss(prediction, target_tensor)
        elif self.gan_mode == 'wgangp':
            if target_is_real:
                loss = -prediction.mean()
            else:
                loss = prediction.mean()
        elif self.gan_mode == 'nonsaturating':
            if target_is_real:
                loss = F.softplus(-prediction).view(bs, -1).mean(dim=1)
            else:
                loss = F.softplus(prediction).view(bs, -1).mean(dim=1)
        return loss

class PatchNCELoss(nn.Module):
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
        self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool

    def forward(self, feat_q, feat_k):
        batchSize = feat_q.shape[0]
        dim = feat_q.shape[1]
        feat_k = feat_k.detach()

        # pos logit
        l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1))
        l_pos = l_pos.view(batchSize, 1)

        # neg logit

        # Should the negatives from the other samples of a minibatch be utilized?
        # In CUT and FastCUT, we found that it's best to only include negatives
        # from the same image. Therefore, we set
        # --nce_includes_all_negatives_from_minibatch as False
        # However, for single-image translation, the minibatch consists of
        # crops from the "same" high-resolution image.
        # Therefore, we will include the negatives from the entire minibatch.
        #if self.opt.nce_includes_all_negatives_from_minibatch: #***
            # reshape features as if they are all negatives of minibatch of size 1.
        #    batch_dim_for_bmm = 1
        #else: #***
        batch_dim_for_bmm = model.batch_size

        # reshape features to batch size
        feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
        feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
        npatches = feat_q.size(1)
        l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))

        # diagonal entries are similarity between same features, and hence meaningless.
        # just fill the diagonal with very small number, which is exp(-10) and almost zero
        diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
        l_neg_curbatch.masked_fill_(diagonal, -10.0)
        l_neg = l_neg_curbatch.view(-1, npatches)

        out = torch.cat((l_pos, l_neg), dim=1) / model.nce_T

        loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
                                                        device=feat_q.device))

        return loss

def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal',
             init_gain=0.02, no_antialias=False, no_antialias_up=False, gpu_ids=[], opt=None):
    norm_layer = get_norm_layer(norm_type=norm)
    net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
                          no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=9, opt=opt)
    return init_net(net, init_type, init_gain, gpu_ids, initialize_weights=('stylegan2' not in netG))

def define_F(input_nc, netF, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, no_antialias=False, gpu_ids=[], netF_nc=256):
    net = PatchSampleF(use_mlp=True, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids, nc=netF_nc)
    return init_net(net, init_type, init_gain, gpu_ids)

def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, no_antialias=False, gpu_ids=[], opt=None):
    norm_layer = get_norm_layer(norm_type=norm)
    net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, no_antialias=no_antialias,)
    return init_net(net, init_type, init_gain, gpu_ids, initialize_weights=('stylegan2' not in netD))

In [None]:
class CUT():
    
    def __init__(self, n_epochs=5, n_epochs_decay=5, lr=0.0002, batch_size=1, num_patches=256):
        
        # https://github.com/taesungp/contrastive-unpaired-translation/blob/master/options/train_options.py
        self.n_epochs = n_epochs # The number of eopchs with the inital learning rate
        self.n_epochs_decay = n_epochs_decay # The number of epochs to linearly decay learning rate to zero
        self.beta1 = 0.5 # Momentum term of adam
        self.beta2 = 0.999 # Momentum term of adam
        self.lr = lr # Initial learning rate of adam
        self.gan_mode = 'lsgan' # The type of GAN objective
        self.lr_policy = 'linear' # Learning rate policy
        self.lr_decay_iters = 50 # Multiply by gamma every lr_decay_iters iterations
        self.isTrain = True # Train or test
        self.epoch_count = 1 # The starting epoch count
        
        
        # https://github.com/taesungp/contrastive-unpaired-translation/blob/master/options/base_options.py
        self.gpu_ids = [0] # Determines which GPU to use
        self.input_nc = 3 # Number of input image channels: 3 for RGB and (none for grayscale)
        self.output_nc = 3 # Number of output image channels: 3 for RGB and (none for grayscale)
        self.ngf = 64 # Number of gen filters in the last convolutional layer
        self.ndf = 64 # Number of discrim filters in the first convolutional layer
        self.opt_netD = 'basic' # Specify discriminator architecture; the basic model is a 70x70 PatchGAN.
        self.opt_netG = 'resnet_9blocks' # Specify the generator architecture
        self.n_layers_D = 3 # Only used if netD=='n_layers'
        self.normG = 'instance' # Specify the type of normalization for G
        self.normD = 'instance' # Specify the type of normalization for D
        self.init_type = 'xavier' # Network initialization
        self.init_gain = 0.02 # Scaling factor for normal, xavier and orthogonal initialization
        self.no_dropout = True # No dropout for the generator
        self.direction = 'AtoB'
        # ...
        self.batch_size = batch_size # Input batch size
        
        # https://github.com/taesungp/contrastive-unpaired-translation/blob/87ab89cdca651f87742844016b0cfa49fa7bd3ee/models/base_model.py#L8
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')  # get device name: CPU or GPU
        self.optimizers = []
        
        # https://github.com/taesungp/contrastive-unpaired-translation/blob/87ab89cdca651f87742844016b0cfa49fa7bd3ee/models/cut_model.py#L9
        self.lambda_GAN = 1.0 # Weight for GAN loss: GAN(G(X))
        self.lambda_NCE = 1.0 # Weight for NCE loss: NCE(G(X), X)
        self.nce_idt = True # Use NCE loss for identity mapping: NCE(G(Y), Y); True for CUT, False for SinCUT
        self.nce_layers = [0, 2, 4, 6, 8, 10, 12, 14, 16]  # Compute NCE loss on these layers
        #self.nce_includes_all_negatives_from_minibatch = False # (used for single image translation)
        self.opt_netF = 'mlp_sample' # How to downsample the feature map
        self.netF_nc = 256 # 256 # ***
        self.nce_T = 0.07 # Temperature for NCE loss
        self.num_patches = num_patches # 256 # *** # Number of patches per layer
        self.flip_equivariance = False # Used by FastCUT, but not CUT
        
        self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'G', 'NCE', 'NCE_Y']
        self.visual_names = ['photo_data', 'fake_B', 'monet_data', 'idt_B']
        
        if self.isTrain:
            self.model_names = ['G', 'F', 'D']
        else:  # during test time, only load G
            self.model_names = ['G']
            
        # Not sure about Park et al's settings for:
        self.no_antialias = True # TRUE OR FALSE?
        self.no_antialias_up = True # TRUE OR FALSE?
        opt = None
        
        # define networks (both generator and discriminator)
        self.netG = define_G(self.input_nc, self.output_nc, self.ngf, self.opt_netG, self.normG,
                                      not self.no_dropout, self.init_type, self.init_gain,
                                      self.no_antialias, self.no_antialias_up, self.gpu_ids, opt)
        self.netF = define_F(self.input_nc, self.opt_netF, self.normG, not self.no_dropout,
                                      self.init_type, self.init_gain, self.no_antialias, self.gpu_ids, self.netF_nc)
        
        if self.isTrain:
            self.netD = define_D(self.output_nc, self.ndf, self.opt_netD, self.n_layers_D, self.normD,
                                          self.init_type, self.init_gain, self.no_antialias, self.gpu_ids, opt)

            # define loss functions
            self.criterionGAN = GANLoss(self.gan_mode).to(self.device)
            self.criterionNCE = []

            for nce_layer in self.nce_layers:
                self.criterionNCE.append(PatchNCELoss(opt).to(self.device))

            self.criterionIdt = torch.nn.L1Loss().to(self.device)
            self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            
            self.schedulers = [self.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
    
    
    def get_current_losses(self):
        """Return training losses / errors. train.py will print out these errors on console, and save them to a file"""
        errors_ret = OrderedDict()
        for name in self.loss_names:
            if isinstance(name, str):
                errors_ret[name] = float(getattr(self, 'loss_' + name))  # float(...) works for both scalar tensor and float number
        return errors_ret
    
    def update_learning_rate(self):
        """Update learning rates for all the networks; called at the end of every epoch"""
        for scheduler in self.schedulers:
            if self.lr_policy == 'plateau':
                scheduler.step(self.metric)
            else:
                scheduler.step()

        lr = self.optimizers[0].param_groups[0]['lr']
        print('learning rate = %.7f' % lr)

    def data_dependent_initialize(self, photo_data, monet_data):
        """
        The feature network netF is defined in terms of the shape of the intermediate, extracted
        features of the encoder portion of netG. Because of this, the weights of netF are
        initialized at the first feedforward pass with some input images.
        Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
        """
        
        #self.set_input(data)  # basically: set model.photo_data to images from one domain, and model.monet_data to images from the other domain
        self.photo_data = photo_data
        self.monet_data = monet_data
        
        bs_per_gpu = self.photo_data.size(0) // max(len(self.gpu_ids), 1)
        self.photo_data = self.photo_data[:bs_per_gpu]
        self.monet_data = self.monet_data[:bs_per_gpu]
        
        self.forward()                     # compute fake images: G(A)
        
        if self.isTrain:
            self.compute_D_loss().backward()                  # calculate gradients for D
            
            self.compute_G_loss().backward()                   # calculate graidents for G
            if self.lambda_NCE > 0.0:
                self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.lr, betas=(self.beta1, self.beta2))
                self.optimizers.append(self.optimizer_F)
                
    def optimize_parameters(self):
        # forward
        self.forward()

        # update D
        self.set_requires_grad(self.netD, True)
        self.optimizer_D.zero_grad()
        self.loss_D = self.compute_D_loss()
        
        self.loss_D.backward() # They have a custom backward function? Is this necessary? I think the standard pytorch function is used here ***
        self.optimizer_D.step()
        
        # update G
        self.set_requires_grad(self.netD, False)
        
        self.optimizer_G.zero_grad()
        
        if self.netF == 'mlp_sample':
            self.optimizer_F.zero_grad()
        
        self.loss_G = self.compute_G_loss()
        
        self.loss_G.backward()
        self.optimizer_G.step()
        if self.netF == 'mlp_sample':
            self.optimizer_F.step()
        
    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
        Parameters:
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad
    
    def get_scheduler(self, optimizer, opt):
        """Return a learning rate scheduler
        Parameters:
            optimizer          -- the optimizer of the network
            opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions．　
                                  opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
        For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
        and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
        For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
        See https://pytorch.org/docs/stable/optim.html for more details.
        """
        if self.lr_policy == 'linear':
            def lambda_rule(epoch):
                lr_l = 1.0 - max(0, epoch + self.epoch_count - self.n_epochs) / float(self.n_epochs_decay + 1)
                return lr_l
            scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
        elif self.lr_policy == 'step':
            scheduler = lr_scheduler.StepLR(optimizer, step_size=self.lr_decay_iters, gamma=0.1)
        elif self.lr_policy == 'plateau':
            scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
        elif self.lr_policy == 'cosine':
            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.n_epochs, eta_min=0)
        else:
            return NotImplementedError('learning rate policy [%s] is not implemented', self.lr_policy)
        return scheduler

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.
        Parameters:
            input (dict): include the data itself and its metadata information.
        The option 'direction' can be used to swap domain A and domain B.
        """
        AtoB = self.direction == 'AtoB'
        self.photo_data = input['A' if AtoB else 'B'].to(self.device)
        self.monet_data = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def parallelize(self):
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                setattr(self, 'net' + name, torch.nn.DataParallel(net, self.gpu_ids))
    

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.real = torch.cat((self.photo_data, self.monet_data), dim=0) if self.nce_idt and self.isTrain else self.photo_data
        if self.flip_equivariance: # True for FastCut, not for CUT
            self.flipped_for_equivariance = self.isTrain and (np.random.random() < 0.5)
            if self.flipped_for_equivariance:
                self.real = torch.flip(self.real, [3])
        
        
        self.fake = self.netG(self.real) # This takes a very large part of the GPU memory ***
        
        self.fake_B = self.fake[:self.photo_data.size(0)]
        
        if self.nce_idt:
            self.idt_B = self.fake[self.photo_data.size(0):]

    def compute_D_loss(self):
        """Calculate GAN loss for the discriminator"""
        fake = self.fake_B.detach()
        # Fake; stop backprop to the generator by detaching fake_B
        pred_fake = self.netD(fake)
        self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
        # Real
        self.pred_real = self.netD(self.monet_data)
        loss_D_real = self.criterionGAN(self.pred_real, True)
        self.loss_D_real = loss_D_real.mean()

        # combine loss and calculate gradients
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        return self.loss_D

    def compute_G_loss(self):
        """Calculate GAN and NCE loss for the generator"""
        fake = self.fake_B
        # First, G(A) should fake the discriminator
        if self.lambda_GAN > 0.0:
            pred_fake = self.netD(fake)
            self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.lambda_GAN
        else:
            self.loss_G_GAN = 0.0

        if self.lambda_NCE > 0.0:
            self.loss_NCE = self.calculate_NCE_loss(self.photo_data, self.fake_B)
        else:
            self.loss_NCE, self.loss_NCE_bd = 0.0, 0.0

        if self.nce_idt and self.lambda_NCE > 0.0:
            self.loss_NCE_Y = self.calculate_NCE_loss(self.monet_data, self.idt_B)
            loss_NCE_both = (self.loss_NCE + self.loss_NCE_Y) * 0.5
        else:
            loss_NCE_both = self.loss_NCE
        
        self.loss_G = self.loss_G_GAN + loss_NCE_both
        return self.loss_G
    
    def calculate_NCE_loss(self, src, tgt): # Takes a lot of GPU memory, in particular self.netG(...) ***
        n_layers = len(self.nce_layers)
        
        feat_q = self.netG(tgt, self.nce_layers, encode_only=True)

        if self.flip_equivariance and self.flipped_for_equivariance: # Used by FastCUT, but not CUT
            feat_q = [torch.flip(fq, [3]) for fq in feat_q]

        feat_k = self.netG(src, self.nce_layers, encode_only=True) # Takes a good amount of GPU memory ***
        
        feat_k_pool, sample_ids = self.netF(feat_k, self.num_patches, None)
        
        feat_q_pool, _ = self.netF(feat_q, self.num_patches, sample_ids)

        total_nce_loss = 0.0
        for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
            loss = crit(f_q, f_k) * self.lambda_NCE
            total_nce_loss += loss.mean()
        
        return total_nce_loss / n_layers

## The Data

### Load the data

In [None]:
def load_images(path_images, n_images):
    """ Load images corresponding to the given path. """
    
    images = []
    image_names = os.listdir(path_images)
    for i in range(n_images):
        image = cv2.imread(os.path.join(path_images, image_names[i]))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        images.append(image)
    
    print(f"Loaded {len(images)} {path_images[29:34]} images.")
    return images

def define_transformations():
    """Define the transformations used to augment the data."""
    
    transform_crop_90 = A.RandomResizedCrop(width=256, height=256, scale=(0.9, 0.9), p=1.0)
    #transform_crop_70 = A.RandomResizedCrop(width=256, height=256, scale=(0.7, 0.7), p=1.0)

    #transform_rotate_90 = A.Rotate((90,90), p=1.0)
    #transform_rotate_180 = A.Rotate((180,180), p=1.0)
    #transform_rotate_270 = A.Rotate((270,270), p=1.0)

    transform_flip_horizontal = A.HorizontalFlip(p=1.0)
    transform_flip_vertical = A.VerticalFlip(p=1.0)

    #transformations = [transform_crop_90, transform_crop_70, transform_rotate_90, transform_rotate_180, transform_rotate_270, transform_flip_horizontal, transform_flip_vertical]
    #transformations = [transform_crop_90, transform_crop_70, transform_rotate_180, transform_flip_horizontal, transform_flip_vertical]
    transformations = [transform_crop_90, transform_flip_horizontal, transform_flip_vertical]
    return transformations

def set_data(images, transformations, device="cuda:0"):
    """ Load image data corresponding to the given path. """
    
    n_images = len(images)
    n_transformations = len(transformations)
    image_tensor = torch.empty([n_images*(n_transformations+1), 256, 256, 3], dtype=torch.float32, device="cuda:0")
    for i, image in enumerate(images):
        image_tensor[i] = torch.from_numpy(np.ascontiguousarray(image, dtype=np.float32) / 255).to(device)
        
        for j, transformation in enumerate(transformations):
            transformed_image = transformation(image=image)["image"]
            transformed_image = np.ascontiguousarray(transformed_image, dtype=np.float32) / 255
            image_tensor[i+n_images*(j+1)] = torch.from_numpy(transformed_image).to(device)
        
    image_tensor = image_tensor.permute(0, 3, 1, 2)
    print(f"Created dataset consisting of {image_tensor.shape[0]} images.")
    return image_tensor

In [None]:
n_photos = 7028 # Using n_photos out of 7028 photos
n_monets = 300 # Using n_monets out of 300 monets

photos = load_images(PHOTO_PATH, n_photos)
monets = load_images(MONET_PATH, n_monets)

In [None]:
transformations = define_transformations()
n_transformations = len(transformations)
photo_data = set_data(photos, [])
monet_data = set_data(monets, transformations)

### Visual inspection of the data

In [None]:
def plot_data(images, indices):
    """ Inspect the data; if len(indices)>4 it must be multiple of 4. """
    
    width = 4
    n_indices = len(indices)
    n_columns = min(4, n_indices)
    n_rows = max(1, n_indices//4)
    
    fig, axs = plt.subplots(n_rows, n_columns, figsize=(16*n_rows,16))
    [ax.axis("off") for ax in axs.ravel()]
    
    for i, index in enumerate(indices):
        image = images[index].permute(1, 2, 0).detach().cpu().numpy()
        if n_indices <= 4:
            axs[i].imshow(image, interpolation='none')
        else:
            axs[i//4, i%4].imshow(image, interpolation='none')

In [None]:
photo_inspection_indices = [0, 99, 199, 299] # If len>4 must be multiple of 4
monet_inspection_indices = [0, 99, 199, 299] # If len>4 must be multiple of 4

inspect_data = True
if inspect_data: # Plot photo images according to the selected indices
    plot_data(photo_data, photo_inspection_indices)

In [None]:
plot_data(monet_data, [1+n_monets*i for i in range(n_transformations+1)])

In [None]:
if inspect_data: # Plot Monet images according to the selected indices
    plot_data(monet_data, monet_inspection_indices)

## Training the CUT Model

In [None]:
def set_input(model, photo_data, monet_data, epoch, i):
    """ Select a new photo and Monet as input. """
        
    model.photo_data = photo_data[np.random.randint(0, photo_data.shape[0], model.batch_size)]
    model.monet_data = monet_data[i:i+model.batch_size]
    
def print_run_info(epoch, i, print_freq, model, iteration_start_time):
    """" Print timer information and training losses. """
    if i > 0 and i % print_freq == 0:
        message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, time.time()-iteration_start_time)
        losses = model.get_current_losses()
        for k, v in losses.items():
            message += '%s: %.3f ' % (k, v)
        print(message)
        return time.time()
    return iteration_start_time

def store_losses(losses_G_GAN, losses_D_real, losses_D_fake, losses_G, losses_NCE, losses_NCE_Y):
    losses = model.get_current_losses()
    losses_G_GAN.append(losses["G_GAN"])
    losses_D_real.append(losses["D_real"])
    losses_D_fake.append(losses["D_fake"])
    losses_G.append(losses["G"])
    losses_NCE.append(losses["NCE"])
    losses_NCE_Y.append(losses["NCE_Y"])

def store_image_samples(photo_samples, monet_samples, photo_inspection_indices, monet_inspection_indices, model, i):
    """ Store particular input and output data for later inspection. """
    if i in photo_inspection_indices:
        photo = model.photo_data.permute(0, 2, 3, 1)[0].detach().cpu().squeeze(0).numpy()
        photo_to_monet = model.fake_B.permute(0, 2, 3, 1)[0].detach().cpu().squeeze(0).numpy()
        photo_samples.append((photo, photo_to_monet))
    
    if i in monet_inspection_indices:
        monet = model.monet_data.permute(0, 2, 3, 1)[0].detach().cpu().squeeze(0).numpy()
        monet_to_monet = model.idt_B.permute(0, 2, 3, 1)[0].detach().cpu().squeeze(0).numpy()
        monet_samples.append((monet, monet_to_monet))

## The training loop

In [None]:
print('The number of photos = %d' % n_photos)
print('The number of monets = %d' % n_monets)
print('The number of transformations = %d' % n_transformations)

# Initialize the CUT model
model = CUT(n_epochs=30, n_epochs_decay=100, lr = 0.0002, batch_size=1, num_patches=256)
model.data_dependent_initialize(photo_data[:model.batch_size], monet_data[np.random.randint(0, monet_data.shape[0], model.batch_size)]) #&&&
model.parallelize()

print_freq = 1000 # Frequency of printing training results
photo_samples = [] # Will contain photos and their translations for visual inspection
monet_samples = [] # Will contain Monets and their translations for visual inspection
losses_G_GAN, losses_D_real, losses_D_fake, losses_G, losses_NCE, losses_NCE_Y  = [], [], [], [], [], []
training_start_time = time.time()
for epoch in range(model.epoch_count, model.n_epochs + model.n_epochs_decay + 1, model.batch_size):    # outer loop for different epochs
    epoch_start_time = time.time()     # Timer for entire epoch
    iteration_start_time = time.time() # Timer for print_freq iterations
    
    # Randomize the order of the monet data
    monet_data = monet_data[torch.randperm(monet_data.shape[0])]
    
    n_iterations = n_monets*n_transformations-model.batch_size+1
    for i in range(n_iterations):
        
        set_input(model, photo_data, monet_data, epoch, i)
        
        model.optimize_parameters() # Calculate loss functions, get gradients, update network weights
        
        print_run_info(epoch, i, print_freq, model, iteration_start_time) # Print timer information and training losses
        
        store_losses(losses_G_GAN, losses_D_real, losses_D_fake, losses_G, losses_NCE, losses_NCE_Y) # Store losses for later inspection
        
        if epoch % 10 == 0:
            store_image_samples(photo_samples, monet_samples, photo_inspection_indices, monet_inspection_indices, model, i)
            
        if i > 0 and i % (n_iterations//model.batch_size-1) == 0:
            model.update_learning_rate()
    
    #model.update_learning_rate() # update learning rates at the end of every epoch.
    print('End of epoch %d / %d \t Time Taken: %d sec\n' % (epoch, model.n_epochs + model.n_epochs_decay, time.time() - epoch_start_time))
    
    enough_time_for_new_epoch = 18000 - (time.time() - training_start_time) > time.time() - epoch_start_time + 500
    if not enough_time_for_new_epoch:
        print(f"Not enough time left ({18000 - (time.time() - training_start_time)}<{time.time() - epoch_start_time + 500}) for another epoch, breaking.")
        break

### Visual inspection of the output

#### Plot the losses

In [None]:
def set_ax(ax, y, x_label, y_label):
    plt.rcParams.update({'font.size': 24})
    plt.xticks(fontsize=24)
    ax[0].plot(y)
    ax[0].set_xlabel(x_label, fontsize='x-large')
    ax[0].set_ylabel(y_label, fontsize='x-large')
    
    start = 10
    while len(y) / (start*10) >= 3:
        start *= 10
    ax[1].plot(range(start, len(y)), y[start:])
    ax[1].set_xlabel(x_label, fontsize='x-large')
    ax[1].set_ylabel(y_label, fontsize='x-large')

fig, axs = plt.subplots(6, 2, figsize=(14*2, 8*6))
[ax.grid() for ax in axs.ravel()]
set_ax(axs[0], losses_G_GAN, "Iterations", "G_GAN Loss")
set_ax(axs[1], losses_D_real, "Iterations", "D_real Loss")
set_ax(axs[2], losses_D_fake, "Iterations", "D_fake Loss")
set_ax(axs[3], losses_G, "Iterations", "G Loss")
set_ax(axs[4], losses_NCE, "Iterations", "NCE Loss")
set_ax(axs[5], losses_NCE_Y, "Iterations", "NCE_Y Loss")

In [None]:
def smooth_losses(losses, smoothing_factor=0.01):
    smoothed_losses = [losses[0]]
    for i in range(1, len(losses)):
        smoothed_losses.append(smoothed_losses[-1] * (1-smoothing_factor) + losses[i] * smoothing_factor)
    return smoothed_losses
        
smoothed_losses_G_GAN = smooth_losses(losses_G_GAN)
smoothed_losses_D_real = smooth_losses(losses_D_real)
smoothed_losses_D_fake = smooth_losses(losses_D_fake)
smoothed_losses_G = smooth_losses(losses_G)
smoothed_losses_NCE = smooth_losses(losses_NCE)
smoothed_losses_NCE_Y = smooth_losses(losses_NCE_Y)

fig, axs = plt.subplots(6, 2, figsize=(14*2, 8*6))
[ax.grid() for ax in axs.ravel()]
set_ax(axs[0], smoothed_losses_G_GAN, "Iterations", "Smoothed G_GAN Loss")
set_ax(axs[1], smoothed_losses_D_real, "Iterations", "Smoothed D_real Loss")
set_ax(axs[2], smoothed_losses_D_fake, "Iterations", "Smoothed D_fake Loss")
set_ax(axs[3], smoothed_losses_G, "Iterations", "Smoothed G Loss")
set_ax(axs[4], smoothed_losses_NCE, "Iterations", "Smoothed NCE Loss")
set_ax(axs[5], smoothed_losses_NCE_Y, "Iterations", "Smoothed NCE_Y Loss")

#### Plot a selection of images and their translations

In [None]:
def plot_images(images):
    fig, axs = plt.subplots(len(images), 2, figsize=(12, 12/2*len(images)))
    [ax.axis("off") for ax in axs.ravel()]
    for i, (image, image_translation) in enumerate(images):
        axs[i, 0].imshow(image, interpolation='none')
        axs[i, 1].imshow(image_translation, interpolation='none')

In [None]:
output_inspection = True
if output_inspection: 
    plot_images(photo_samples)

In [None]:
if output_inspection: 
    plot_images(monet_samples)

# Create Submission File

In [None]:
import PIL
! mkdir ../images

In [None]:
model.isTrain = False
test_photo_samples = []
for i in range(n_photos):
    model.photo_data = torch.unsqueeze(photo_data[i], 0)
    model.forward()
    
    prediction = (model.fake_B.permute(0, 2, 3, 1)[0] * 255).detach().cpu().squeeze(0).numpy().astype(np.uint8)
    im = PIL.Image.fromarray(prediction)
    im.save("../images/" + str(i+1) + ".jpg")
    
    #store_image_samples(test_photo_samples, None, photo_inspection_indices, [], model, i)
    store_image_samples(test_photo_samples, None, photo_inspection_indices + [500, 1000, 2000, 3000, 4000, 5000, 6000, 7000], [], model, i)
        
    if i > 0 and i % 1000 == 0:
        print(f"Made {i} out of {n_photos} predictions.")
print(f"Done, made {n_photos} predictions.")

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images") # Make the submission

### Visually inspect the submission

In [None]:
plot_images(test_photo_samples)