In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import tensorflow as tf
import os
import re
import tqdm
import PIL
import shutil

In [None]:
# CHANGE ME!!!
SAVE = "../input/allaug-cyclegan20200623/allaug_CycleGan_IGDK_sd2021_0"

AUTO = tf.data.experimental.AUTOTUNE
HEIGHT, WIDTH, CHANNELS = 256, 256, 3

### Dataset Functions:

In [None]:
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)


def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=CHANNELS)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [HEIGHT, WIDTH, CHANNELS])
    return image


def read_tfrecord(example):
    tfrecord_format = {
        'image_name': tf.io.FixedLenFeature([], tf.string),
        'image': tf.io.FixedLenFeature([], tf.string),
        'target': tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image


def load_dataset(filenames):
    global AUTO
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO)
    return dataset

### Create "Photo" dataset:

In [None]:
GCS_PATH = "/kaggle/input/gan-getting-started"
PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))
n_photo_samples = count_data_items(PHOTO_FILENAMES)
photo_ds = load_dataset(PHOTO_FILENAMES)


### Predict and Save function:

In [None]:
def predict_and_save(path, input_ds, generator_model):
    i = 1
    for img in input_ds:
        prediction = generator_model(img, training=False)[0].numpy()  # make predition
        prediction = (prediction * 127.5 + 127.5).astype(np.uint8)  # re-scale
        im = PIL.Image.fromarray(prediction)
        im.save(os.path.join(path, "images", '{}.jpg'.format(i)))
        i += 1

In [None]:
chkpath = os.path.join(SAVE, "checkpoints")
print("Loading model from {}".format(os.path.join(chkpath, "model")))
# loaded_model = tf.saved_model.load(os.path.join(chkpath, "model"))
loaded_model = tf.keras.models.load_model(os.path.join(chkpath, "model"))

In [None]:
os.mkdir("./images")
predict_and_save(
    path="./",
    input_ds=photo_ds.batch(1),
    generator_model=loaded_model.m_gen        # Monet Generator
)

In [None]:
images_path = os.path.join(".", 'images')
shutil.make_archive(images_path, 'zip', "images")
print('| Generated samples: {}'.format(
    len([name for name in os.listdir(images_path) if os.path.isfile(os.path.join(images_path, name))])
))