Based on the paper : https://ieeexplore.ieee.org/document/9186319

Based on the code : https://github.com/ranery/Bayesian-CycleGAN



In [None]:
!nvidia-smi -L
!pip install --upgrade --force-reinstall --no-deps kaggle

GPU 0: Tesla P100-PCIE-16GB (UUID: GPU-c74f44ec-5471-5c48-a87c-040f983dbcc0)
Collecting kaggle
[?25l  Downloading https://files.pythonhosted.org/packages/99/33/365c0d13f07a2a54744d027fe20b60dacdfdfb33bc04746db6ad0b79340b/kaggle-1.5.10.tar.gz (59kB)
[K     |████████████████████████████████| 61kB 3.2MB/s 
[?25hBuilding wheels for collected packages: kaggle
  Building wheel for kaggle (setup.py) ... [?25l[?25hdone
  Created wheel for kaggle: filename=kaggle-1.5.10-cp36-none-any.whl size=73269 sha256=ea6176d83c165d643950dc9812526389ce3f6065266d65b493e03996c03d11cc
  Stored in directory: /root/.cache/pip/wheels/3a/d1/7e/6ce09b72b770149802c653a02783821629146983ee5a360f10
Successfully built kaggle
Installing collected packages: kaggle
  Found existing installation: kaggle 1.5.10
    Uninstalling kaggle-1.5.10:
      Successfully uninstalled kaggle-1.5.10
Successfully installed kaggle-1.5.10


In [None]:
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch
import torch.nn as nn
from torch.nn import init
from torch.nn.parameter import Parameter
import torch.utils.data as data

import functools
import numpy as np
import cv2
from scipy import misc
import time, itertools
import random
from collections import OrderedDict

from PIL import Image

import os
import os.path

# Load Data

In [None]:
from google.colab import files
# Here you should upload your Kaggle API key (see : https://www.kaggle.com/docs/api (Authentification paragraph))
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"negalov","key":"f2d706af1447ced98029686098226cba"}'}

In [None]:
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json
! kaggle datasets list

In [None]:
! kaggle competitions download -c gan-getting-started

In [None]:
! unzip /content/gan-getting-started.zip

# Utils

In [None]:
class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return Variable(images)
        return_images = []
        for image in images:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images

def tensor2im(image_tensor, imtype=np.uint8):
    image_numpy = image_tensor.detach().cpu().float().numpy()
    #if image_numpy.shape[0] == 1:
    image_numpy = image_numpy[0]
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
    return image_numpy.astype(imtype)


def diagnose_network(net, name='network'):
    mean = 0.0
    count = 0
    for param in net.parameters():
        if param.grad is not None:
            mean += torch.mean(torch.abs(param.grad.data))
            count += 1
    if count > 0:
        mean = mean / count
    print(name)
    print(mean)


def save_image(image_numpy, image_path):
    image_pil = Image.fromarray(image_numpy)
    image_pil.save(image_path)


def print_numpy(x, val=True, shp=False):
    x = x.astype(np.float64)
    if shp:
        print('shape,', x.shape)
    if val:
        x = x.flatten()
        print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
            np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))


def mkdirs(paths):
    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            mkdir(path)
    else:
        mkdir(paths)


def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)

class Visualizer():
    def __init__(self, opt):
        # self.opt = opt
        self.display_id = opt.display_id
        self.use_html = opt.isTrain and not opt.no_html
        self.win_size = opt.display_winsize
        self.name = opt.name
        self.opt = opt
        self.saved = False
        self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'images')
        self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
        try:
          with open(self.log_name, "a") as log_file:
            now = time.strftime("%c")
            log_file.write('================ Training Loss (%s) ================\n' % now)
        except FileNotFoundError:
          with open(self.log_name, "w") as log_file:
            now = time.strftime("%c")
            log_file.write('================ Training Loss (%s) ================\n' % now)

    def reset(self):
        self.saved = False

    # |visuals|: dictionary of images to display or save
    def display_current_results(self, visuals, epoch, save_result):
        for label, image_numpy in visuals.items():
            img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
            save_image(image_numpy, img_path)

    # errors: same format as |errors| of plotCurrentErrors
    def print_current_errors(self, epoch, i, errors, t):
        message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
        for k, v in errors.items():
            message += '%s: %.3f ' % (k, v)
        print(message)
        with open(self.log_name, "a") as log_file:
            log_file.write('%s\n' % message)

# Networks

