<a href="https://colab.research.google.com/github/ficle-fr/pix2pix_change_style/blob/colabs/colabs/train_tpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!wget https://raw.githubusercontent.com/ficle-fr/pix2pix_change_style/main/common.py
!wget https://raw.githubusercontent.com/ficle-fr/pix2pix_change_style/main/descriminator.py
!wget https://raw.githubusercontent.com/ficle-fr/pix2pix_change_style/main/generator.py
!wget https://raw.githubusercontent.com/ficle-fr/pix2pix_change_style/main/img_generator.py
!wget https://raw.githubusercontent.com/ficle-fr/pix2pix_change_style/main/train.py

#Currently from the colab branch
!wget https://raw.githubusercontent.com/ficle-fr/pix2pix_change_style/colabs/colabs/read_write_db.py


In [None]:
!sudo apt install libcairo2-dev pkg-config python3-dev
!pip install pycairo

In [None]:
import tensorflow as tf

from generator import Generator, generator_loss
from descriminator import Discriminator, discriminator_loss

from img_generator import img_pair_gen1
from read_write_db import write, read


In [None]:
write("temp_records100.tfrecords", 100)

In [None]:
dataset = read("./temp_records100.tfrecords")
BATCH_SIZE = 5
dataset = dataset.batch(BATCH_SIZE)

In [None]:
print("Tensorflow version " + tf.__version__)

try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  print("Total number of TPU cores:", tpu.get_tpu_system_metadata().num_cores)
except ValueError:
  raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.TPUStrategy(tpu)

In [None]:
@tf.function
def train_multiple_steps(generator, discriminator,
                         generator_optimizer, discriminator_optimizer,
                         dataset):
    def step_fn(input):
        input_image, target = input
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            gen_output = generator(input_image, training = True)

            disc_real_output = discriminator([input_image, target], training = True)
            disc_generated_output = discriminator([input_image, gen_output], training=True)

            gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
            disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

            generator_gradients = gen_tape.gradient(gen_total_loss,
                                                    generator.trainable_variables)
            discriminator_gradients = disc_tape.gradient(disc_loss,
                                                        discriminator.trainable_variables)

            generator_optimizer.apply_gradients(zip(generator_gradients,
                                                    generator.trainable_variables))
            discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                                        discriminator.trainable_variables))

    for input in dataset:
        start = time.time()
        tpu_strategy.run(step_fn, args=(input))
        print(f'\nTraining time: {time.time()-start:.2f} sec\n')

In [None]:
with tpu_strategy.scope():
    generator = Generator([256, 256, 3], 3)
    discriminator = Discriminator([256, 256, 3])

    generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

train_multiple_steps(generator, discriminator,
                    generator_optimizer, discriminator_optimizer,
                    dataset)