In [None]:
import tensorflow as tf
import keras
tf.__version__

import os
import random
import numpy as np
from PIL import Image
import pathlib

import matplotlib.pyplot as plt
%matplotlib inline

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [None]:
from model_settings import GAN, Generator, Discriminator
from utils import normalize_images

In [None]:
SHAPE = [476,476,3]

CURRENT_EPOCH = 0
BATCH_SIZE = 8
EPOCHS = 30
noise_level = 3
val_size = BATCH_SIZE*100

RANDOM_SEED = 101
random.seed(RANDOM_SEED)
AUTOTUNE = tf.data.AUTOTUNE

latent_dim = 512

input_folder = r"" # Folder where GEE imagery is stored
output_folder = r"" # Folder where output results will be saved

In [None]:
if not os.path.isdir(output_folder):
    os.mkdir(output_folder)

data_dir = pathlib.Path(input_folder)
import_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  labels=None,
  shuffle=True,
  image_size=(SHAPE[0],SHAPE[1]),
  seed=RANDOM_SEED,
  batch_size=None)

In [None]:
data_augmentation = keras.Sequential(
                [keras.layers.RandomFlip("horizontal_and_vertical"),
                 keras.layers.GaussianNoise(noise_level)])

train_ds = import_ds.skip(val_size).shuffle(512).map(lambda x: normalize_images(data_augmentation(x, training=True)),num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)
val_ds = import_ds.take(val_size).map(lambda x: normalize_images(x)).batch(1)

In [None]:
plt.figure(figsize=(15, 15))
for i,image in enumerate(train_ds.take(25)):
    plt.subplot(5, 5, i + 1)
    img = image.numpy()[0,:,:,:]*255
    plt.imshow(img.astype('uint8'))
    plt.axis("off")
plt.subplots_adjust(hspace=0.01, wspace=0.01)

In [None]:
seed = tf.random.normal([25, latent_dim])
np.save(os.path.join(output_folder,'seed'),seed)

#### Custom Model

In [None]:
generator = Generator(SHAPE)
generator.summary()
discriminator = Discriminator(SHAPE)
discriminator.summary()

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001,beta_1=0.25)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.000025,beta_1=0.25)
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=False)

In [None]:
checkpoint_prefix = os.path.join(output_folder, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)
# checkpoint.restore(r"").expect_partial() # Optional checkpoint restoration function
# seed = np.load(r"") # Optional seed loading step

In [None]:
model = GAN(generator, discriminator, latent_dim)
model.compile(generator_optimizer, discriminator_optimizer, loss_fn)

In [None]:
class SaveImageCallback(keras.callbacks.Callback):
    def __init__(self):
        super(SaveImageCallback, self).__init__()

    def generate_and_save_images(self, generator, epoch, test_image):
        epoch = epoch+1
        predictions = generator(test_image, training=False)
        fig = plt.figure(figsize=(12, 12))
        for i in range(predictions.shape[0]):
            plt.subplot(5, 5, i+1)
            arr = predictions[i, :, :, :]*255
            arr = arr.numpy().astype("uint8")
            plt.imshow(arr)
            plt.axis('off')
        plt.subplots_adjust(hspace=0.01, wspace=0.01)
        plt.tight_layout()
        plt.savefig(os.path.join(output_folder,'image_at_epoch_{:04d}.png'.format(epoch)))
        plt.close()

    def on_epoch_end(self, epoch, logs=None):
        global CURRENT_EPOCH
        global train_ds
        CURRENT_EPOCH = epoch+1
        self.generate_and_save_images(generator,epoch,seed)
        noise_calc = noise_level-(noise_level*CURRENT_EPOCH/(EPOCHS))*2
        if noise_calc <= 0:
            noise_calc = 0
        data_augmentation = keras.Sequential(
                    [keras.layers.RandomFlip("horizontal_and_vertical"),
                     keras.layers.GaussianNoise(noise_calc)])
        train_ds = import_ds.skip(val_size).shuffle(512)\
                        .map(lambda x: normalize_images(data_augmentation(x, training=True)),num_parallel_calls=AUTOTUNE)\
                        .batch(BATCH_SIZE).prefetch(AUTOTUNE)
        if CURRENT_EPOCH % 2 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)

In [None]:
history = model.fit(train_ds, epochs=EPOCHS, validation_data=val_ds, callbacks=[SaveImageCallback()])

In [None]:
# results = [model.evaluate(val,verbose=0)[0][0] for val in val_ds.take(50).as_numpy_iterator()]
# print(np.asarray(results).mean(),np.asarray(results).std())

In [None]:
model.save_weights(os.path.join(output_folder,"gan_model.h5"))
generator.save(os.path.join(output_folder,"generator_weights.h5"))
discriminator.save(os.path.join(output_folder,"discriminator_weights.h5"))
np.save(os.path.join(output_folder,'seed'),seed)

In [None]:
predictions = generator(seed, training=False)

fig = plt.figure(figsize=(15, 15))
# print(predictions.numpy())
for i in range(predictions.shape[0]):
    plt.subplot(5, 5, i+1)
    arr = predictions[i, :, :, :]*255
    arr = arr.numpy().astype("uint8")
    plt.imshow(arr)
    plt.axis('off')
plt.subplots_adjust(hspace=0.01, wspace=0.01)
plt.tight_layout()

In [None]:
plt.plot(history.history['Gen Loss'])
plt.plot(history.history['Disc Loss'])
plt.plot([v[0][0] for v in history.history['val_discriminator']])
plt.title('Model Losses')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Generator', 'Discriminator','Validation'], loc='upper left')
plt.show()
plt.savefig('gan_fig.png')