In [None]:
class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
                 tensor=torch.FloatTensor):
        super(GANLoss, self).__init__()
        self.real_label = target_real_label
        self.fake_label = target_fake_label
        self.real_label_var = None
        self.fake_label_var = None
        self.Tensor = tensor
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real):
        target_tensor = None
        if target_is_real:
            create_label = ((self.real_label_var is None) or
                            (self.real_label_var.numel() != input.numel()))
            if create_label:
                real_tensor = self.Tensor(input.size()).fill_(self.real_label)
                self.real_label_var = Variable(real_tensor, requires_grad=False)
            target_tensor = self.real_label_var
        else:
            create_label = ((self.fake_label_var is None) or
                            (self.fake_label_var.numel() != input.numel()))
            if create_label:
                fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
                self.fake_label_var = Variable(fake_tensor, requires_grad=False)
            target_tensor = self.fake_label_var
        return target_tensor

    def __call__(self, input, target_is_real):
        if isinstance(input[0], list):
            loss = 0
            for input_i in input:
                pred = input_i[-1]
                target_tensor = self.get_target_tensor(pred, target_is_real)
                loss += self.loss(pred, target_tensor)
            return loss
        else:
            target_tensor = self.get_target_tensor(input[-1], target_is_real)
            return self.loss(input[-1], target_tensor)


In [None]:
class ResnetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)

    def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented!' % padding_type)

        conv_block += [
            nn.Conv2d(dim, dim, kernel_size=3, padding=p),
            norm_layer(dim),
            activation,
        ]

        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented!' % padding_type)

        conv_block += [
            nn.Conv2d(dim, dim, kernel_size=3, padding=p),
            norm_layer(dim),
            activation,
        ]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out

class GlobalGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9,
                 norm_layer=nn.BatchNorm2d, padding_type='reflect'):
        assert(n_blocks >= 0)
        super(GlobalGenerator, self).__init__()
        activation = nn.ReLU(True)

        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
            norm_layer(ngf),
            activation,
        ]

        # downsample
        for i in range(n_downsampling):
            mult = 2**i
            model += [
                nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
                norm_layer(ngf * mult * 2),
                activation,
            ]

        # resnet blocks
        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [
                ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)
            ]

        # upsample
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [
                nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
                norm_layer(int(ngf * mult / 2)),
                activation,
            ]
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
            nn.Tanh(),
        ]
        self.model = nn.Sequential(*model)

    def forward(self, input):
        return self.model(input)

class MultiscaleDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, num_D=3, use_dropout=False):
        super(MultiscaleDiscriminator, self).__init__()
        self.num_D = num_D
        self.n_layers = n_layers

        for i in range(num_D):
            netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, use_dropout)
            setattr(self, 'layer'+str(i), netD.model)

        self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)

    def singleD_forward(self, model, input):
        return [model(input)]

    def forward(self, input):
        num_D = self.num_D
        result = []
        input_downsampled = input
        for i in range(num_D):
            model = getattr(self, 'layer'+str(num_D-1-i))
            result.append(self.singleD_forward(model, input_downsampled))
            if i != (num_D-1):
                input_downsampled = self.downsample(input_downsampled)
        return result

class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_dropout=False):
        super(NLayerDiscriminator, self).__init__()
        self.n_layers = n_layers

        kw = 4
        padw = int(np.ceil((kw-1.0)/2))
        model = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf = ndf
        for n in range(1, n_layers):
            nf_prev = nf
            nf = min(nf * 2, 512)
            model += [
                nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
                norm_layer(nf),
                nn.LeakyReLU(0.2, True)
            ]

        nf_prev = nf
        nf = min(nf * 2, 512)
        model += [
            nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
            norm_layer(nf),
            nn.LeakyReLU(0.2, True)
        ]
        model += [nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]

        if use_sigmoid:
            model += [nn.Sigmoid()]

        if use_dropout:
            model = model + [nn.Dropout(0.5)]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        return self.model(input)

class Encoder(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, n_layers=4, norm_layer=nn.BatchNorm2d, ratio=1):
        super(Encoder, self).__init__()
        self.output_nc = output_nc

        model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
                 norm_layer(ngf), nn.ReLU(True)]
        ### downsample
        for i in range(n_layers):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
                      norm_layer(ngf * mult * 2), nn.ReLU(True)]

        self.downsample = nn.Sequential(*model)
        self.pool = nn.AvgPool2d(32)
        self.fc = nn.Sequential(*[nn.Linear(int(ngf * mult * 2 * 4 / ratio), 32)])
        self.fcVar = nn.Sequential(*[nn.Linear(int(ngf * mult * 2 * 4 / ratio), 32)])

        ### upsample
        for i in range(n_layers):
            mult = 2**(n_layers - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
                       norm_layer(int(ngf * mult / 2)), nn.ReLU(True)]

        model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
        self.model = nn.Sequential(*model)

    def forward(self, input):
        feature = self.model(input)
        _conv = self.downsample(input)
        _conv = self.pool(_conv)
        # print(_conv)
        _conv = _conv.view(input.size(0), -1)
        #print(_conv.shape)
        output = self.fc(_conv)
        outputVar = self.fcVar(_conv)
        return output, outputVar, feature

