In [57]:
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from torch.autograd import Variable
import numpy as np
import itertools
import matplotlib.pyplot as plt
import skimage as sk
import skimage.io as skio
import os
% matplotlib inline

In [84]:
def load_images(path):
    return skio.imread_collection(['{0}/{1}'.format(path, filename) for filename in os.listdir(path)])

train_img = load_images('data/trainB')
train_art = load_images('data/trainA')
test_img = load_images('data/testB')
test_art = load_images('data/testA')

## Image Pool

In [69]:
class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return Variable(images)
        return_images = []
        for image in images:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images

## GAN Loss

In [60]:
class GANLoss(nn.Module):
    def __init__(self, real_label=1.0, fake_label=0.0):
        super(GANLoss, self).__init__()
        self.real_label = real_label
        self.fake_label = fake_label
        self.real_label_var = None
        self.fake_label_var = None
        self.loss = nn.MSELoss()
    
    def get_target_tensor(self, input, target_is_real):
        target_tensor = None
        if target_is_real:
            create_label = ((self.real_label_var is None) or 
                           (self.real_label_var.numel() != input.numel()))
            if create_label:
                real_tensor = torch.FloatTensor(input.size()).fill_(self.real_label)
                self.real_label_var = Variable(real_tensor, requires_grad=False)
            target_tensor = self.real_label_var
        else:
            create_label = ((self.fake_label_var is None) or 
                           (self.fake_label_var.numel() != input.numel()))
            if create_label:
                fake_tensor = torch.FloatTensor(input.size()).fill_(self.fake_label)
                self.fake_label_var = Variable(fake_tensor, requires_grad=False)
            target_tensor = self.real_label_var 
        return target_tensor
        
    def __call__(self, input, target_is_real):
        target_tensor = self.get_target_tensor(input, target_is_real)
        return self.loss(input, target_tensor)

## Discriminator

In [113]:
class NLayerDiscriminator(nn.Module):
    def __init__(self, n_input_channels, n_filters=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
        super(NLayerDiscriminator, self).__init__()
        use_bias = norm_layer == nn.InstanceNorm2d

        kernel_size = 4
        padding_size = 1
        sequence = [nn.Conv2d(n_input_channels, n_filters, kernel_size=kernel_size, stride=2, padding=padding_size),
                    nn.LeakyReLU(0.2, True)]

        n_filters_mult = 1
        n_filters_mult_prev = 1
        for n in range(1, n_layers):
            n_filters_mult_prev = n_filters_mult
            nf_mult = min(2**n, 8)
            sequence += [nn.Conv2d(n_filters * n_filters_mult_prev, n_filters * n_filters_mult,
                                  kernel_size=kernel_size, stride=2, padding=padding_size, bias=use_bias), 
                        norm_layer(n_filters * n_filters_mult), 
                        nn.LeakyReLU(0.2, True)]
            
        n_filters_mult_prev = n_filters_mult
        n_filters_mult = min(2**n_layers, 8)
        sequence += [nn.Conv2d(n_filters * n_filters_mult_prev, n_filters * n_filters_mult,
                              kernel_size=kernel_size, stride=1, padding=padding_size, bias=use_bias), 
                    norm_layer(n_filters * n_filters_mult),
                    nn.LeakyReLU(0.2, True)]
        sequence += [nn.Conv2d(n_filters * n_filters_mult, 1, kernel_size=kernel_size, stride=1, padding=padding_size)]
        
        if use_sigmoid: sequence += [nn.Sigmoid()]
        self.model = nn.Sequential(*sequence)
        
    def forward(self, input):
        return self.model(input)

## Generator

In [98]:
class ResNetGenerator(nn.Module):
    def __init__(self, n_input_channels, n_output_channels, n_filters=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, gpu_ids=[], padding_type='reflect'):
        super(ResNetGenerator, self).__init__()
        self.n_input_channels = n_input_channels
        self.n_output_channels = n_output_channels
        self.n_filters = n_filters
        self.gpu_ids = gpu_ids
        use_bias = norm_layer == nn.InstanceNorm2d
        
        model = [nn.ReflectionPad2d(3),
                nn.Conv2d(n_input_channels, n_filters, kernel_size=7, padding=0, bias=use_bias),
                norm_layer(n_filters),
                nn.ReLU(True)]
        
        # Downsampling
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(n_filters * mult, n_filters * mult * 2, kernel_size=3,
                               stride=2, padding=1, bias=use_bias),
                     norm_layer(n_filters * mult * 2),
                     nn.ReLU(True)]
            
        # ResNet Blocks
        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResNetBlock(n_filters * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
        
        # Upsampling
        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ConvTranspose2d(n_filters * mult, int(n_filters * mult/2),
                                        kernel_size=3, stride=2,
                                        padding=1, output_padding=1,
                                        bias=use_bias),
                     norm_layer(int(n_filters * mult / 2)),
                     nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(n_filters, n_output_channels, kernel_size=7, padding=0)]
        model += [nn.Tanh()]
        
        self.model = nn.Sequential(*model) 

    def forward(self, input):
        return self.model(input)

