# pix2pix in Tensorflow

## Load images

In [None]:
from tensorflow.keras.layers import Layer
from tensorflow.keras import Model
import tensorflow as tf
import numpy as np
import math

### Common layers
The paper defines the following types of Layers:
- C(k)
    - Convolution (k filters; 4x4; stride 2)
    - BatchNorm (in testing and training)
    - ReLU
- CD(k)
    - Convolution (k filters; 4x4; stride 2)
    - BatchNorm (in testing and training)
    - Dropout (50% rate)
    - ReLU

In [None]:
class C(Layer):
    """This layer represents the C(k) layer described in the pix2pix paper. The activation function 
        is a parameter to allow the use of different activation functions like ReLU and leaky ReLU for 
        encoder and decoder. The sampling_factor gives a factor by which the convolution output will be 
        sampled up or down. A value of 2 will sample the tensor up by 2. A value of 0.5 will sample the 
        tensor down by 2."""
        
    def __init__(self, k, activation=None, sampling='down', batchnorm=True):
        super(C, self).__init__()
        if sampling == 'up':
            self.conv = tf.keras.layers.Conv2DTranspose(k, kernel_size=4, strides=2, activation=None, padding='same')
        elif sampling == 'down':
            self.conv = tf.keras.layers.Conv2D(k, kernel_size=4, strides=2, activation=None, padding='same')
        else:
            raise AttributeError('illegal sampling mode: "' + str(sampling) + '"')
            
        self.batchnorm = None
        if batchnorm:
            self.batchnorm = tf.keras.layers.BatchNormalization()
            
        self.activation = activation
        
    def call(self, x):
        x = self.conv(x)
        
        if self.batchnorm != None:
            x = self.batchnorm(x)
        
        if self.activation != None:
            x = self.activation(x)
            
        return x

class CD(C):
    """This layer represents the CD(k) layer described in the pix2pix paper. The activation function 
        is a parameter to allow the use of different activation functions like ReLU and leaky ReLU for 
        encoder and decoder. The sampling_factor gives a factor by which the convolution output will be 
        sampled up or down. A value of 2 will sample the tensor up by 2. A value of 0.5 will sample the 
        tensor down by 2."""
    
    def __init__(self, k, activation=None, sampling=None, batchnorm=True):
        super(CD, self).__init__(k, activation, sampling, batchnorm)
        self.dropout = tf.keras.layers.Dropout(rate=0.5)

    def call(self, x):
        x = self.conv(x)
        
        if self.batchnorm != None:
            x = self.batchnorm(x)
        
        x = self.dropout(x)
        
        if self.activation != None:
            x = self.activation(x)
            
        return x

### Define the discriminator
**16 x 16:**
- C64
- C128 
- conv to 1d
- sigmoid

**70 x 70:**
- C64 
- C128 
- C256 
- C512 
- conv to 1d 
- sigmoid
    
**286 x 286:**
- C64 
- C128 
- C256 
- C512 
- C512
- C512 
- conv to 1d 
- sigmoid