In [None]:
def weights_init_gaussian(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

def weights_init_uniform(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.uniform(m.weight.data, -0.06, 0.06)
    elif classname.find('Conv') != -1:
        init.uniform(m.weight.data, -0.06, 0.06)
    elif classname.find('BatchNorm2d') != -1:
        init.uniform(m.weight.data, 0.04, 1.06)
        init.constant(m.bias.data, 0.0)

def get_norm_layer():
    norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
    return norm_layer

def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, norm='instance', ratio=1):
    norm_layer = get_norm_layer()
    if netG == 'global':
        netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
    elif netG == 'encoder':
        netG = Encoder(input_nc, output_nc, 64, n_downsample_global, norm_layer, ratio)
    else:
        raise NotImplementedError('generator [%s] is not found.' % netG)
    netG.apply(weights_init_gaussian)
    return netG

def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1):
    norm_layer = get_norm_layer()
    netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, use_dropout=False)
    netD.apply(weights_init_gaussian)
    return netD

# Models

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

class CycleGAN():
    def name(self):
        return 'Bayesian CycleGAN Model'

    def initialize(self, opt):
        self.opt = opt
        self.isTrain = opt.isTrain
        if torch.cuda.is_available():
            print('cuda is available, we will use gpu!')
            self.Tensor = torch.cuda.FloatTensor
            #self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
            torch.cuda.manual_seed_all(100)
        else:
            self.Tensor = torch.FloatTensor
            torch.manual_seed(100)
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)

        # get radio for network initialization
        ratio = 256 * 256 / opt.loadSize / (opt.loadSize / opt.ratio)

        # load network
        netG_input_nc = opt.input_nc + 1
        netG_output_nc = opt.output_nc + 1
        self.netG_A = define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG_A,
                                        opt.n_downsample_global, opt.n_blocks_global, opt.norm).type(self.Tensor)#.cuda()
        self.netG_B = define_G(netG_output_nc, opt.input_nc, opt.ngf, opt.netG_B,
                                        opt.n_downsample_global, opt.n_blocks_global, opt.norm).type(self.Tensor)#.cuda()

        self.netE_A = define_G(opt.input_nc, 1, 64, 'encoder', opt.n_downsample_global, norm=opt.norm, ratio=ratio).type(self.Tensor)#.cuda()
        self.netE_B = define_G(opt.output_nc, 1, 64, 'encoder', opt.n_downsample_global, norm=opt.norm, ratio=ratio).type(self.Tensor)#.cuda()

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = define_D(opt.output_nc, opt.ndf, opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.num_D_A).type(self.Tensor)#.cuda()
            self.netD_B = define_D(opt.input_nc, opt.ndf, opt.n_layers_D, opt.norm,
                                            use_sigmoid, opt.num_D_B).type(self.Tensor)#.cuda()

        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG_A, 'G_A', opt.which_epoch, self.save_dir)
            self.load_network(self.netG_B, 'G_B', opt.which_epoch, self.save_dir)
            self.load_network(self.netE_A, 'E_A', opt.which_epoch, self.save_dir)
            self.load_network(self.netE_B, 'E_B', opt.which_epoch, self.save_dir)
            if self.isTrain:
                self.load_network(self.netD_A, 'D_A', opt.which_epoch, self.save_dir)
                self.load_network(self.netD_B, 'D_B', opt.which_epoch, self.save_dir)

        # set loss functions and optimizers
        if self.isTrain:
            self.old_lr = opt.lr
            # define loss function
            self.criterionGAN = GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionL1 = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_E_A = torch.optim.Adam(self.netE_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_E_B = torch.optim.Adam(self.netE_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

        print('Network initialized!')

        # dataset path and name list
        self.origin_path = os.getcwd()
        self.path_A = self.opt.dataroot_A
        self.path_B = self.opt.dataroot_B
        self.list_A = os.listdir(self.path_A)
        self.list_B = os.listdir(self.path_B)

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        self.input_A = input['A' if AtoB else 'B']
        self.input_B = input['B' if AtoB else 'A']
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A).type(self.Tensor)#.cuda()
        self.real_B = Variable(self.input_B).type(self.Tensor)#.cuda()

        # feature map
        mc_sample_x = random.sample(self.list_A, self.opt.mc_x)
        mc_sample_y = random.sample(self.list_B, self.opt.mc_y)
        self.real_B_zx = []
        self.real_A_zy = []
        self.mu_x = []
        self.mu_y = []
        self.logvar_x = []
        self.logvar_y = []
        os.chdir(self.path_A)
        for sample_x in mc_sample_x:
            z_x = Image.open(sample_x).convert('RGB')
            z_x = self.img_resize(z_x, self.opt.loadSize)
            z_x = transform(z_x)
            if self.opt.input_nc == 1:  # RGB to gray
                z_x = z_x[0, ...] * 0.299 + z_x[1, ...] * 0.587 + z_x[2, ...] * 0.114
                z_x = z_x.unsqueeze(0)
            z_x = Variable(z_x).type(self.Tensor)#.cuda()
            z_x = torch.unsqueeze(z_x, 0)
            mu_x, logvar_x, feat_map = self.netE_A.forward(z_x)
            self.mu_x.append(mu_x)
            self.logvar_x.append(logvar_x)
            self.feat_map_zx = feat_map
            real_B_zx = []
            for i in range(0, self.opt.batchSize):
                _real = torch.unsqueeze(self.real_B[i], 0)
                _real = torch.cat([_real, feat_map], dim=1)
                real_B_zx.append(_real)
            real_B_zx = torch.cat(real_B_zx)
            self.real_B_zx.append(real_B_zx)
        self.mu_x = torch.cat(self.mu_x)
        self.logvar_x = torch.cat(self.logvar_x)

        os.chdir(self.path_B)
        for sample_y in mc_sample_y:
            z_y = Image.open(sample_y).convert('RGB')
            z_y = self.img_resize(z_y, self.opt.loadSize)
            z_y = transform(z_y)
            if self.opt.output_nc == 1:  # RGB to gray
                z_y = z_y[0, ...] * 0.299 + z_y[1, ...] * 0.587 + z_y[2, ...] * 0.114
                z_y = z_y.unsqueeze(0)
            z_y = Variable(z_y).type(self.Tensor)#.cuda()
            z_y = torch.unsqueeze(z_y, 0)

            mu_y, logvar_y, feat_map = self.netE_B.forward(z_y)
            self.mu_y.append(mu_y)
            self.logvar_y.append(logvar_y)
            self.feat_map_zy = feat_map
            real_A_zy = []
            for i in range(0, self.opt.batchSize):
                _real = torch.unsqueeze(self.real_A[i], 0)
                _real = torch.cat((_real, feat_map), dim=1)
                real_A_zy.append(_real)
            real_A_zy = torch.cat(real_A_zy)
            self.real_A_zy.append(real_A_zy)
        self.mu_y = torch.cat(self.mu_y)
        self.logvar_y = torch.cat(self.logvar_y)

        os.chdir(self.origin_path)

    def inference(self):
        real_A = Variable(self.input_A).type(self.Tensor)
        real_B = Variable(self.input_B).type(self.Tensor)

        # feature map
        #os.chdir(self.path_A)
        #mc_sample_x = random.sample(self.list_A, 1)
        #z_x = Image.open(mc_sample_x[0]).convert('RGB')
        #z_x = self.img_resize(z_x, self.opt.loadSize)
        #z_x = transform(z_x)
        #if self.opt.input_nc == 1:  # RGB to gray
            #z_x = z_x[0, ...] * 0.299 + z_x[1, ...] * 0.587 + z_x[2, ...] * 0.114
            #z_x = z_x.unsqueeze(0)
        #if self.opt.use_feat:
            #z_x = z_x[0, ...] * 0.299 + z_x[1, ...] * 0.587 + z_x[2, ...] * 0.114
            #z_x = z_x.unsqueeze(0)
        #z_x = Variable(z_x).type(self.Tensor)
        #z_x = torch.unsqueeze(z_x, 0)

        #if not self.opt.use_feat:
            #mu_x, logvar_x, feat_map_zx = self.netE_A.forward(z_x)
        #else:
            #feat_map_zx = z_x

        os.chdir(self.path_B)
        mc_sample_y = random.sample(self.list_B, 1)
        z_y = Image.open(mc_sample_y[0]).convert('RGB')
        z_y = self.img_resize(z_y, self.opt.loadSize)
        z_y = transform(z_y)
        if self.opt.output_nc == 1:  # RGB to gray
            z_y = z_y[0, ...] * 0.299 + z_y[1, ...] * 0.587 + z_y[2, ...] * 0.114
            z_y = z_y.unsqueeze(0)
        if self.opt.use_feat:
            z_y = z_y[0, ...] * 0.299 + z_y[1, ...] * 0.587 + z_y[2, ...] * 0.114
            z_y = z_y.unsqueeze(0)
        z_y = Variable(z_y).type(self.Tensor)
        z_y = torch.unsqueeze(z_y, 0)

        if not self.opt.use_feat:
            mu_y, logvar_y, feat_map_zy = self.netE_B.forward(z_y)
        else:
            feat_map_zy = z_y

        os.chdir(self.origin_path)

        # combine input image with random feature map
        #real_B_zx = []
        #for i in range(0, self.opt.batchSize):
            #_real = torch.cat((real_B[i:i+1], feat_map_zx), dim=1)
            #real_B_zx.append(_real)
        #real_B_zx = torch.cat(real_B_zx)
        real_A_zy = []
        for i in range(0, self.opt.batchSize):
            _real = torch.cat((real_A[i:i+1], feat_map_zy), dim=1)
            real_A_zy.append(_real)
        real_A_zy = torch.cat(real_A_zy)

        # inference
        fake_B = self.netG_A(real_A_zy)
        #fake_B_next = torch.cat((fake_B, feat_map_zx), dim=1)
        #self.rec_A = self.netG_B(fake_B_next)
        self.fake_B = fake_B

        #fake_A = self.netG_B(real_B_zx)
        #fake_A_next = torch.cat((fake_A, feat_map_zy), dim=1)
        #self.rec_B = self.netG_A(fake_A_next)
        #self.fake_A = fake_A

    def get_image_paths(self):
        return self.image_paths

    def img_resize(self, img, target_width):
        ow, oh = img.size
        if (ow == target_width):
            return img
        else:
            w = target_width
            h = int(target_width * oh / ow)
        return img.resize((w, h), Image.BICUBIC)

    def get_z_random(self, batchSize, nz, random_type='gauss'):
        z = self.Tensor(batchSize, nz)
        if random_type == 'uni':
            z.copy_(torch.rand(batchSize, nz) * 2.0 - 1.0)
        elif random_type == 'gauss':
            z.copy_(torch.randn(batchSize, nz))
        z = Variable(z)
        return z

    def backward_G(self):
        # GAN loss D_A(G_A(A))
        fake_B = []
        for real_A in self.real_A_zy:
            _fake = self.netG_A(real_A)
            fake_B.append(_fake)
        fake_B = torch.cat(fake_B)

        pred_fake = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake, True)

        # GAN loss D_B(G_B(B))
        fake_A = []
        for real_B in self.real_B_zx:
            _fake = self.netG_B(real_B)
            fake_A.append(_fake)
        fake_A = torch.cat(fake_A)

        pred_fake = self.netD_B(fake_A)
        loss_G_B = self.criterionGAN(pred_fake, True)

        # cycle loss
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B

        # Forward cycle loss
        fake_B_next = []
        for i in range(0, fake_B.size(0)):
        	_fake = fake_B[i:(i+1)]
        	_fake = torch.cat((_fake, self.feat_map_zx), dim=1)
        	fake_B_next.append(_fake)
        fake_B_next = torch.cat(fake_B_next)

        rec_A = self.netG_B(fake_B_next)
        loss_cycle_A = 0
        for i in range(0, self.opt.mc_y):
            loss_cycle_A += self.criterionCycle(rec_A[i*self.real_A.size(0):(i+1)*self.real_A.size(0)], self.real_A) * lambda_A
        pred_cycle_G_A = self.netD_B(rec_A)
        loss_cycle_G_A = self.criterionGAN(pred_cycle_G_A, True)

        # Backward cycle loss
        fake_A_next = []
        for i in range(0, fake_A.size(0)):
        	_fake = fake_A[i:(i+1)]
        	_fake = torch.cat((_fake, self.feat_map_zy), dim=1)
        	fake_A_next.append(_fake)
        fake_A_next = torch.cat(fake_A_next)

        rec_B = self.netG_A(fake_A_next)
        loss_cycle_B = 0
        for i in range(0, self.opt.mc_x):
            loss_cycle_B += self.criterionCycle(rec_B[i*self.real_B.size(0):(i+1)*self.real_B.size(0)], self.real_B) * lambda_B
        pred_cycle_G_B = self.netD_A(rec_B)
        loss_cycle_G_B = self.criterionGAN(pred_cycle_G_B, True)

        # prior loss
        prior_loss_G_A = self.get_prior(self.netG_A.parameters(), self.opt.batchSize)
        prior_loss_G_B = self.get_prior(self.netG_B.parameters(), self.opt.batchSize)

        # KL loss
        kl_element = self.mu_x.pow(2).add_(self.logvar_x.exp()).mul_(-1).add_(1).add_(self.logvar_x)
        loss_kl_EA = torch.sum(kl_element).mul_(-0.5) * self.opt.lambda_kl

        kl_element = self.mu_y.pow(2).add_(self.logvar_y.exp()).mul_(-1).add_(1).add_(self.logvar_y)
        loss_kl_EB = torch.sum(kl_element).mul_(-0.5) * self.opt.lambda_kl

        # total loss
        loss_G =  loss_G_A + loss_G_B + (prior_loss_G_A + prior_loss_G_B) + (loss_cycle_G_A + loss_cycle_G_B) * self.opt.gamma + (loss_cycle_A + loss_cycle_B) + (loss_kl_EA + loss_kl_EB)
        loss_G.backward()

        self.fake_B = fake_B
        self.fake_A = fake_A
        self.rec_A = rec_A
        self.rec_B = rec_B

        self.loss_G_A = loss_G_A.item() + loss_cycle_G_A.item() * self.opt.gamma + prior_loss_G_A.item()
        self.loss_G_B = loss_G_B.item() + loss_cycle_G_B.item() * self.opt.gamma + prior_loss_G_A.item()
        self.loss_cycle_A = loss_cycle_A.item()
        self.loss_cycle_B = loss_cycle_B.item()
        self.loss_kl_EA = loss_kl_EA.item()
        self.loss_kl_EB = loss_kl_EB.item()

    def backward_D_A(self):
        fake_B = Variable(self.fake_B).type(self.Tensor)#.cuda()
        rec_B = Variable(self.rec_B).type(self.Tensor)#.cuda()
        # how well it classifiers fake images
        pred_fake = self.netD_A(fake_B.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        pred_cycle_fake = self.netD_A(rec_B.detach())
        loss_D_cycle_fake = self.criterionGAN(pred_cycle_fake, False)

        # how well it classifiers real images
        pred_real = self.netD_A(self.real_B)
        loss_D_real = self.criterionGAN(pred_real, True) * self.opt.mc_y

        # prior loss
        prior_loss_D_A = self.get_prior(self.netD_A.parameters(), self.opt.batchSize)

        # total loss
        loss_D_A = (loss_D_real + loss_D_fake) * 0.5 + (loss_D_real + loss_D_cycle_fake) * 0.5 * self.opt.gamma + prior_loss_D_A
        loss_D_A.backward()
        self.loss_D_A = loss_D_A.item()

    def backward_D_B(self):
        fake_A = Variable(self.fake_A).type(self.Tensor)#.cuda()
        rec_A = Variable(self.rec_A).type(self.Tensor)#.cuda()
        # how well it classifiers fake images
        pred_fake = self.netD_B(fake_A.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        pred_cycle_fake = self.netD_B(rec_A.detach())
        loss_D_cycle_fake = self.criterionGAN(pred_cycle_fake, False)

        # how well it classifiers real images
        pred_real = self.netD_B(self.real_A)
        loss_D_real = self.criterionGAN(pred_real, True) * self.opt.mc_x

        # prior loss
        prior_loss_D_B = self.get_prior(self.netD_B.parameters(), self.opt.batchSize)

        # total loss
        loss_D_B = (loss_D_real + loss_D_fake) * 0.5 + (loss_D_real + loss_D_cycle_fake) * 0.5 * self.opt.gamma + prior_loss_D_B
        loss_D_B.backward()
        self.loss_D_B = loss_D_B.item()


    def optimize(self):
        # forward
        self.forward()
        # G_A and G_B
        # E_A and E_B
        self.optimizer_G.zero_grad()
        self.optimizer_E_A.zero_grad()
        self.optimizer_E_B.zero_grad()

        self.backward_G()
        
        self.optimizer_G.step()
        self.optimizer_E_A.step()
        self.optimizer_E_B.step()
        # D_A
        self.optimizer_D_A.zero_grad()

        self.backward_D_A()

        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()

        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_loss(self):
        loss = OrderedDict([
            ('D_A', self.loss_D_A),
            ('D_B', self.loss_D_B),
            ('G_A', self.loss_G_A),
            ('G_B', self.loss_G_B)
        ])
        if self.opt.gamma == 0:
            loss['cyc_A'] = self.loss_cycle_A
            loss['cyc_B'] = self.loss_cycle_B
        elif self.opt.gamma > 0:
            loss['cyc_G_A'] = self.loss_cycle_A
            loss['cyc_G_B'] = self.loss_cycle_B
        if self.opt.lambda_kl > 0:
        	loss['kl_EA'] = self.loss_kl_EA
        	loss['kl_EB'] = self.loss_kl_EB
        return loss

    def get_stye_loss(self):
        loss = OrderedDict([
            ('L1_A', self.loss_G_A_L1),
            ('L1_B', self.loss_G_B_L1)
        ])
        return loss

    def get_current_visuals(self):
        #real_A = tensor2im(self.input_A)
        fake_B = tensor2im(self.fake_B)
        #rec_A = tensor2im(self.rec_A)
        #real_B = tensor2im(self.input_B)
        #fake_A = tensor2im(self.fake_A)
        #rec_B = tensor2im(self.rec_B)
        #visuals = OrderedDict([
            #('real_A', real_A),
            #('fake_B', fake_B),
            #('rec_A', rec_A),
            #('real_B', real_B),
            #('fake_A', fake_A),
            #('rec_B', rec_B)
        #])
        visuals = OrderedDict([('fake_B', fake_B)])
        return visuals

    def get_prior(self, parameters, dataset_size):
        prior_loss = Variable(torch.zeros((1))).cuda()
        for param in parameters:
            prior_loss += torch.mean(param*param)
        return prior_loss / dataset_size

    def save_model(self, label):
        self.save_network(self.netG_A, 'G_A', label)
        self.save_network(self.netG_B, 'G_B', label)
        self.save_network(self.netE_A, 'E_A', label)
        self.save_network(self.netE_B, 'E_B', label)
        self.save_network(self.netD_A, 'D_A', label)
        self.save_network(self.netD_B, 'D_B', label)

    def load_network(self, network, network_label, epoch_label, save_dir=''):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        try:
            network.load_state_dict(torch.load(save_path))
        except:
            pretrained_dict = torch.load(save_path)
            model_dict = network.state_dict()
            try:
                pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
                network.load_state_dict(pretrained_dict)
                print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
            except:
                print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
                if sys.version_info >= (3, 0):
                    not_initialized = set()
                else:
                    from sets import Set
                    not_initialized = Set()
                for k, v in pretrained_dict.items():
                    if v.size() == model_dict[k].size():
                        model_dict[k] = v

                for k, v in model_dict.items():
                    if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
                        not_initialized.add(k.split('.')[0])
                print(sorted(not_initialized))
                network.load_state_dict(model_dict)

    def save_network(self, network, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if torch.cuda.is_available():
            network.cuda()

    def print_network(self, net):
        num_params = 0
        for param in net.parameters():
            num_params += param.numel()
        print(net)
        print('Total number of parameters: %d' % num_params)

    # update learning rate (called once every iter)
    def update_learning_rate(self, epoch, epoch_iter, dataset_size):
        # lrd = self.opt.lr / self.opt.niter_decay
        if epoch > self.opt.niter:
            lr = self.opt.lr * np.exp(-1.0 * min(1.0, epoch_iter/float(dataset_size)))
            for param_group in self.optimizer_D_A.param_groups:
                param_group['lr'] = lr
            for param_group in self.optimizer_D_B.param_groups:
                param_group['lr'] = lr
            for param_group in self.optimizer_G.param_groups:
                param_group['lr'] = lr
            print('update learning rate: %f -> %f' % (self.old_lr, lr))
            self.old_lr = lr
        else:
            lr = self.old_lr

# Dataset

In [None]:
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

class BaseDataLoader():
    def __init__(self):
        pass
    
    def initialize(self, opt):
        self.opt = opt
        pass

    def load_data():
        return None

class BaseDataset(data.Dataset):
    def __init__(self):
        super(BaseDataset, self).__init__()

    def name(self):
        return 'BaseDataset'

    def initialize(self, opt):
        pass

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)

    return images

def get_transform(opt):
    transform_list = []
    if opt.resize_or_crop == 'resize':         # 1024 x 1024
        osize = [opt.loadSize, opt.loadSize]
        transform_list.append(transforms.Scale(osize, Image.BICUBIC))
        # transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == 'crop':
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == 'scale_width':  # 1024 x 512
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.loadSize)))
    elif opt.resize_or_crop == 'scale_width_and_crop':
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.loadSize)))
        transform_list.append(transforms.RandomCrop(opt.fineSize))

    if opt.isTrain and not opt.no_flip:
        transform_list.append(transforms.RandomHorizontalFlip())

    transform_list += [transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5),
                                            (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)

