<a href="https://colab.research.google.com/github/kotetsu-n/pytorch_GANs_impl/blob/master/cycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip

In [None]:
!unzip horse2zebra.zip

In [None]:
!mkdir ./dataset
!mkdir ./dataset/trainA ./dataset/trainB ./dataset/testA ./dataset/testB
!mv ./horse2zebra/trainA ./dataset/trainA/
!mv ./horse2zebra/trainB ./dataset/trainB/
!mv ./horse2zebra/testA ./dataset/testA/
!mv ./horse2zebra/testB ./dataset/testB/

In [None]:
'''
To compose this notebook, 
I refererd to https://github.com/arnab39/cycleGAN-PyTorch
If there are any problems, please contact me.
'''

import torch
import torch.nn as nn
from torch.nn import init
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import itertools
import os
import copy
import numpy as np
import argparse

def init_weights(net, init_type='normal', gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            init.normal(m.weight.data, 0.0, gain)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal(m.weight.data, 1.0, gain)
            init.constant(m.bias.data, 0.0)

    print('Network initialized with weights sampled from N(0,0.02).')
    net.apply(init_func)

def init_network(net, gpu_id):
    if gpu_id > -1:
        assert(torch.cuda.is_available())
        net.cuda(gpu_id)
        # net = torch.nn.DataParallel(net, gpu_ids)
    init_weights(net)
    return net

def conv_norm_lrelu(in_dim, out_dim, kernel_size, stride=1, padding=0,
                    norm_layer = nn.BatchNorm2d, bias = False):
    return nn.Sequential(
        nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias = bias),
        norm_layer(out_dim), nn.LeakyReLU(0.2,True))

class NLayerDiscriminator(nn.Module):
    def __init__(self, input_ch, base_ch=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_bias=False):
        super(NLayerDiscriminator, self).__init__()
        dis_layers = [nn.Conv2d(input_ch, base_ch, kernel_size=4, stride=2, padding=1),
                           nn.LeakyReLU(0.2, True)]
        mult_coef = 1
        mult_coef_prev = 1
        for n in range(1, n_layers):
            mult_coef_prev = mult_coef
            mult_coef = min(2**n, 8)
            dis_layers += [nn.Conv2d(base_ch * mult_coef_prev, base_ch * mult_coef, 
                                    kernel_size=4, stride=2, padding=1, bias = use_bias),
                          norm_layer(base_ch * mult_coef),
                          nn.LeakyReLU(0.2, True)]
        mult_coef_prev = mult_coef
        mult_coef = min(2**n_layers, 8)
        dis_layers += [nn.Conv2d(base_ch * mult_coef_prev, base_ch * mult_coef,
                                kernel_size=4, stride=1, padding=1, bias=use_bias),
                                norm_layer(base_ch * mult_coef),
                                nn.LeakyReLU(0.2, True)]
        dis_layers += [nn.Conv2d(base_ch * mult_coef, 1, kernel_size=4, stride=1, padding=1)]

        self.dis_layers = nn.Sequential(*dis_layers)

    def forward(self, input):
        return self.dis_layers(input)

