In [3]:
!pip install git+https://github.com/tensorflow/examples.git
!pip install tensorflow
!pip install IPython

Collecting git+https://github.com/tensorflow/examples.git
  Cloning https://github.com/tensorflow/examples.git to /tmp/pip-req-build-zreg_u0r
  Running command git clone --filter=blob:none --quiet https://github.com/tensorflow/examples.git /tmp/pip-req-build-zreg_u0r
  Resolved https://github.com/tensorflow/examples.git to commit 4a468da622d8827db09c57a4df1f3bb59d95d621
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: tensorflow-examples
  Building wheel for tensorflow-examples (setup.py) ... [?25l[?25hdone
  Created wheel for tensorflow-examples: filename=tensorflow_examples-0.1733136619.424038708570368490133280075616484262886170285601-py3-none-any.whl size=301602 sha256=39798cbf825ed338a8b79942e4dff3b68270970476f688fc8da0cfb00f4ee9cb
  Stored in directory: /tmp/pip-ephem-wheel-cache-k7n84j5w/wheels/72/5f/d0/7fe769eaa229bf20101d11a357eb23c83c481bee2d7f710599
Successfully built tensorflow-examples
Installing collected packages: tensorflow-e

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#import libraries

import tensorflow as tf
import os
from tensorflow_examples.models.pix2pix import pix2pix
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

In [5]:
###### Training code block

AUTOTUNE = tf.data.AUTOTUNE

# Choose style
# style = "van_gogh"
style = "Ukiyo_e"
#style = "engraving"

# Define hyper-parameters
EPOCHS = 30
LAMBDA = 15

OUTPUT_PATH = f"/content/drive/MyDrive/Colab Notebooks/preprocessed_dataset_{style}"

train_content = tf.data.experimental.load(
    os.path.join(OUTPUT_PATH, 'train_content'),
    element_spec=tf.TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32)
)

train_style = tf.data.experimental.load(
    os.path.join(OUTPUT_PATH, 'train_style'),
    element_spec=tf.TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32)
)

print(f"Datasets of {style} loaded successfully!")


sample_content = next(iter(train_content))
sample_style = next(iter(train_style))

OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)




loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)


def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5


def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)


def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

  return LAMBDA * loss1

def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

checkpoint_path = f"/content/drive/MyDrive/Colab Notebooks/checkpoints_{style}_e{EPOCHS}_L{LAMBDA}/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')


output_dir = f'/content/drive/MyDrive/Colab Notebooks/{style}_training_images_epochs{EPOCHS}_lambda{LAMBDA}'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

def generate_images(model, test_input):
  prediction = model(test_input)

  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.savefig(os.path.join(output_dir, f'generated_image_epoch_{epoch+1}.png'))
  plt.close()


@tf.function
def train_step(real_x, real_y):
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.

    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)

    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)

    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss,
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss,
                                        generator_f.trainable_variables)

  discriminator_x_gradients = tape.gradient(disc_x_loss,
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss,
                                            discriminator_y.trainable_variables)

  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients,
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients,
                                            generator_f.trainable_variables))

  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))

  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))


for epoch in range(EPOCHS):
  print(f"Epoch {epoch+1}/{EPOCHS}")
  start = time.time()

  n = 0
  for image_x, image_y in zip(train_content, train_style):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n += 1

  clear_output(wait=True)

  generate_images(generator_g, sample_content)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

def resize_images(image):
    return tf.image.resize(image, [256, 256])



Saving checkpoint for epoch 30 at /content/drive/MyDrive/Colab Notebooks/checkpoints_engraving_e30_L15/train/ckpt-6
Time taken for epoch 30 is 24.34049415588379 sec



In [None]:
###### Image generation code block

# Choose style

# style = "van_gogh"
style = "Ukiyo_e"
#style = "engraving"

# Define hyper-parameters

EPOCHS = 30
LAMBDA = 15

version = f"{style}_e{EPOCHS}_L{LAMBDA}"

checkpoint_path = f'/content/drive/MyDrive/Colab Notebooks/checkpoints_{version}/train'
test_content_path = '/content/drive/MyDrive/Colab Notebooks/test/content_test'  # Corrected path


OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)


generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)


ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)


if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print('Latest checkpoint restored!!')
else:
    print('No checkpoint found. Initializing from scratch.')


def load_image(image_file):
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [256, 256])
    image = (image / 127.5) - 1
    return image, image_file


test_content = tf.data.Dataset.list_files(os.path.join(test_content_path, '*.jpg'))
test_content = test_content.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
test_content = test_content.batch(1)


output_dir = f'/content/drive/MyDrive/Colab Notebooks/generated_images_{version}'

os.makedirs(output_dir, exist_ok=True)


def generate_images(model, test_input, image_file):
    prediction = model(test_input)

    plt.figure(figsize=(12, 12))

    display_list = [test_input[0], prediction[0]]
    plt.imshow(display_list[1] * 0.5 + 0.5)
    output_image = tf.image.convert_image_dtype((prediction[0] * 0.5 + 0.5), dtype=tf.uint8)
    output_image = tf.image.encode_jpeg(output_image)

    filename = os.path.basename(image_file.numpy()[0].decode('utf-8'))
    tf.io.write_file(os.path.join(output_dir, filename), output_image)

    plt.close()

for idx, (test_input, image_file) in enumerate(test_content):
    generate_images(generator_g, test_input, image_file)

print("Image generation complete.")

Latest checkpoint restored!!
Image generation complete.