def __scale_width(img, target_width):
    ow, oh = img.size
    if (ow == target_width):
        return img
    else:
        w = target_width
        h = int(target_width * oh / ow)
    return img.resize((w, h), Image.BICUBIC)

class UnalignedDataset(BaseDataset):
    def initialize(self, opt):
        self.opt = opt
        self.dir_A = opt.dataroot_A
        self.dir_B = opt.dataroot_B

        self.A_paths = make_dataset(self.dir_A)
        self.B_paths = make_dataset(self.dir_B)

        self.A_paths = sorted(self.A_paths)
        self.B_paths = sorted(self.B_paths)
        self.A_size = len(self.A_paths)
        self.B_size = len(self.B_paths)
        self.transform = get_transform(opt)

    def __getitem__(self, index):
        A_path = self.A_paths[index % self.A_size]
        index_A = index % self.A_size
        if self.opt.serial_batches:
            index_B = index % self.B_size
        else:
            index_B = random.randint(0, self.B_size - 1)
        B_path = self.B_paths[index_B]
        # print('(A, B) = (%d, %d)' % (index_A, index_B))
        A_img = Image.open(A_path).convert('RGB')
        B_img = Image.open(B_path).convert('RGB')

        A = self.transform(A_img)
        B = self.transform(B_img)
        if self.opt.which_direction == 'BtoA':
            input_nc = self.opt.output_nc
            output_nc = self.opt.input_nc
        else:
            input_nc = self.opt.input_nc
            output_nc = self.opt.output_nc

        if input_nc == 1:  # RGB to gray
            tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
            A = tmp.unsqueeze(0)

        if output_nc == 1:  # RGB to gray
            tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
            B = tmp.unsqueeze(0)
        return {'A': A, 'B': B,
                'A_paths': A_path, 'B_paths': B_path}

    def __len__(self):
        return max(self.A_size, self.B_size)

    def name(self):
        return 'UnalignedDataset'