## ResNet Block 

In [107]:
class ResNetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        super(ResNetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
        
    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        block = []
        p = 0
        
#         if padding_type == 'reflect': block += [nn.ReflectionPad2d(1)]
#         elif padding_type == 'replicate': block += [nn.ReplicationPad2d(1)]
#         elif padding_type == 'zero': p = 1
            
        block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim),
                       nn.ReLU(True)]
        
        if use_dropout: 
            block += [nn.Dropout(0.5)]
        
        p = 0
        if padding_type == 'reflect': block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate': block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero': p = 1
            
        return nn.Sequential(*block)
    
    def forward(self, x):
        out = x + self.conv_block(x)
        return out

## Cycle GAN

In [116]:
class CycleGan():
    def __init__(self, train_img, train_art):
        size = train_img[0].shape[0]
        n_input_channels, n_output_channels = train_img[0].shape[-1], train_img[0].shape[-1]
        n_filters = 64
        batch_size = 1
        lr = 0.0002
        beta1 = 0.5
        pool_size = 50
        
        self.input_img = torch.Tensor(batch_size, n_input_channels, size, size)
        self.input_art = torch.Tensor(batch_size, n_input_channels, size, size)

        self.fake_img_pool = ImagePool(pool_size)
        self.fake_art_pool = ImagePool(pool_size)
        
        # Define G's and D's
        self.G_img = ResNetGenerator(n_input_channels, n_input_channels, n_filters)
        self.G_art = ResNetGenerator(n_input_channels, n_output_channels, n_filters)
        self.D_img = NLayerDiscriminator(3)
        self.D_art = NLayerDiscriminator(3)
        
        # Define loss functions
        self.criterion_GAN = GANLoss()
        self.criterion_cycle = torch.nn.L1Loss()
        self.criterion_identity = torch.nn.L1Loss()
        
        # Optimizers
        self.optimizer_G = torch.optim.Adam(itertools.chain(self.G_img.parameters(), self.G_art.parameters()), 
                                            lr=lr, betas=(beta1, 0.999))
        self.optimizer_D_img = torch.optim.Adam(self.D_img.parameters(), lr=lr, betas=(beta1, 0.999))
        self.optimizer_D_art = torch.optim.Adam(self.D_art.parameters(), lr=lr, betas=(beta1, 0.999))
        self.optimizers = [self.optimizer_G, self.optimizer_D_img, self.optimizer_D_art]
        self.schedulers = []
        for optimizer in self.optimizers:
            def lambda_rule(epoch):
                epoch_count = 1
                niter = 100
                niter_decay = 100
                lr_l = 1.0 - max(0, epoch + 1 + epoch_count - niter) / float(niter_decay + 1)
                return lr_l
            scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
            self.schedulers.append(scheduler)
    
    def set_input(self, input):
        self.input_A = input['A']
        self.input_B = input['B']
        
    def forward(self):
        self.real_img = Variable(self.input_img)
        self.real_art = Variable(self.input_art)
        
    def backward_D_basic(self, D, real, fake):
        # Real
        pred_real = D(real)
        loss_D_real = self.criterion_GAN(pred_real, True)
        # Fake
        pred_fake = D(fake.detach())
        loss_D_fake = self.criterion_GAN(pred_fake, False)
        # Combined Loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D
        
    def backward_D_img(self):
        fake_art = self.fake_art_pool(self.fake_art)
        loss_D_img = self.backward_D_basic(self.D_img, self.real_art, fake_art)
        self.loss_D_img = loss_D_img.data[0]

    def backward_D_art(self):
        fake_img = self.fake_img_pool(self.fake_img)
        loss_D_art = self.backward_D_basic(self.D_art, self.real_img, fake_img)
        self.loss_D_art = loss_D_art.data[0]
        
    def backward_G(self):
        lambda_idt = 0.5
        lambda_img = 10.0
        lambda_art = 10.0
        
        # Passing in art to art generator should keep art
        idt_img = self.G_img(self.real_art)
        loss_idt_img = self.criterion_identity(idt_img, self.real_art) * lambda_art * lambda_idt
        
        # Passing in images to img generator should keep image
        idt_art = self.G_art(self.real_img)
        loss_idt_art = self.criterion_identity(idt_art, self.real_art) * lambda_img * lambda_idt
        
        self.idt_img = idt_img.data
        self.idt_art = idt_art.data
        self.loss_idt_img = loss_idt_img.data[0]
        self.loss_idt_art = loss_idt_art.data[0]
        
        # GAN loss for generating fake art
        fake_art = self.G_img(self.real_img)
        pred_fake = self.D_img(fake_art)
        loss_G_img = self.criterionGAN(pred_fake, True)
        
        # GAN loss for generating fake images
        fake_img = self.G_art(self.real_art)
        pred_fake = self.D_art(fake_img)
        loss_G_art = self.criterionGAN(pred_fake, True)
        
        # Forward cycle loss
        rec_img = self.G_art(fake_art)
        loss_cycle_img = self.criterion_cycle(rec_img, self.real_img) * lambda_img
        
        # Backward cycle loss 
        rec_art = self.G_img(fake_img)
        loss_cycle_art = self.criterion_cycle(rec_art, self.real_art) * lambda_art
        
        loss_G = loss_G_img + loss_G_art + loss_cycle_img + loss_cycle_art + loss_idt_img + loss_idt_art
        loss_G.backward()
        
        self.fake_art = fake_art.data
        self.fake_img = fake_img.data
        self.rec_img = rec_img.data
        self.rec_art = rec_art.data
        
        self.loss_G_img = loss_G_img.data[0]
        self.loss_G_art = loss_G_art.data[0]
        self.loss_cycle_img = loss_cycle_img.data[0]
        self.loss_cycle_art = loss_cycle_art.data[0]
        
    def optimize_parameters(self):
        self.forward()
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.opimtizer_G.step()
        
        self.optimizer_D_img.zero_grad()
        self.backward_D_img()
        self.optimizer_D_img.step()
        
        self.optimizer_D_art.zero_grad()
        self.backward_D_art()
        self.optimizer_D_art.step()
        

In [117]:
class Options():
    def __init__(self):
        self.niter = 100
        self.niter_decay = 100

# Training

In [118]:
model = CycleGan(train_img, train_art)
options = Options()

for epoch in range(options.niter + options.niter_decay):
    for i, _ in enumerate(train_img):
        data = {'A': train_img[i], 'B': train_art[i]}
        model.set_input(data)
        model.optimize_parameters()
        print('works')

AttributeError: 'CycleGan' object has no attribute 'criterionGAN'