<a href="https://colab.research.google.com/github/erikroruiz/OCT_style_transfer/blob/main/pix2pix_oct_v3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# pix2pix: Image-to-image translation with a conditional GAN

This tutorial demonstrates how to build and train a conditional generative adversarial network (cGAN) called pix2pix that learns a mapping from input images to output images, as described in [Image-to-image translation with conditional adversarial networks](https://arxiv.org/abs/1611.07004) by Isola et al. (2017). pix2pix is not application specific—it can be applied to a wide range of tasks, including synthesizing photos from label maps, generating colorized photos from black and white images, turning Google Maps photos into aerial images, and even transforming sketches into photos.

In this example, your network will generate images of building facades using the [CMP Facade Database](http://cmp.felk.cvut.cz/~tylecr1/facade/) provided by the [Center for Machine Perception](http://cmp.felk.cvut.cz/) at the [Czech Technical University in Prague](https://www.cvut.cz/). To keep it short, you will use a [preprocessed copy]((https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/)) of this dataset created by the pix2pix authors.

In the pix2pix cGAN, you condition on input images and generate corresponding output images. cGANs were first proposed in [Conditional Generative Adversarial Nets](https://arxiv.org/abs/1411.1784) (Mirza and Osindero, 2014)

The architecture of your network will contain:

- A generator with a [U-Net]([U-Net](https://arxiv.org/abs/1505.04597))-based architecture.
- A discriminator represented by a convolutional PatchGAN classifier (proposed in the [pix2pix paper](https://arxiv.org/abs/1611.07004)).




### Comprobar el entorno GPU

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

### Comprobar el entorno RAM

In [None]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

## Importamos TensorFlow y otras librerias necesarias.

In [None]:
import tensorflow as tf

import os
import cv2 
import time
import pathlib
import datetime
import numpy as np
from PIL import Image
from google.colab import drive
# import matplotlib.pyplot as plt

from matplotlib import pyplot as plt
from IPython import display

In [None]:



#Creamos la carpeta origen de Google Drive
BASE_FOLDER = '/content/drive/My Drive/TFM/'
# Montamos la carpeta en Google drive
drive.mount('/content/drive/')

In [None]:
# Nombre del experimento
EXPERIMENT_NAME = "pix2pix_119c_119n_paired_bs1"
# Ubicación de las imágenes de train y test
PATH_TRAIN = BASE_FOLDER+'datasets/images_pix2pix/119c_119n_paired/train'
PATH_TEST = BASE_FOLDER+'datasets/images_pix2pix/119c_119n_paired/test'
# The oct training set consist of 364 images
BUFFER_SIZE = 119

# The batch size of 1 produced better results for the U-Net in the original pix2pix experiment
BATCH_SIZE = 1
# Each image is 256x256 in size
IMG_WIDTH = 256
IMG_HEIGHT = 256
# Número de steps de entrenamiento (múltiplos de 1000) 200 EPOCHS = 80000 STEPS
N_STEPS = 80000



# Para el generador U-Net
OUTPUT_CHANNELS = 3

OPTIMIZER = "Adam"

# Adam optimizers parameters
LEARNING_RATE = 2e-4
BETA1=0.5



## Load the dataset

### Mostramos una imagen de train

In [None]:
sample_image = tf.io.read_file(str(PATH_TRAIN+'/001.jpg'))
sample_image = tf.io.decode_jpeg(sample_image)
print(sample_image.shape)

In [None]:
plt.figure()
plt.imshow(sample_image)

### Mostramos una imagen de test

In [None]:
sample_image = tf.io.read_file(str(PATH_TEST+'/001.jpg'))
sample_image = tf.io.decode_jpeg(sample_image)
print(sample_image.shape)

In [None]:
plt.figure()
plt.imshow(sample_image)

Cada imagen original tiene un tamaño de `300 x 600` contenindo dos imágenes de  `300 x 300` pixeles:

Hay que separar la imagen con ruido de la imagen sin ruido. Las dos imágenes resultantes tendrán un tamaño de 300 x 300.

Se define una función que separa la imagen de entrada en dos imágenes:

In [None]:
def load(image_file):
  # Read and decode an image file to a uint8 tensor
  image = tf.io.read_file(image_file)
  image = tf.image.decode_jpeg(image)

  # Split each image tensor into two tensors:
  # - one with a real building facade image
  # - one with an architecture label image 
  w = tf.shape(image)[1]
  w = w // 2
  input_image = image[:, w:, :]
  real_image = image[:, :w, :]

  # Convert both images to float32 tensors
  input_image = tf.cast(input_image, tf.float32)
  real_image = tf.cast(real_image, tf.float32)

  return input_image, real_image

In [None]:
inp, re = load(str(PATH_TRAIN+'/050.jpg'))
# Casting to int for matplotlib to display the images
plt.subplot(121)
plt.title('Imagen con ruido')
plt.imshow(inp / 255.0)

plt.subplot(122)
plt.title('Imagen con ruido')
plt.imshow(re / 255.0)



As described in the [pix2pix paper](https://arxiv.org/abs/1611.07004), you need to apply random jittering and mirroring to preprocess the training set.

Define several functions that:

1. Resize each `300 x 300` image to a larger height and width—`286 x 286`.
2. Randomly crop it back to `256 x 256`.
3. Randomly flip the image horizontally i.e. left to right (random mirroring).
4. Normalize the images to the `[-1, 1]` range.

In [None]:
def resize(input_image, real_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image, real_image

In [None]:
def random_crop(input_image, real_image):
  stacked_image = tf.stack([input_image, real_image], axis=0)
  cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image[0], cropped_image[1]

In [None]:
# Normalizing the images to [-1, 1]
def normalize(input_image, real_image):
  input_image = (input_image / 127.5) - 1
  real_image = (real_image / 127.5) - 1

  return input_image, real_image

In [None]:
@tf.function()
def random_jitter(input_image, real_image):
  # Resizing to 286x286
  input_image, real_image = resize(input_image, real_image, 286, 286)

  # Random cropping back to 256x256
  input_image, real_image = random_crop(input_image, real_image)

  if tf.random.uniform(()) > 0.5:
    # Random mirroring
    input_image = tf.image.flip_left_right(input_image)
    real_image = tf.image.flip_left_right(real_image)

  return input_image, real_image

Se definen algunas funciones de ayuda para hacer el preprocesamiento de las imágenes de training y test:

In [None]:
def load_image_train(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = random_jitter(input_image, real_image)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

In [None]:
def load_image_test(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = resize(input_image, real_image,
                                   IMG_HEIGHT, IMG_WIDTH)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

## Generamos los datos de entrada con `tf.data`

In [None]:
# Cargamos las imágenes de training
train_dataset = tf.data.Dataset.list_files(str(PATH_TRAIN+'/*.jpg'))
# Aplicamos funciones de preprocesado:
#  - ramdom jitter: Escalar a 286 x 286 >> cortar aleatoriamente a 256 x 256 >> random mirroring
#  - normalizar a [-1, 1]
train_dataset = train_dataset.map(load_image_train, num_parallel_calls=tf.data.AUTOTUNE)
# Barajamos las imágenes
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
# Combinando la imágnes consecutivos del conjunto de datos en lotes.
train_dataset = train_dataset.batch(BATCH_SIZE)

# Cargamos las imágenes de test
test_dataset = tf.data.Dataset.list_files(str(PATH_TEST+'/*.jpg'))
# Aplicamos funciones de preprocesado:
#  - Escalar a 256 x 256 
#  - normalizar a [-1, 1]
test_dataset = test_dataset.map(load_image_test)
# Combinando la imágnes consecutivos del conjunto de datos en lotes.
test_dataset = test_dataset.batch(BATCH_SIZE)


## Build the generator

The generator of your pix2pix cGAN is a _modified_ [U-Net](https://arxiv.org/abs/1505.04597). A U-Net consists of an encoder (downsampler) and decoder (upsampler). (You can find out more about it in the [Image segmentation](https://www.tensorflow.org/tutorials/images/segmentation) tutorial and on the [U-Net project website](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/).)

- Each block in the encoder is: Convolution -> Batch normalization -> Leaky ReLU
- Each block in the decoder is: Transposed convolution -> Batch normalization -> Dropout (applied to the first 3 blocks) -> ReLU
- There are skip connections between the encoder and decoder (as in the U-Net).

Define the downsampler (encoder):

In [None]:
def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result

### Ejemplo. Reduce la dimensionaliad de la imagen de prueba (inp) de 300 x 300 a 150 x 150

In [None]:
# down_model = downsample(3, 4)
# down_result = down_model(tf.expand_dims(inp, 0))
# print (down_result.shape)

Define the upsampler (decoder):

In [None]:
def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result

### Ejemplo. Aumenta la dimensionaliad de la imagen de prueba (inp) de 150 x 150 a 300 x 300

In [None]:
# up_model = upsample(3, 4)
# up_result = up_model(down_result)
# print (up_result.shape)

Define the generator with the downsampler and the upsampler:

In [None]:
def Generator():
  inputs = tf.keras.layers.Input(shape=[256, 256, 3])

  down_stack = [
    downsample(64, 4, apply_batchnorm=False),  # (batch_size, 128, 128, 64)
    downsample(128, 4),  # (batch_size, 64, 64, 128)
    downsample(256, 4),  # (batch_size, 32, 32, 256)
    downsample(512, 4),  # (batch_size, 16, 16, 512)
    downsample(512, 4),  # (batch_size, 8, 8, 512)
    downsample(512, 4),  # (batch_size, 4, 4, 512)
    downsample(512, 4),  # (batch_size, 2, 2, 512)
    downsample(512, 4),  # (batch_size, 1, 1, 512)
  ]

  up_stack = [
    upsample(512, 4, apply_dropout=True),  # (batch_size, 2, 2, 1024)
    upsample(512, 4, apply_dropout=True),  # (batch_size, 4, 4, 1024)
    upsample(512, 4, apply_dropout=True),  # (batch_size, 8, 8, 1024)
    upsample(512, 4),  # (batch_size, 16, 16, 1024)
    upsample(256, 4),  # (batch_size, 32, 32, 512)
    upsample(128, 4),  # (batch_size, 64, 64, 256)
    upsample(64, 4),  # (batch_size, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh')  # (batch_size, 256, 256, 3)

  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

Visualize the generator model architecture:

In [None]:
generator = Generator()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

### Define the generator loss

GANs learn a loss that adapts to the data, while cGANs learn a structured loss that penalizes a possible structure that differs from the network output and the target image, as described in the [pix2pix paper](https://arxiv.org/abs/1611.07004).

- The generator loss is a sigmoid cross-entropy loss of the generated images and an **array of ones**.
- The pix2pix paper also mentions the L1 loss, which is a MAE (mean absolute error) between the generated image and the target image.
- This allows the generated image to become structurally similar to the target image.
- The formula to calculate the total generator loss is `gan_loss + LAMBDA * l1_loss`, where `LAMBDA = 100`. This value was decided by the authors of the paper.

In [None]:
LAMBDA = 100

In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

  # Mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss, gan_loss, l1_loss

In [None]:
# # Run the trained model on a few examples from the test set
# for inp, tar in test_dataset.take(5):
#   generate_images(generator, inp, tar)

The training procedure for the generator is as follows:

![Generator Update Image](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/gen.png?raw=1)


## Build the discriminator

The discriminator in the pix2pix cGAN is a convolutional PatchGAN classifier—it tries to classify if each image _patch_ is real or not real, as described in the [pix2pix paper](https://arxiv.org/abs/1611.07004).

- Each block in the discriminator is: Convolution -> Batch normalization -> Leaky ReLU.
- The shape of the output after the last layer is `(batch_size, 30, 30, 1)`.
- Each `30 x 30` image patch of the output classifies a `70 x 70` portion of the input image.
- The discriminator receives 2 inputs: 
    - The input image and the target image, which it should classify as real.
    - The input image and the generated image (the output of the generator), which it should classify as fake.
    - Use `tf.concat([inp, tar], axis=-1)` to concatenate these 2 inputs together.

Let's define the discriminator:

In [None]:
def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
  tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')

  x = tf.keras.layers.concatenate([inp, tar])  # (batch_size, 256, 256, channels*2)

  down1 = downsample(64, 4, False)(x)  # (batch_size, 128, 128, 64)
  down2 = downsample(128, 4)(down1)  # (batch_size, 64, 64, 128)
  down3 = downsample(256, 4)(down2)  # (batch_size, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (batch_size, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1)  # (batch_size, 31, 31, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (batch_size, 33, 33, 512)

  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2)  # (batch_size, 30, 30, 1)

  return tf.keras.Model(inputs=[inp, tar], outputs=last)

Visualize the discriminator model architecture:

In [None]:
discriminator = Discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)

### Define the discriminator loss

- The `discriminator_loss` function takes 2 inputs: **real images** and **generated images**.
- `real_loss` is a sigmoid cross-entropy loss of the **real images** and an **array of ones(since these are the real images)**.
- `generated_loss` is a sigmoid cross-entropy loss of the **generated images** and an **array of zeros (since these are the fake images)**.
- The `total_loss` is the sum of `real_loss` and `generated_loss`.

In [None]:
def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

The training procedure for the discriminator is shown below.

To learn more about the architecture and the hyperparameters you can refer to the [pix2pix paper](https://arxiv.org/abs/1611.07004).

![Discriminator Update Image](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/images/dis.png?raw=1)


## Define the optimizers 


In [None]:
# Adam optimizer parameters
# learning_rate: A Tensor, floating point value, or a schedule that is a tf.keras.optimizers.schedules.LearningRateSchedule, or a callable that takes no arguments and returns the actual value to use, The learning rate. Defaults to 0.001.
# beta_1: A float value or a constant float tensor, or a callable that takes no arguments and returns the actual value to use. The exponential decay rate for the 1st moment estimates. Defaults to 0.9.

generator_optimizer = tf.keras.optimizers.Adam(LEARNING_RATE, beta_1=BETA1) # Args: learning rate, beta_1
discriminator_optimizer = tf.keras.optimizers.Adam(LEARNING_RATE, beta_1=BETA1) # Args: learning rate, beta_1

## Define checkpoint-saver

In [None]:
checkpoint_dir = BASE_FOLDER+''+EXPERIMENT_NAME+'/checkpoints'

checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=2)

INIT_STEP = 0
# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  print('Restoring checkpoint')
  checkpoint.restore(ckpt_manager.latest_checkpoint)
  INIT_STEP = int(ckpt_manager.latest_checkpoint.split(sep='ckpt-')[-1])*5000
  print('Latest checkpoint restored!!')
print(INIT_STEP)


## Generate images

Write a function to plot some images during training.

- Pass images from the test set to the generator.
- The generator will then translate the input image into the output.
- The last step is to plot the predictions and _voila_!

Note: The `training=True` is intentional here since
you want the batch statistics, while running the model on the test dataset. If you use `training=False`, you get the accumulated statistics learned from the training dataset (which you don't want).

In [None]:
def generate_images(model, test_input, tar):
  prediction = model(test_input, training=True)
  plt.figure(figsize=(15, 15))

  display_list = [test_input[0], tar[0], prediction[0]]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # Getting the pixel values in the [0, 1] range to plot.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

Test the function:

In [None]:
# for example_input, example_target in test_dataset.take(1):
#   generate_images(generator, example_input, example_target)

## Training

- For each example input generates an output.
- The discriminator receives the `input_image` and the generated image as the first input. The second input is the `input_image` and the `target_image`.
- Next, calculate the generator and the discriminator loss.
- Then, calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.


In [None]:
# log_dir=BASE_FOLDER+''+EXPERIMENT_NAME+'/logs/'

# summary_writer = tf.summary.create_file_writer(
#   log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
# @tf.function
def train_step(input_image, target, step):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)

    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))

  # with summary_writer.as_default():
  #   tf.summary.scalar('gen_total_loss', gen_total_loss, step=step//1000)
  #   tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=step//1000)
  #   tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=step//1000)
  #   tf.summary.scalar('disc_loss', disc_loss, step=step//1000)
  
  return np.array([step, gen_total_loss, gen_gan_loss, gen_l1_loss, disc_loss])

In [None]:
# Guardar array en disco en formato npy
def guardar_array (url_nombre, array):
    np.save(url_nombre, array)

The actual training loop. Since this tutorial can run of more than one dataset, and the datasets vary greatly in size the training loop is setup to work in steps instead of epochs.

- Iterates over the number of steps.
- Every 10 steps print a dot (`.`).
- Every 1k steps: clear the display and run `generate_images` to show the progress.
- Every 5k steps: save a checkpoint.

In [None]:
def fit(train_ds, test_ds, steps):
  example_input, example_target = next(iter(test_ds.take(1)))
  start = time.time()

  if (INIT_STEP == 0): # Primera vez que se ejecuta se inicializa el array de métricas
    metrics_array = np.zeros((int(N_STEPS/1000),5))
    print("Entranamiento from scratch")
  else: # Se carga el array de métricas desde disco
    print("Entrenamiento desde checkpoint")
    metrics_array = np.load(BASE_FOLDER+''+EXPERIMENT_NAME+'/results/metrics.npy')
    steps = N_STEPS - INIT_STEP

  print("Entrenamos ",steps," steps desde ", INIT_STEP," HASTA ",N_STEPS)

  for step, (input_image, target) in train_ds.repeat().take(steps).enumerate():
    step = step + INIT_STEP
    
    if (step) % 1000 == 0:
      display.clear_output(wait=True)
      print(int(step),' ',N_STEPS)
      if step != 0:
        print('\n')
        print(f'Time taken for 1000 steps: {time.time()-start:.2f} sec\n')
        # Guardar el array de métricas
        indice = int((step / 1000)-1)
        print("\nGuardando la fila ",indice," en el array de métricas")
        metrica = train_step(input_image, target, step)
        metrics_array[indice] = metrica
        guardar_array (BASE_FOLDER+''+EXPERIMENT_NAME+'/results/metrics.npy', metrics_array)
      else:
        metrica = train_step(input_image, target, step)

      start = time.time()
      generate_images(generator, example_input, example_target)
      print(f"Step: {step//1000}k")

    else:
      metrica = train_step(input_image, target, step)

    # Training step
    if (step+1) % 10 == 0:
      print('.', end='', flush=True)

    # Save (checkpoint) the model every 5k steps
    if (step + 1) % 5000 == 0:
      ckpt_save_path = ckpt_manager.save()
      print ('\nSaving checkpoint for step {} at {}'.format(step+1, ckpt_save_path))
  
  # Guardamos la última métrica
  indice = int((N_STEPS / 1000)-1)
  print("Indice",indice)
  print("\nGuardando la última fila ",indice," en el array de métricas")
  metrics_array[indice] = metrica
  guardar_array (BASE_FOLDER+''+EXPERIMENT_NAME+'/results/metrics.npy', metrics_array)

  # Guardamos el modelo generador G
  print("Guardamos el modelo resultante")
  generator.save(BASE_FOLDER+'/'+EXPERIMENT_NAME+'/models')  

  # Guardamos el array de parámetros
  params_array = np.array([EXPERIMENT_NAME,str(BATCH_SIZE),str(N_STEPS),str(N_STEPS/400),str(BUFFER_SIZE),OPTIMIZER])
  guardar_array (BASE_FOLDER+''+EXPERIMENT_NAME+'/results/params.npy', params_array)

  return metrics_array, metrica


Finally, run the training loop:

In [None]:
metrics_array, metrica = fit(train_dataset, test_dataset, steps=N_STEPS)

In [None]:
metrics_array = np.load(BASE_FOLDER+''+EXPERIMENT_NAME+'/results/metrics.npy')
params_array = np.load(BASE_FOLDER+''+EXPERIMENT_NAME+'/results/params.npy')

title = "PIX2PIX"
subtitle = "Nombre exp.: ["+params_array[0]+"] - Batch size: ["+params_array[1]+"] - Steps: ["+params_array[2]+"] Epochs: ["+params_array[3]+"] - Buffer size: ["+params_array[4]+"] - Optimizer: ["+params_array[5]+"]"

fig = plt.figure(figsize=(16,12))
axis = fig.add_subplot(111)

plt.title(title+'\n\n'+subtitle+'\n', fontsize=16, pad=10)

axis.plot(metrics_array[:,0].astype(int),metrics_array[:,1], marker='o', color="#1974D2" ,label='TOTAL LOSS',linewidth=3, linestyle="dashed") 
axis.plot(metrics_array[:,0].astype(int),metrics_array[:,2], marker='o', color="#FF007F" ,label='GAN LOSS',linewidth=2) 
axis.plot(metrics_array[:,0].astype(int),metrics_array[:,3], marker='o', color="#FFAA1D" ,label='L1 LOSS',linewidth=2) 
axis.plot(metrics_array[:,0].astype(int),metrics_array[:,4], marker='o', color="#66FF00" ,label='DISC LOSS',linewidth=2)
axis.legend(loc='best')
plt.grid()
plt.show() 

fig.savefig(BASE_FOLDER+''+EXPERIMENT_NAME+'/results/plot.png')