In [1]:
!pip install opendatasets



In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
%cd /content/drive/MyDrive/GAN/pix2pix

/content/drive/MyDrive/GAN/pix2pix


## Import required package:

In [5]:
from data_loader import create_dataset
from discriminator import Discriminator
from generator import Generator
from utils import plot_generator_loss, generate_and_visualize, plot_loss

import tensorflow as tf
import time
from tqdm import tqdm
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy, MAE
import os
import matplotlib.pyplot as plt

from IPython import display
import imageio.v2 as imageio

## Data preparation:

In [5]:
import opendatasets as od
od.download('https://www.kaggle.com/datasets/alincijov/pix2pix-maps')

Skipping, found downloaded files in "./pix2pix-maps" (use force=True to force download)


In [6]:
TRAIN_FOLDER = 'pix2pix-maps/train'
VAL_FOLDER = 'pix2pix-maps/val'

BATCH_SIZE = 1
BUFFER_SIZE = 256
SHUFFLE = True

train_dataset = create_dataset(TRAIN_FOLDER, BATCH_SIZE, BUFFER_SIZE, SHUFFLE)
val_dataset = create_dataset(VAL_FOLDER, BATCH_SIZE, BUFFER_SIZE, SHUFFLE)

1096
1098


## Hyperparameter configuration

In [7]:
LEARNING_RATE = 2e-4
BETA_1 = 0.5
BETA_2 = 0.999

L1_SCALE = 100


## Training

In [None]:
# Training settings
EPOCHS = 200

# Create optimizer
gen_optim = Adam(learning_rate=LEARNING_RATE, beta_1=BETA_1, beta_2=BETA_2)
disc_optim = Adam(learning_rate=LEARNING_RATE, beta_1=BETA_1, beta_2=BETA_2)

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

# Create checkpoint
checkpoint_dir = './checkpoint'
checkpoint_prefix = os.path.join(checkpoint_dir, 'pix2pix-ckpt')
checkpoint = tf.train.Checkpoint(generator=generator, discriminator=discriminator,
                                 generator_optim=gen_optim, discriminator_optim=disc_optim)

# Prepare val input
test_inputs = []
test_targets = []
for x, y in val_dataset.take(6):
    test_inputs.append(x)
    test_targets.append(y)
test_inputs = tf.concat(test_inputs, axis=0)
test_targets = tf.concat(test_targets, axis=0)

def train(dataset, continue_training=False, checkpoint_dir=None, start_epoch=0):
    if continue_training:
        checkpoint.restore(checkpoint_dir)

    G_losses = []
    D_losses = []
    # Initialize loss
    bce_loss = BinaryCrossentropy()
    for epoch in range(start_epoch, EPOCHS):
        start = time.perf_counter()

        for batch_idx, (input_image, target_image) in enumerate(tqdm(dataset)):
            # Train discriminator
            with tf.GradientTape() as d_tape:
                fake_image = generator(input_image, training=True)
                true_disc = discriminator(input_image, target_image, training=True)
                true_loss = bce_loss(tf.ones_like(true_disc), true_disc)
                fake_disc = discriminator(input_image, fake_image, training=True)
                fake_loss = bce_loss(tf.zeros_like(fake_disc), fake_disc)

                # Slow down the discriminator by dividing the object by 2
                d_loss = (true_loss + fake_loss) * 0.5

            d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
            # Update weights
            disc_optim.apply_gradients(zip(d_grads, discriminator.trainable_variables))
            D_losses.append(d_loss)

            # Train generator
            with tf.GradientTape() as g_tape:
                # Create new fake image to seperate it from the discriminator calculation graph
                fake_image = generator(input_image, training=True)
                gen_fake_disc = discriminator(input_image, fake_image, training=True)
                gen_fake_loss = bce_loss(tf.ones_like(gen_fake_disc), gen_fake_disc)

                l1_loss = tf.reduce_mean(MAE(target_image, fake_image))

                g_loss = gen_fake_loss + L1_SCALE * l1_loss

            g_grads = g_tape.gradient(g_loss, generator.trainable_variables)
            # Update weights
            gen_optim.apply_gradients(zip(g_grads, generator.trainable_variables))
            G_losses.append(g_loss)

            # Print result for monitoring
            if (batch_idx + 1) % 100 == 0 and batch_idx > 0:
                print(d_loss, g_loss)
                print(
                    f"Epoch [{epoch}/{EPOCHS}] Batch {batch_idx}/{len(dataset)} \
                      Loss D: {d_loss:.4f}, loss G: {g_loss:.4f}"
                )

        # Monitor training process
        display.clear_output(wait=True)
        print(f'Traning epoch {epoch} takes {time.perf_counter() - start} seconds')
        generate_and_visualize(generator, test_inputs, test_targets, epoch)
        plot_loss(G_losses, D_losses)

        if (epoch + 1) % 10 == 0:
            checkpoint.save(checkpoint_prefix)

    print('Done!')
    # Display the last result
    display.clear_output(wait=True)
    print(f'Traning epoch {epoch} takes {time.perf_counter() - start} seconds')
    generate_and_visualize(generator, test_inputs, test_targets, EPOCHS)
    plot_loss(G_losses, D_losses)


train(train_dataset, continue_training=True, checkpoint_dir = checkpoint_prefix + '-7',
              start_epoch=70)

Output hidden; open in https://colab.research.google.com to view.

In [None]:
import cv2
output_images = generator.predict(test_inputs)

fig = plt.figure(figsize=(10, 10))
for i in range(6):
    plt.subplot(3, 2, i + 1)
    final_img = cv2.hconcat([(test_inputs[i] * 127.5 + 127.5).numpy().astype('uint8'),
                            (test_targets[i] * 127.5 + 127.5).numpy().astype('uint8'),
                            (output_images[i] * 127.5 + 127.5).astype('uint8')])
    plt.imshow(final_img)
    plt.axis('off')

In [6]:
with open('pix2pix_training_process.gif', 'w+'):
    pass

frames = []
for epoch in range(76):
    image = imageio.imread('training-images/image_at_epoch_{:04d}.png'.format(epoch))
    frames.append(image)

imageio.mimsave('pix2pix_training_process.gif', frames)

# this is a hack to display the gif inside the notebook
os.system('cp pix2pix_training_process.gif pix2pix_training_process.gif.png')

display.Image(filename="pix2pix_training_process.gif.png")