In [None]:
USE_CACHED_MODELS=False

this is just the code required for the monet generator challenge. see the notebook "you never give me your monet" to see where all this came from.

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
#import tensorflow_addons as tfa

from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    if tf.config.list_physical_devices('GPU'):
        strategy = tf.distribute.MirroredStrategy()
        print("Running on GPU.")
    else:
        strategy = tf.distribute.get_strategy()
        print("Running on CPU.")
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
    
print(tf.__version__)

GCS_PATH = KaggleDatasets().get_gcs_path()
MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
print('Monet TFRecord Files:', len(MONET_FILENAMES))

PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))
print('Photo TFRecord Files:', len(PHOTO_FILENAMES))
IMAGE_SIZE = [256, 256]

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

def load_dataset(filenames, labeled=True, ordered=False):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset


OUTPUT_CHANNELS = 3

def downsample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        # casey: replaced defunct tfa
        result.add(layers.GroupNormalization(groups=-1, gamma_initializer=gamma_init))
        #result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))
        
    result.add(layers.LeakyReLU())

    return result

def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

    # casey: replaced defunct tfa
    result.add(layers.GroupNormalization(groups=-1, gamma_initializer=gamma_init))
    #result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))
    if apply_dropout:
        result.add(layers.Dropout(0.5))

    result.add(layers.ReLU())

    return result

def Generator():
    inputs = layers.Input(shape=[256,256,3])

    # bs = batch size
    down_stack = [
        downsample(64, 4, apply_instancenorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                  strides=2,
                                  padding='same',
                                  kernel_initializer=initializer,
                                  activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    x = last(x)

    return keras.Model(inputs=inputs, outputs=x)

def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = layers.Input(shape=[256, 256, 3], name='input_image')

    x = inp

    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)
    
    # casey: replace defunct tfa
    #norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)
    norm1 = layers.GroupNormalization(groups=-1, gamma_initializer=gamma_init)(conv)
   
    leaky_relu = layers.LeakyReLU()(norm1)

    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=inp, outputs=last)


my stuff starts after this.

In [None]:
def getAltDiscriminator():
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
    initializer = tf.random_normal_initializer(0., 0.02)
    return keras.Sequential(
        [
            keras.Input(shape=(256,256,3)),
            layers.ZeroPadding2D(),

            layers.Conv2D(64, kernel_size=(3, 3), activation="leaky_relu"),
            layers.Conv2D(64, kernel_size=(3, 3), activation="leaky_relu"),
            layers.Conv2D(64, kernel_size=(3, 3), activation="leaky_relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            
            ## the following is basically the same as the original discriminator
            ## output needs to be 30x30x1
            downsample(128,4),
            downsample(256,4),
            layers.ZeroPadding2D(),
            layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False),
            layers.GroupNormalization(groups=-1, gamma_initializer=gamma_init),
            layers.LeakyReLU(),
            layers.ZeroPadding2D(),

            layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)

        ]
    )

In [None]:
from keras.layers import Input, concatenate, Conv2D, Add, MaxPooling2D, Conv2DTranspose, BatchNormalization, GroupNormalization
from keras.models import Model

