## Setup

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from cyclegan.cyclegan import get_resnet_generator, get_discriminator, CycleGan

## Data

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os

newpath = r'drive/MyDrive/Colab Notebooks/DL/project' 
if not os.path.exists(newpath):
    os.makedirs(newpath)

os.chdir('drive/MyDrive/Colab Notebooks/DL/project')
print(os.getcwd())

### Train data prepration

In [None]:
import cv2

# loading
with open('images.npy', 'rb') as f:
    images = np.load(f)[:1100]

with open('depths.npy', 'rb') as f:
    depths = np.expand_dims(np.load(f),axis=3)[:1100]

print(images.shape, depths.shape)

# Resize
ratio = 0.25
images = list(images)
depths = list(depths)

for i in range(len(images)):
    images[i] = cv2.resize(images[i],None,fx=ratio,fy=ratio)

for i in range(len(depths)):
    depths[i] = cv2.resize(depths[i],None,fx=ratio,fy=ratio)

images = np.array(images)
depths = np.expand_dims(np.array(depths),axis=3)

print(images.shape, depths.shape)

# normalization
images = images / 127.5 - 1
depths = depths / 127.5 - 1

# TF dataset
img_ds = tf.constant(images)
dep_ds = tf.constant(depths)
img_ds = tf.data.Dataset.from_tensor_slices(img_ds).cache().shuffle(256)
img_ds = img_ds.batch(batch_size=1)
dep_ds = tf.data.Dataset.from_tensor_slices(dep_ds).cache().shuffle(256)
dep_ds = dep_ds.batch(batch_size=1)
print(img_ds, dep_ds)

### Test data prepration

In [None]:
# loading
with open('images.npy', 'rb') as f:
    te_images = np.load(f)[1100:]

with open('depths.npy', 'rb') as f:
    te_depths = np.expand_dims(np.load(f),axis=3)[1100:]

# Resize
te_images = list(te_images)
te_depths = list(te_depths)
for i in range(len(te_images)):
    te_images[i] = cv2.resize(te_images[i],None,fx=ratio,fy=ratio)

for i in range(len(te_depths)):
    te_depths[i] = cv2.resize(te_depths[i],None,fx=ratio,fy=ratio)

te_images = np.array(te_images)
te_depths = np.expand_dims(np.array(te_depths),axis=3)
print (te_images.shape, te_depths.shape)

# Normalization
te_images = te_images/127.5 - 1
te_depths = te_depths/127.5 - 1

# TF dataset
te_img_ds = tf.constant(te_images)
te_dep_ds = tf.constant(te_depths)
te_img_ds = tf.data.Dataset.from_tensor_slices(te_img_ds).cache().shuffle(256)
te_img_ds = te_img_ds.batch(batch_size=5)
te_dep_ds = tf.data.Dataset.from_tensor_slices(te_dep_ds).cache().shuffle(256)
te_dep_ds = te_dep_ds.batch(batch_size=5)
print (img_ds, dep_ds)

## G and D

In [None]:
size_X = (120,160,3)
size_Y = (120,160,1)

# Get the generators
gen_G = get_resnet_generator (
    filters=48,
    num_downsampling_blocks=2,
    num_residual_blocks=5,
    num_upsample_blocks=2,
    name="generator_G",
    in_size=size_X,
    out_channel=1
)

gen_F = get_resnet_generator (
    filters=48,
    num_downsampling_blocks=2,
    num_residual_blocks=5,
    num_upsample_blocks=2,
    name="generator_F",
    out_channel=3,
    in_size = size_Y
)

# Get the discriminators
disc_X = get_discriminator(
    filters=48,
    num_downsampling=3,
    name = "discriminator_X",
    in_size = size_X
)

disc_Y = get_discriminator(
    filters=48,
    num_downsampling=3,
    in_size = size_Y,
    name="discriminator_Y"
)

In [None]:
gen_G.summary()

## Callbacks

In [None]:
class GANMonitor(keras.callbacks.Callback):
    """A callback to generate and save images after each epoch"""

    def __init__(self, num_img=4):
        self.num_img = num_img

    def on_epoch_end(self, epoch, logs=None):
        _, ax = plt.subplots(4, 2, figsize=(12, 12))
        for i, img in enumerate(te_img_ds.take(self.num_img)):
            prediction = self.model.gen_G(img)[0].numpy()
            prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
            img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

            ax[i, 0].imshow(img)
            ax[i, 1].imshow(prediction.squeeze())
            ax[i, 0].set_title("Input image")
            ax[i, 1].set_title("Translated image")
            ax[i, 0].axis("off")
            ax[i, 1].axis("off")

            prediction = keras.preprocessing.image.array_to_img(prediction)
            prediction.save(
                "generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch + 1)
            )
        plt.show()
        plt.close()

## Training model

In [None]:
adv_loss_fn = keras.losses.MeanSquaredError()

def generator_loss_fn(fake):
    fake_loss = adv_loss_fn(tf.ones_like(fake), fake)
    return fake_loss

def discriminator_loss_fn(real, fake):
    real_loss = adv_loss_fn(tf.ones_like(real), real)
    fake_loss = adv_loss_fn(tf.zeros_like(fake), fake)
    return (real_loss + fake_loss) * 0.5


# Create cycle gan model
cycle_gan_model = CycleGan(
    generator_G=gen_G, generator_F=gen_F, discriminator_X=disc_X, discriminator_Y=disc_Y
)

# Compile the model
cycle_gan_model.compile(
    gen_G_optimizer=keras.optimizers.Adam(learning_rate=5e-4, beta_1=0.5),
    gen_F_optimizer=keras.optimizers.Adam(learning_rate=5e-4, beta_1=0.5),
    disc_X_optimizer=keras.optimizers.Adam(learning_rate=5e-4, beta_1=0.5),
    disc_Y_optimizer=keras.optimizers.Adam(learning_rate=5e-4, beta_1=0.5),
    gen_loss_fn=generator_loss_fn,
    disc_loss_fn=discriminator_loss_fn,
)

# Callbacks
plotter = GANMonitor()
checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d}"
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath
)
pass

In [None]:
hist = cycle_gan_model.fit(
    tf.data.Dataset.zip((img_ds, dep_ds)),
    epochs=10,
    callbacks=[plotter]
)

## Predictions

In [None]:
_, ax = plt.subplots(5, 3, figsize=(10, 15))

for i in range(5):
    img = te_images[10*i:10*i+1]
    prediction = cycle_gan_model.gen_G(img, training=False)[0]
    prediction = (prediction * 127.5 + 127.5).numpy().astype(np.uint8)
    img = (img[0] * 127.5 + 127.5).astype(np.uint8)
    dep = (te_depths[10*i]*127.5 + 127.5).astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction.squeeze())
    ax[i, 2].imshow(dep.squeeze())
    ax[i, 0].set_title("Input image")
    ax[i, 1].set_title("Translated image")
    ax[i, 2].set_title("Actual Depth")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
    ax[i, 2].axis("off")

plt.tight_layout()
plt.show()