def CreateDataset(opt):
    dataset = None
    dataset = UnalignedDataset()
    print("dataset [%s] was created" % (dataset.name()))
    dataset.initialize(opt)
    # dataset.__getitem__(1)
    return dataset

class CustomDatasetDataLoader(BaseDataLoader):
    def name(self):
        return 'CustomDatasetDataLoader'

    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)
        self.dataset = CreateDataset(opt)
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=opt.batchSize,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.nThreads))

    def load_data(self):
        return self

    def __len__(self):
        return min(len(self.dataset), self.opt.max_dataset_size)

    def __iter__(self):
        for i, data in enumerate(self.dataloader):
            if i >= self.opt.max_dataset_size:
                break
            yield data

def CreateDataLoader(opt):
    data_loader = CustomDatasetDataLoader()
    data_loader.initialize(opt)
    return data_loader


# Training

In [None]:
class Config:
  def __init__(self):
    self.name = "Monet"
    self.checkpoints_dir = "/content/drive/MyDrive/photo2monet/cycleganbayesian/"
    self.model = 'CycleGAN'
    self.norm = 'instance'
    self.use_dropout = False
    self.gpu_ids = '0'
    self.which_direction = 'AtoB'

    self.batchSize = 1
    self.loadSize = 256
    self.ratio = 1
    self.fineSize = 256
    self.input_nc = 3
    self.output_nc = 3

    self.dataroot_A = '/content/photo_jpg/'
    self.dataroot_B = '/content/drive/MyDrive/photo2monet/monetphotos/'
    self.resize_or_crop = "scale_width"
    self.serial_batches = False
    self.no_flip = True
    self.nThreads = 1
    self.max_dataset_size = float("inf")

    self.display_winsize = 256
    self.display_id = 0
    self.display_port = 8097

    self.netG_A = 'global'
    self.netG_B = 'global'
    self.ngf = 32
    self.n_downsample_global = 2
    self.n_blocks_global = 6

    self.netD = 'mult_sacle'
    self.num_D_A = 1
    self.num_D_B = 1
    self.n_layers_D = 3
    self.ndf = 64

    self.initialized = True
    self.isTrain = False

    self.display_freq = 100
    self.display_single_pane_ncols = 0
    self.update_html_freq = 1000
    self.print_freq = 100
    self.save_latest_freq = 5000
    self.save_epoch_freq = 5

    self.continue_train = True
    self.gamma = 0.1
    self.epoch_count = 1
    self.phase = "train"
    self.which_epoch = "latest"
    self.niter = 50
    self.niter_decay = 50
    self.beta1 = 0.5
    self.lr = 0.0002
    self.no_lsgan = False
    self.lambda_A = 10.0
    self.lambda_B = 10.0
    self.lambda_kl = 0.1
    self.mc_y = 3
    self.mc_x = 3
    self.no_html = False
    self.lr_policy = 'lambda'
    self.lr_decay_iters = 50
    self.debug = False
    self.need_match = False
    self.use_feat = False