In [None]:
# 16 x 16 discriminator:
class Discriminator16(Model):
    def __init__(self):
        super(Discriminator16, self).__init__()
        self.conv1 = C(k=64, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down", batchnorm=False)
        self.conv2 = C(k=128, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        
        # flatten and dense with one neuron and sigmoid is the same as conv to 1D and sigmoid
        self.flatten = tf.keras.layers.Flatten()
        self.out = tf.keras.layers.Dense(1, activation=tf.keras.activations.sigmoid)
        
    def call(self, x, y):
        """Calls the discriminator with input x and generator output y"""
        x = tf.keras.layers.concatenate([x, y])
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.flatten(x)
        x = self.out(x)
        return x

In [None]:
# 70 x 70 discriminator:
class Discriminator70(Model):
    def __init__(self):
        super(Discriminator70, self).__init__()
        self.conv1 = C(k=64, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down", batchnorm=False)
        self.conv2 = C(k=128, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.conv3 = C(k=256, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.conv4 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        
        # flatten and dense with one neuron and sigmoid is the same as conv to 1D and sigmoid
        self.flatten = tf.keras.layers.Flatten()
        self.out = tf.keras.layers.Dense(1, activation=tf.keras.activations.sigmoid)
        
    def call(self, x, y):
        """Calls the discriminator with input x and generator output y"""
        x = tf.keras.layers.concatenate([x, y])
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.flatten(x)
        x = self.out(x)
        return x

In [None]:
# 286 x 286 discriminator:
class Discriminator286(Model):
    def __init__(self):
        super(Discriminator286, self).__init__()
        self.conv1 = C(k=64, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down", batchnorm=False)
        self.conv2 = C(k=128, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.conv3 = C(k=256, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.conv4 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.conv5 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.conv6 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        
        # flatten and dense with one neuron and sigmoid is the same as conv to 1D and sigmoid
        self.flatten = tf.keras.layers.Flatten()
        self.out = tf.keras.layers.Dense(1, activation=tf.keras.activations.sigmoid)
        
    def call(self, x, y):
        """Calls the discriminator with input x and generator output y"""
        x = tf.keras.layers.concatenate([x, y])
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.flatten(x)
        x = self.out(x)
        return x

### Define the Generator model

Autoencoder

encoder: (Leaky ReLU (slope = 0.2))
- C64
- C128
- C256
- C512
- C512
- C512
- C512
- C512

UNet decoder: (ReLU)
- CD512
- CD1024
- CD1024
- C1024
- C1024
- C512
- C256
- C128
- reduction to output channels
- tanh

In [None]:
class Generator(Model):
    def __init__(self, output_dim=3):
        super(Generator, self).__init__()
        
        # encoder:
        self.enc_conv1 = C(k=64, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down", batchnorm=False)
        self.enc_conv2 = C(k=128, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.enc_conv3 = C(k=256, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.enc_conv4 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.enc_conv5 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.enc_conv6 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.enc_conv7 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        self.enc_conv8 = C(k=512, activation=tf.keras.layers.LeakyReLU(alpha=0.2), sampling="down")
        
        # decoder
        self.dec_conv1 = CD(k=512, activation=tf.keras.activations.relu, sampling="up")
        self.dec_conv2 = CD(k=1024, activation=tf.keras.activations.relu, sampling="up")
        self.dec_conv3 = CD(k=1024, activation=tf.keras.activations.relu, sampling="up")
        self.dec_conv4 = C(k=1024, activation=tf.keras.activations.relu, sampling="up")
        self.dec_conv5 = C(k=1024, activation=tf.keras.activations.relu, sampling="up")
        self.dec_conv6 = C(k=512, activation=tf.keras.activations.relu, sampling="up")
        self.dec_conv7 = C(k=256, activation=tf.keras.activations.relu, sampling="up")
        self.dec_conv8 = C(k=128, activation=tf.keras.activations.relu, sampling="up")
        
        self.out = tf.keras.layers.Conv2D(output_dim, kernel_size=3, strides=1, activation=tf.keras.activations.tanh, padding='same')
        
    def call(self, x):
        # encoder
        x1 = self.enc_conv1(x)
        x2 = self.enc_conv2(x1)
        x3 = self.enc_conv3(x2)
        x4 = self.enc_conv4(x3)
        x5 = self.enc_conv5(x4)
        x6 = self.enc_conv6(x5)
        x7 = self.enc_conv7(x6)
        x8 = self.enc_conv8(x7)
        
        #decoder
        x = self.dec_conv1(x8)
        x = self.dec_conv2(tf.keras.layers.concatenate([x, x7]))
        x = self.dec_conv3(tf.keras.layers.concatenate([x, x6]))
        x = self.dec_conv4(tf.keras.layers.concatenate([x, x5]))
        x = self.dec_conv5(tf.keras.layers.concatenate([x, x4]))
        x = self.dec_conv6(tf.keras.layers.concatenate([x, x3]))
        x = self.dec_conv7(tf.keras.layers.concatenate([x, x2]))
        x = self.dec_conv8(tf.keras.layers.concatenate([x, x1]))
        
        # get three channels
        x = self.out(x)
        return x

## Define the combined pix2pix model

In [1]:
class Pix2pix(Model):
    """ """
    def __init__(self, discriminator=Discriminator70(), output_dim=3):
        super(Pix2pix, self).__init__()
        self.output_dim = output_dim
        
        self.g = Generator(output_dim=output_dim)
        self.d = discriminator
        
        self.cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.999)
    
    def _disc_loss(self, disc_real, disc_fake):
        """ """
        return self.cross_entropy(tf.ones_like(disc_real), disc_real) + self.cross_entropy(tf.zeros_like(disc_fake), disc_fake)
    
    def _gen_loss(self, y, generated, disc_fake):
        """ """
        gan_loss = self.cross_entropy(tf.ones_like(disc_fake), disc_fake)
        l1_loss = tf.reduce_mean(tf.abs(y - generated))
        return gan_loss + (6 / self.output_dim) * l1_loss
    
    def split_dataset(self, x, y, validation_split=0.0):
        """ """
        num_test = math.floor(x.shape[0] * validation_split)

        test_x = np.array(x[:num_test], dtype=np.float32)
        test_y = np.array(y[:num_test], dtype=np.float32)

        train_x = np.array(x[num_test:], dtype=np.float32)
        train_y = np.array(y[num_test:], dtype=np.float32)
        
        return ((train_x, train_y), (test_x, test_y))
    
    def predict_on_batch(self, x):
        """ """
        return self.g(x)
    
    def predict(self, x, batch_size=None, verbose=0, steps=None, callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False):
        """ """
        dataset = tf.data.Dataset.from_tensor_slices(x)
        if batch_size == None:
            dataset = dataset.batch(batch_size=1)
        else:
            dataset = dataset.batch(batch_size=batch_size)
        
        if callbacks != None:
            for callback in callbacks:
                callback.on_predict_begin()
        
        result = np.array([])
        for n, x1 in dataset.enumerate():
            if callbacks != None:
                for callback in callbacks:
                    callback.on_predict_batch_begin(n)
                    
            output = self.predict_on_batch(x1)
            if result.size == 0:
                result = output.numpy()
            else:
                result = np.concatenate((result, output.numpy()), axis=0)
            
            if callbacks != None:
                for callback in callbacks:
                    callback.on_predict_batch_end(n)
            
            if steps != None and n >= steps:
                break
        
        if callbacks != None:
            for callback in callbacks:
                callback.on_predict_end()
                
        return result
    
    def fit(self, x, y, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None, validation_freq=1, max_queue_size=10, workers=1, use_multiprocessing=False):
        """ """
        # get or generate validation data
        if validation_data != None:
            train_x = x
            train_y = y
            
            test_x = validation_data[0]
            test_y = validation_data[1]
        else:
            ((train_x, train_y), (test_x, test_y)) = self.split_dataset(x, y, validation_split=validation_split)
        
        # generate tf datasets from data
        train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
        test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y))

        if shuffle:
            train_dataset = train_dataset.shuffle(buffer_size=50000)

        if batch_size == None:
            train_dataset = train_dataset.batch(batch_size=1)
        else:
            train_dataset = train_dataset.batch(batch_size=batch_size)
        test_dataset = test_dataset.batch(10000)
        
        # prepare callbacks
        if callbacks != None:
            params = {
                'batch_size': batch_size,
                'epochs': epochs,
                'steps': steps_per_epoch,
                'samples': train_x.shape[0],
                'verbose': verbose,
                'do_validation': validation_steps != None,
                'metrics': []
            }
            for callback in callbacks:
                callback.set_params(params)
                
                callback.set_model(self)
        
        # train
        if callbacks != None:
            for callback in callbacks:
                callback.on_train_begin()
        
        for epoch in range(initial_epoch, epochs):
            if callbacks != None:
                for callback in callbacks:
                    callback.on_epoch_begin(epoch)
                    
            for n, (x, y) in train_dataset.enumerate():
                if callbacks != None:
                    for callback in callbacks:
                        callback.on_batch_begin(n)
                        
                model.train_on_batch(x, y)
                
                if callbacks != None:
                    for callback in callbacks:
                        callback.on_batch_end(n, logs={'size': x.shape[0]})
                
                if steps_per_epoch != None and n >= steps_per_epoch:
                    break
                    
            if callbacks != None:
                for callback in callbacks:
                    callback.on_epoch_end(epoch)
        
        if callbacks != None:
            for callback in callbacks:
                callback.on_train_end()
    
    def train_on_batch(self, x, y, sample_weight=None, class_weight=None, reset_metrics=True):
        """ """
        # train discriminator
        with tf.GradientTape() as tape:
            generated = self.g(x)
            disc_real = self.d(x, y)
            disc_fake = self.d(x, generated)
            
            loss = self._disc_loss(disc_real, disc_fake)
            
        gradients = tape.gradient(loss, self.d.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.d.trainable_variables))
        
        # train generator
        with tf.GradientTape() as tape:
            generated = self.g(x)
            disc_fake = self.d(x, generated)
            
            loss = self._gen_loss(y, generated, disc_fake)
            
        gradients = tape.gradient(loss, self.g.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.g.trainable_variables))
        
    def test_on_batch(self, x, y, sample_weight=None, reset_metrics=True):
        """ """
        generated = self.g(x)
        disc_real = self.d(x, y)
        disc_fake = self.d(x, generated)

        disc_loss = self._disc_loss(disc_real, disc_fake)
        gen_loss = self._gen_loss(y, generated, disc_fake)
        
        return np.array([disc_loss, gen_loss])
            

NameError: name 'Model' is not defined