class PixelDiscriminator(nn.Module):
    def __init__(self, input_ch, base_ch=64, norm_layer=nn.BatchNorm2d, use_bias=False):
        super(PixelDiscriminator, self).__init__()
        dis_layers = [
            nn.Conv2d(input_ch, base_ch, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(base_ch, base_ch * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
            norm_lAyer(base_ch * 2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(base_ch * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
        
        self.dis_layers = nn.Sequential(*dis_layers)

    def forward(self, input):
        return self.dis_layers(input)

def def_discriminator(input_ch, base_ch, network_type='n_layers', n_layers=3, norm='batch', gpu_id=0):
    if norm == 'batch':
        norm_layer = nn.BatchNorm2d
        use_bias = False
    elif norm == 'instance':
        norm_layer = nn.InstanceNorm2d
        use_bias = True
    
    if network_type == 'n_layers':
        discriminator = NLayerDiscriminator(input_ch, base_ch, n_layers, norm_layer=norm_layer, use_bias=use_bias)
    elif network_type == 'pixel':
        discriminator = PixelDiscriminator(input_ch, base_ch, norm_layer=norm_layer, use_bias=use_bias)

    return init_network(discriminator, gpu_id)

def get_norm_layer(norm_type='instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer

In [None]:
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_ch=None, submodule=None, 
                 outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        use_bias = norm_layer = nn.InstanceNorm2d
        if input_ch == None:
            input_ch = outer_nc
        downconv = nn.Conv2d(input_ch, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc*2, outer_nc, kernel_size=4, stride=2, padding=1)
            down = [downconv]
            up = [nn.ReLU(True), upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            down = [nn.LeakyReLU(0.2, True), downconv]
            up = [nn.ReLU(True), upconv, norm_layer(outer_nc)]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc*2, outer_nc, kernel_size=4,
                                        stride=2, padding=1, bias=use_bias)
            down = [nn.LeakyReLU(0.2, True), downconv, norm_layer(inner_nc)]
            up = [nn.ReLU(True), upconv, norm_layer(outer_nc)]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)
                
class UnetGenerator(nn.Module):
    def __init__(self, input_ch, output_ch, num_downs, base_ch=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetGenerator, self).__init__()

        unet_block = UnetSkipConnectionBlock(base_ch*8, base_ch*8, submodule=None, 
                                             norm_layer=norm_layer, innermost=True)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(base_ch*8, base_ch*8,
                                                 submodule=unet_block,
                                                 norm_layer=norm_layer,
                                                 use_dropout=use_dropout)
        unet_block = UnetSkipConnectionBlock(base_ch*4, base_ch*8, submodule=unet_block, 
                                             norm_layer=norm_layer, use_dropout=use_dropout)
        unet_block = UnetSkipConnectionBlock(base_ch*2, base_ch*4, 
                                             submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(base_ch, base_ch*2, 
                                             submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(output_ch, base_ch, input_ch=input_ch,
                                             submodule=unet_block, outermost=True,
                                             norm_layer=norm_layer)
        self.unet_model = unet_block

    def forward(self, input):
        return self.unet_model(input)

def def_generator(input_ch, output_ch, base_ch, network_type='unet128', 
               norm='batch', use_dropout=False, gpu_id=0):
    if norm == 'batch':
        norm_layer = nn.BatchNorm2d
    elif norm == 'instance':
        norm_layer = nn.InstanceNorm2d

    if network_type == 'unet128':
        generator = UnetGenerator(input_ch, output_ch, 7, base_ch, norm_layer=norm, use_dropout=use_dropout)
    elif network_type == 'unet256':
        generator = UnetGenerator(input_ch, output_ch, 8, base_ch, norm_layer=norm, use_dropout=use_dropout)
    
    return init_network(generator, gpu_id)

class LambdaLR():
    def __init__(self, epochs, offset, decay_epoch):
        self.epochs = epochs
        self.offset = offset
        self.decay_epoch = decay_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_epoch)/(self.epochs - self.decay_epoch)

In [None]:
class sample_from_pool(object):
        """
        sample from pooled items

        Args:
            max_elements (int): max number of pool items
        """        
        def __init__(self, max_elements=50):
            self.max_elements = max_elements
            self.cur_elements = 0
            self.items = []

        def __call__(self, in_items):
            return_items = []
            for in_item in in_items:
                if self.cur_elements < self.max_elements:
                    self.items.append(in_item)
                    self.cur_elements += 1
                    return_items.append(in_item)
                else:
                    if np.random.ranf() > 0.5:
                        idx = np.random.randint(0, self.max_elements)
                        tmp = copy.copy(self.items[idx])
                        self.items[idx] = in_item
                        return_items.append(tmp)
                    else:
                        return_items.append(in_item)
            return return_items

class cycleGAN(object):
    """
        A (source domain), B (target domain).
        Generators: G_A: A -> B; G_B: B -> A.
        Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A.
    """
    def __init__(self, args):
        self.G_A = def_generator(input_ch=3, output_ch=3, base_ch=args.base_ch, network_type=args.generator_net,
                                 norm='batch', use_dropout=args.dropout, gpu_id=args.gpu_id)
        self.G_B = def_generator(input_ch=3, output_ch=3, base_ch=args.base_ch, network_type=args.generator_net,
                                 norm='batch', use_dropout=args.dropout, gpu_id=args.gpu_id)
        
        self.D_A = def_discriminator(input_ch=3, base_ch=args.base_ch, network_type=args.discriminator_net,
                                    n_layers=3, norm=args.norm, gpu_id=args.gpu_id)
        self.D_B = def_discriminator(input_ch=3, base_ch=args.base_ch, network_type=args.discriminator_net,
                                    n_layers=3, norm=args.norm, gpu_id=args.gpu_id)

        # losses
        self.MSE = nn.MSELoss()
        self.L1 = nn.L1Loss()

        print(self.G_A)

        # optimiezers
        self.g_optimizer = torch.optim.Adam(itertools.chain(self.G_A.parameters(), self.G_B.parameters()),
                                            lr=args.lr, betas=(0.5, 0.999))
        self.d_optimizer = torch.optim.Adam(itertools.chain(self.D_A.parameters(), self.D_B.parameters()),
                                            lr=args.lr, betas=(0.5, 0.999))

        self.g_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.g_optimizer,
                                                                lr_lambda=LambdaLR(args.epochs, 0, args.decay_epoch).step)
        self.d_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.d_optimizer,
                                                                lr_lambda=LambdaLR(args.epochs, 0, args.decay_epoch).step)

        # reserve
        self.training_dirs = None
        self.a_loader = None
        self.b_loader = None
        self.a_fake_sample = None
        self.b_fake_sample = None

        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir, exist_ok=True)
        
        try:
            ckpt = torch.load(f'{args.checkpoint_dir}/latest.ckpt')
            self.start_epoch = ckpt['epoch']
            self.D_A.load_state_dict(ckpt['D_A'])
            self.D_B.load_state_dict(ckpt['D_B'])
            self.G_A.load_state_dict(ckpt['G_A'])
            self.G_B.load_state_dict(ckpt['G_B'])
            self.d_optimizer.load_state_dict(ckpt['d_optimizer'])
            self.g_optimizer.load_state_dict(ckpt['g_optimizer'])
        except:
            print('No checkpoints!!')
            self.start_epoch = 0

    def prepare_data(self, args):
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize((args.load_height, args.load_width)),
            transforms.RandomCrop((args.crop_height, args.crop_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
        self.training_dirs = {'trainA': os.path.join(args.dataset_dir, 'trainA'), 
                                     'trainB': os.path.join(args.dataset_dir, 'trainB')}
    
        self.a_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(self.training_dirs['trainA'], transform=transform),
            batch_size=args.batch_size, shuffle=True, num_workers=4)
        self.b_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(self.training_dirs['trainB'], transform=transform),
            batch_size=args.batch_size, shuffle=True, num_workers=4)

        self.a_fake_sample = sample_from_pool()
        self.b_fake_sample = sample_from_pool()

    def set_grad_status(self, nets, requires_grad=False):
        for net in nets:
            for param in net.parameters():
                param.requires_grad = requires_grad
        
    def train(self, args, test=None, output_interval=None):
        self.prepare_data(args)
        for epoch in range(self.start_epoch, args.epochs):
            lr = self.g_optimizer.param_groups[0]['lr']
            print(f'current learning rate = {lr}')

            epoch_gen_loss = 0
            epoch_dis_loss= 0

            for i, (a_real, b_real) in enumerate(zip(self.a_loader, self.b_loader)):
                step = epoch * min(len(self.a_loader), len(self.b_loader)) + i + 1

                # ------ GENERATOR ------
                self.set_grad_status([self.D_A, self.D_B], False)  ### really needed?
                self.g_optimizer.zero_grad()

                a_real = Variable(a_real[0]).cuda()
                b_real = Variable(b_real[0]).cuda()
                
                a_fake = self.G_B(b_real)
                b_fake = self.G_A(a_real)
                
                a_recon = self.G_B(b_fake)
                b_recon = self.G_A(a_fake)

                a_idt = self.G_B(a_real)
                b_idt = self.G_A(b_real)

                # identity losses L1(idt <-> real)
                a_idt_loss = self.L1(a_idt, a_real) * args.lamda * args.idt_coef
                b_idt_loss = self.L1(b_idt, b_real) * args.lamda * args.idt_coef

                # adversarial Losses MSE(pred(fake) <-> 1.0)
                pred_a_fake = self.D_A(a_fake)
                pred_b_fake = self.D_B(b_fake)

                real_label = Variable(torch.ones(pred_a_fake.size())).cuda()

                a_gen_loss = self.MSE(pred_a_fake, real_label)
                b_gen_loss = self.MSE(pred_b_fake, real_label)

                # cycle consistency losses L1(recon <-> real)
                a_cycle_loss = self.L1(a_recon, a_real) * args.lamda
                b_cycle_loss = self.L1(b_recon, b_real) * args.lamda

                ### total generators losses
                gen_loss = a_gen_loss + b_gen_loss + a_cycle_loss + b_cycle_loss + a_idt_loss + b_idt_loss

                # update generators
                gen_loss.backward()
                self.g_optimizer.step()


                # ------ DISCRIMINATOR ------ 
                self.set_grad_status([self.D_A, self.D_B], True) ### really needed?
                self.d_optimizer.zero_grad()

                # sample from histry of generated images
                a_fake = Variable(torch.Tensor(self.a_fake_sample([a_fake.cpu().data.numpy()])[0])).cuda()
                b_fake = Variable(torch.Tensor(self.b_fake_sample([b_fake.cpu().data.numpy()])[0])).cuda()

                # forward
                pred_a_real = self.D_A(a_real)
                pred_b_real = self.D_B(b_real)
                pred_a_fake = self.D_A(a_fake)
                pred_b_fake = self.D_B(b_fake)
                real_label = Variable(torch.ones(pred_a_real.size())).cuda()
                fake_label = Variable(torch.zeros(pred_a_fake.size())).cuda()

                # Discriminator losses
                dis_a_real_loss = self.MSE(pred_a_real, real_label)
                dis_b_real_loss = self.MSE(pred_b_real, real_label)
                dis_a_fake_loss = self.MSE(pred_a_fake, fake_label)
                dis_b_fake_loss = self.MSE(pred_b_fake, fake_label)

                # Total discriminators losses
                dis_a_loss = (dis_a_real_loss + dis_a_fake_loss)*0.5
                dis_b_loss = (dis_b_real_loss + dis_b_fake_loss)*0.5

                # Update
                dis_a_loss.backward()
                dis_b_loss.backward()
                self.d_optimizer.step()

                log_str = f'Epoch: {epoch}, ({i+1}/{min(len(self.a_loader), len(self.b_loader))}) \
                                Gen loss:{gen_loss:.4f}, Dis loss:{0.5*(dis_a_loss+dis_b_loss):.4f}'
                print("\r"+log_str, end="")
                
                epoch_gen_loss += gen_loss
                epoch_dis_loss += 0.5*(dis_a_loss+dis_b_loss)

            epoch_gen_loss /= min(len(self.a_loader), len(self.b_loader))
            epoch_dis_loss /= min(len(self.a_loader), len(self.b_loader))
            print(f'Epoch: {epoch}, Gen loss:{epoch_gen_loss:.4f}, Dis loss:{epoch_dis_loss:.4f}')

            torch.save({'epoch':epoch+1,
                        'D_A': self.D_A.state_dict(),
                        'D_B': self.D_B.state_dict(),
                        'G_A': self.G_A.state_dict(),
                        'G_B': self.G_B.state_dict(),
                        'd_optimizer': self.g_optimizer.state_dict(),
                        'g_optimizer': self.g_optimizer.state_dict()}, 
                        f'{args.checkpoint_dir}/lates.ckpt')

            # update learning late
            self.g_lr_scheduler.step()
            self.d_lr_scheduler.step()

            if test is not None:
                if epoch % output_interval == 0:
                    test(args, epoch)

In [None]:
def test(args, epoch=-1):
    transform = transforms.Compose(
        [transforms.Resize((args.crop_height,args.crop_width)),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

    test_dirs = {'testA': os.path.join(args.dataset_dir, 'testA'),
                 'testB': os.path.join(args.dataset_dir, 'testB')}

    a_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(test_dirs['testA'], transform=transform),
        batch_size=args.batch_size, shuffle=True, num_workers=4)
    b_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(test_dirs['testB'], transform=transform),
        batch_size=args.batch_size, shuffle=True, num_workers=4)

    G_A = def_generator(input_ch=3, output_ch=3, base_ch=args.base_ch, network_type=args.generator_net, 
                        norm=args.norm, use_dropout= args.dropout, gpu_id=args.gpu_id)
    G_B = def_generator(input_ch=3, output_ch=3, base_ch=args.base_ch, network_type=args.generator_net, 
                        norm=args.norm, use_dropout= args.dropout, gpu_id=args.gpu_id)

    try:
        ckpt = torch.load(f'{args.checkpoint_dir}/latest.ckpt')
        self.G_A.load_state_dict(ckpt['G_A'])
        self.G_B.load_state_dict(ckpt['G_B'])
    except:
        print('No checkpoints!!')

    ### Generation
    a_real = Variable(iter(a_loader).next()[0], requires_grad=True).cuda()
    b_real = Variable(iter(b_loader).next()[0], requires_grad=True).cuda()

    G_A.eval()
    G_B.eval()

    with torch.no_grad():
        a_fake = G_B(b_real)
        b_fake = G_A(a_real)
        a_recon = G_B(b_fake)
        b_recon = G_A(a_fake)
        a_idt = G_B(a_real)
        b_idt = G_A(b_real)
        
    pic = (torch.cat([a_real, b_fake, a_recon, a_idt, b_real, a_fake, b_recon, b_idt], dim=0).data + 1) / 2.0

    if not os.path.isdir(args.results_dir):
        os.makedirs(args.results_dir)

    save_name = f'/sample_{epoch}.jpg' if epoch > -1 else '/sample.jpg'
    torchvision.utils.save_image(pic, args.results_dir+save_name, nrow=4)

In [None]:
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='cycleGAN PyTorch')
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--decay_epoch', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--lr', type=float, default=.0002)
    parser.add_argument('--load_height', type=int, default=286)
    parser.add_argument('--load_width', type=int, default=286)
    parser.add_argument('--gpu_id', type=str, default=0)
    parser.add_argument('--crop_height', type=int, default=256)
    parser.add_argument('--crop_width', type=int, default=256)
    parser.add_argument('--lamda', type=int, default=10)
    parser.add_argument('--idt_coef', type=float, default=0.5)
    parser.add_argument('--training', action='store_true')
    parser.add_argument('--testing', action='store_true')
    parser.add_argument('--results_dir', type=str, default='./results')
    parser.add_argument('--dataset_dir', type=str, default='./dataset')
    parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints/horse2zebra')
    parser.add_argument('--output_interval', type=int, default=5)
    parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
    parser.add_argument('--dropout', action='store_true', help='dropout for the generator')
    parser.add_argument('--base_ch', type=int, default=64, help='# of gen filters in first conv layer')
    parser.add_argument('--generator_net', type=str, default='unet256')
    parser.add_argument('--discriminator_net', type=str, default='n_layers')
    args = parser.parse_args(args=['--training', '--testing', '--output_interval', '1'])
    
    if args.training:
        model = cycleGAN(args)
        if args.testing and args.output_interval > 0:
            print('training & testing')
            model.train(args, test, args.output_interval)
        else:
            print('training')
            model.train(args)

    elif args.testing:
        print('testing')
        test(args)