def get_connected_unet2():
    def aspp_block(x, num_filters, rate_scale=1):
        x1 = Conv2D(num_filters, (3, 3), dilation_rate=(6 * rate_scale, 6 * rate_scale), padding="same")(x)
        x1 = BatchNormalization()(x1)

        x2 = Conv2D(num_filters, (3, 3), dilation_rate=(12 * rate_scale, 12 * rate_scale), padding="same")(x)
        x2 = BatchNormalization()(x2)

        x3 = Conv2D(num_filters, (3, 3), dilation_rate=(18 * rate_scale, 18 * rate_scale), padding="same")(x)
        x3 = BatchNormalization()(x3)

        x4 = Conv2D(num_filters, (3, 3), padding="same")(x)
        x4 = BatchNormalization()(x4)

        y = Add()([x1, x2, x3, x4])
        y = Conv2D(num_filters, (1, 1), padding="same")(y)
        return y 

    def get_wnet():
        inputs = Input((256, 256, 3))
        conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
        conv1 = BatchNormalization()(conv1)
        conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
        conv1 = BatchNormalization()(conv1)
        pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

        conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
        conv2 = BatchNormalization()(conv2)
        conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
        conv2 = BatchNormalization()(conv2)
        pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

        conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
        conv3 = BatchNormalization()(conv3)
        conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
        conv3 = BatchNormalization()(conv3)
        pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

        conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
        conv4 = BatchNormalization()(conv4)
        conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
        conv4 = BatchNormalization()(conv4)
        pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

        conv5 = aspp_block(pool4, 512)

        up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
        conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
        conv6 = BatchNormalization()(conv6)
        conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)
        conv6 = BatchNormalization()(conv6)

        up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
        conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
        conv7 = BatchNormalization()(conv7)
        conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)
        conv7 = BatchNormalization()(conv7)

        up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
        conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
        conv8 = BatchNormalization()(conv8)
        conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
        conv8 = BatchNormalization()(conv8)

        up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
        conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
        conv9 = BatchNormalization()(conv9)
        conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)
        conv9 = BatchNormalization()(conv9)

        down10 = concatenate([Conv2D(32, (3, 3), activation='relu', padding='same')(conv9), conv9], axis=3)
        conv10 = Conv2D(32, (3, 3), activation='relu', padding='same')(down10)
        conv10 = BatchNormalization()(conv10)
        conv10 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv10)
        conv10 = BatchNormalization()(conv10)    
        pool10 = MaxPooling2D(pool_size=(2, 2))(conv10)

        down11 = concatenate([Conv2D(64, (3, 3), activation='relu', padding='same')(pool10), conv8], axis=3)
        conv11 = Conv2D(64, (3, 3), activation='relu', padding='same')(down11)
        conv11 = BatchNormalization()(conv11)
        conv11 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv11)
        conv11 = BatchNormalization()(conv11)   
        pool11 = MaxPooling2D(pool_size=(2, 2))(conv11)

        down12 = concatenate([Conv2D(128, (3, 3), activation='relu', padding='same')(pool11), conv7], axis=3)
        conv12 = Conv2D(128, (3, 3), activation='relu', padding='same')(down12)
        conv12 = BatchNormalization()(conv12)
        conv12 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv12)    
        conv12 = BatchNormalization()(conv12)
        pool12 = MaxPooling2D(pool_size=(2, 2))(conv12)

        down13 = concatenate([Conv2D(256, (3, 3), activation='relu', padding='same')(pool12), conv6], axis=3)
        conv13 = Conv2D(256, (3, 3), activation='relu', padding='same')(down13)
        conv13 = BatchNormalization()(conv13)
        conv13 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv13)  
        conv13 = BatchNormalization()(conv13)    
        pool13 = MaxPooling2D(pool_size=(2, 2))(conv13)

        conv14 = aspp_block(pool13, 512)

        up15 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv14), conv13], axis=3)
        conv15 = Conv2D(256, (3, 3), activation='relu', padding='same')(up15)
        conv15 = BatchNormalization()(conv15)    
        conv15 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv15)
        conv15 = BatchNormalization()(conv15) 

        up16 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv15), conv12], axis=3)
        conv16 = Conv2D(128, (3, 3), activation='relu', padding='same')(up16)
        conv16 = BatchNormalization()(conv16)     
        conv16 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv16)
        conv16 = BatchNormalization()(conv16)      

        up17 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv16), conv11], axis=3)
        conv17 = Conv2D(64, (3, 3), activation='relu', padding='same')(up17)
        conv17 = BatchNormalization()(conv17)      
        conv17 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv17)
        conv17 = BatchNormalization()(conv17)  

        #casey: switched these to leaky relu, I think the relus are forcing the more 
        #extreme primary color patches.
        up18 = concatenate([Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(conv17), conv10], axis=3)
        conv18 = Conv2D(32, (3, 3), activation='leaky_relu', padding='same')(up18)
        conv18 = BatchNormalization()(conv18)      
        conv18 = Conv2D(32, (3, 3), activation='leaky_relu', padding='same')(conv18)
        conv18 = BatchNormalization()(conv18)
        
        conv18 = aspp_block(conv18, 3)
        conv19 = BatchNormalization()(conv18)

        model = Model(inputs=[inputs], outputs=[conv19])
        return model
    return get_wnet()

In [None]:
with strategy.scope():
    monet_generator5 = get_connected_unet2()# transforms photos to Monet-esque paintings
    photo_generator5 = get_connected_unet2() # transforms Monet paintings to be more like photos

    monet_discriminator5 = getAltDiscriminator() # differentiates real Monet paintings and generated Monet paintings
    photo_discriminator5 = getAltDiscriminator() # differentiates real photos and generated photos
    
    monet_generator_optimizer = tf.keras.optimizers.Adam()#2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam()#2e-4, beta_1=0.5)

    monet_discriminator_optimizer = tf.keras.optimizers.Adam()#2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam()#2e-4, beta_1=0.5)

    cycle_gan_model5 = CycleGan(
        monet_generator5, photo_generator5, 
        monet_discriminator5, photo_discriminator5
    )

    cycle_gan_model5.compile(
        m_gen_optimizer = monet_generator_optimizer,
        p_gen_optimizer = photo_generator_optimizer,
        m_disc_optimizer = monet_discriminator_optimizer,
        p_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )
    

    class CustomCallback(keras.callbacks.Callback):
        def on_epoch_end(self, epoch, logs=None):
            self.model.m_gen.save(f"model5.{epoch}.keras")
            plot_results(self.model.m_gen)
            
SUBMIT = True # TODO: move
if not USE_CACHED_MODELS:
    hist5 = cycle_gan_model5.fit(
        tf.data.Dataset.zip((monet_ds, photo_ds)),
        callbacks=[CustomCallback()],
        epochs=25
    )
    monet_generator5.save("monet_generator5.keras")
else:
    monet_generator5 = tf.keras.models.load_model("monet_generator5.keras")