In [None]:
opt = Config()
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('training images = %d' % dataset_size)

dataset [UnalignedDataset] was created
training images = 7038


In [None]:
# continue train or not
if opt.continue_train:
    start_epoch = 11
    epoch_iter = 0
    print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
else:
    start_epoch, epoch_iter = 1, 0

Resuming from epoch 11 at iteration 0


In [None]:
visualizer = Visualizer(opt)

In [None]:
model = CycleGAN()
model.initialize(opt)

cuda is available, we will use gpu!
Network initialized!


In [None]:
img_dir = "/content/drive/MyDrive/photo2monet/cycleganbayesian/epoch40/"
for i, data in enumerate(dataset):
  model.set_input(data)
  model.inference()
  visuals = model.get_current_visuals()
  for label, image_numpy in visuals.items():
    img_path = os.path.join(img_dir, f'{i}.jpg')
    save_image(image_numpy, img_path)
  if i%100 == 0:
    print(f'{i} images saved')

0 images saved
100 images saved
200 images saved
300 images saved
400 images saved
500 images saved
600 images saved
700 images saved
800 images saved
900 images saved
1000 images saved
1100 images saved
1200 images saved
1300 images saved
1400 images saved
1500 images saved
1600 images saved
1700 images saved
1800 images saved
1900 images saved
2000 images saved
2100 images saved
2200 images saved
2300 images saved
2400 images saved
2500 images saved
2600 images saved
2700 images saved
2800 images saved
2900 images saved
3000 images saved
3100 images saved
3200 images saved
3300 images saved
3400 images saved
3500 images saved
3600 images saved
3700 images saved
3800 images saved
3900 images saved
4000 images saved
4100 images saved
4200 images saved
4300 images saved
4400 images saved
4500 images saved
4600 images saved
4700 images saved
4800 images saved
4900 images saved
5000 images saved
5100 images saved
5200 images saved
5300 images saved
5400 images saved
5500 images saved
5600