In [6]:
from tensorflow_examples.models.pix2pix import pix2pix
import tensorflow as tf
import tensorflow as tf
from tensorflow_examples.models.pix2pix import pix2pix
from loss_functions import *
from cycle_data_helper import *
from img_helpers import *
import time
import argparse
from datetime import datetime, timezone
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

In [8]:
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [9]:
@tf.function
def train_step(real_x, real_y):
    # persistent is set to True because the tape is used more than
    # once to calculate the gradients.
    with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.

        fake_y = generator_g(real_x, training=True)
        cycled_x = generator_f(fake_y, training=True)

        fake_x = generator_f(real_y, training=True)
        cycled_y = generator_g(fake_x, training=True)

        # same_x and same_y are used for identity loss.
        same_x = generator_f(real_x, training=True)
        same_y = generator_g(real_y, training=True)

        disc_real_x = discriminator_x(real_x, training=True)
        disc_real_y = discriminator_y(real_y, training=True)

        disc_fake_x = discriminator_x(fake_x, training=True)
        disc_fake_y = discriminator_y(fake_y, training=True)

        # calculate the loss
        gen_g_loss = generator_loss(disc_fake_y)
        gen_f_loss = generator_loss(disc_fake_x)

        total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)

        # Total generator loss = adversarial loss + cycle loss
        total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
        total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

        disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
        disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

    # Calculate the gradients for generator and discriminator
    generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
    generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)

    discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
    discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)

    # Apply the gradients to the optimizer
    generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

    generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))

    discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))

    discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))

In [107]:
def get_datasets(batch_size=BATCH_SIZE,test=False, cartoon_path="jlbaker361/avatar_captioned-augmented", movie_path="jlbaker361/movie_captioned-augmented"):
    if test:
        data_frame_cartoon= [np.array(img) for img in load_dataset("jlbaker361/little_dataset",split="train")["image"]]
        data_frame_movie=[np.array(img) for img in load_dataset("jlbaker361/little_dataset",split="train")["image"]]
    else:
        data_frame_cartoon= [np.array(img) for img in load_dataset(cartoon_path,split="train")["image"]]
        data_frame_movie=[np.array(img) for img in load_dataset(movie_path,split="train")["image"]]
    data_frame_cartoon, data_frame_movie = equal_length(data_frame_cartoon, data_frame_movie)
    train_cartoon= tf.data.Dataset.from_tensor_slices(data_frame_cartoon).map(
        preprocess_image_train, num_parallel_calls=AUTOTUNE).batch(batch_size)
    train_movie = tf.data.Dataset.from_tensor_slices(data_frame_movie).map(
        preprocess_image_train, num_parallel_calls=AUTOTUNE).batch(batch_size)
    return train_cartoon, train_movie

In [108]:
train_cartoon, train_movie = get_datasets(batch_size=1,test=True,movie_path="jlbaker361/little-dataset",cartoon_path="jlbaker361/little-dataset")

Using custom data configuration jlbaker361--little_dataset-16ebd2a0f72cf497
Found cached dataset parquet (/home/jlb638/.cache/huggingface/datasets/jlbaker361___parquet/jlbaker361--little_dataset-16ebd2a0f72cf497/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Using custom data configuration jlbaker361--little_dataset-16ebd2a0f72cf497
Found cached dataset parquet (/home/jlb638/.cache/huggingface/datasets/jlbaker361___parquet/jlbaker361--little_dataset-16ebd2a0f72cf497/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


In [109]:
[t for t in train_movie]

2023-04-02 22:38:14.606144: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype uint8 and shape [5,512,512,3]
	 [[{{node Placeholder/_0}}]]


[<tf.Tensor: shape=(1, 512, 512, 3), dtype=float32, numpy=
 array([[[[ 0.09019613, -0.04313725, -0.44313723],
          [ 0.09803927,  0.01176476, -0.45098037],
          [ 0.12941182,  0.07450986, -0.42745095],
          ...,
          [-0.52156866, -0.54509807, -0.5686275 ],
          [-0.5294118 , -0.54509807, -0.5921569 ],
          [-0.5529412 , -0.5686275 , -0.5921569 ]],
 
         [[-0.12941176, -0.27058822, -0.654902  ],
          [-0.11372548, -0.21568626, -0.6627451 ],
          [ 0.01176476, -0.03529412, -0.5372549 ],
          ...,
          [-0.5372549 , -0.5686275 , -0.60784316],
          [-0.52156866, -0.5529412 , -0.6       ],
          [-0.52156866, -0.5529412 , -0.5921569 ]],
 
         [[-0.12941176, -0.26274508, -0.654902  ],
          [-0.11372548, -0.21568626, -0.654902  ],
          [-0.01960784, -0.05882353, -0.5764706 ],
          ...,
          [-0.52156866, -0.5529412 , -0.6       ],
          [-0.52156866, -0.5529412 , -0.6156863 ],
          [-0.5137255 ,

In [110]:
train_movie_sample=[t for t in train_movie][0][0]
train_cartoon_sample=[t for t in train_cartoon][0][0]

2023-04-02 22:38:47.143743: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype uint8 and shape [5,512,512,3]
	 [[{{node Placeholder/_0}}]]


In [111]:
for epoch in range(2):
    start = time.time()
    n = 0
    for image_x, image_y in tf.data.Dataset.zip((train_cartoon, train_movie)):
        train_step(image_x, image_y)
        if n % 10 == 0:
            print ('.', end='')
        n += 1
    print ('\nTime taken for epoch {} is {} sec\n'.format(epoch + 1,time.time()-start))

2023-04-02 22:39:14.003300: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_7' with dtype uint8 and shape [5,512,512,3]
	 [[{{node Placeholder/_7}}]]
2023-04-02 22:39:14.004054: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype uint8 and shape [5,512,512,3]
	 [[{{node Placeholder/_0}}]]


.
Time taken for epoch 1 is 40.92991375923157 sec



2023-04-02 22:39:54.922179: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype uint8 and shape [5,512,512,3]
	 [[{{node Placeholder/_0}}]]
2023-04-02 22:39:54.922586: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype uint8 and shape [5,512,512,3]
	 [[{{node Placeholder/_0}}]]


.
Time taken for epoch 2 is 10.353836297988892 sec

