In [None]:
import tensorflow as tf
import os
import zipfile
from tensorflow import keras
import sys
from tensorflow.keras import layers
import numpy as np
try:
  import tensorflow_addons as tfa
except:
  !pip install tensorflow_addons
  import tensorflow_addons as tfa

print(tf.__version__)

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
def unzip_dataset(zip_file_path, destination_folder):
    # Check if the destination folder exists, create it if not
    if not os.path.exists(destination_folder):
        os.makedirs(destination_folder)

    # Unzip the dataset
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        zip_ref.extractall(destination_folder)

    print("Dataset successfully unzipped.")

unzip_dataset('/content/drive/MyDrive/PetsGAN/PetsGAN_train.zip', '/content/drive/MyDrive/PetsGAN/')
unzip_dataset('/content/drive/MyDrive/PetsGAN/PetsGAN_validation.zip', '/content/drive/MyDrive/PetsGAN/')

In [None]:
# Import modules
sys.path.append('/content/drive/MyDrive/PetsGAN')
from model import CycleGAN
from train import create_training_dataloader as train_loader
from train import create_validation_dataloader as val_loader
from train import VisualizationCallback,LinearAnnealingScheduler

# Training parameters
BATCH_SIZE = 16
EPOCHS = 150

# Generate training set and create the model
pets_train, art_train = train_loader("/content/drive/MyDrive/PetsGAN/PetsGAN_train",
                                      BATCH_SIZE)

# Generate validation set
pets_val, art_val = val_loader(
        "/content/drive/MyDrive/PetsGAN/PetsGAN_validation",
        BATCH_SIZE
)


print(f"Creating PetsArtistGAN model...")

# Create and compile model
model = CycleGAN()
model.compile()

# Test in/out
model(np.random.rand(1, 256, 256, 3))

# Print summary
model.summary()

In [None]:
callback_image = next(iter(pets_val))
visualize_callback = VisualizationCallback(callback_image)
lr_callback = LinearAnnealingScheduler(initial_lr=2e-4, start_epoch=70, final_epoch = 150)

# Train the model
history = model.fit(tf.data.Dataset.zip((pets_train, art_train)),
                    epochs=EPOCHS,
                    validation_data = tf.data.Dataset.zip((pets_val, art_val)),
                    callbacks = [visualize_callback,
                                 lr_callback]
                    )

In [None]:
print(history.history.keys())

In [None]:
# Save whole model
model.save(f'/content/drive/MyDrive/PetsGAN/PetsArtistGAN_{EPOCHS}_{BATCH_SIZE}')

In [None]:
# Save just the weights
model.save_weights(f'/content/drive/MyDrive/PetsGAN/PetsArtistGAN_weights_{EPOCHS}_{BATCH_SIZE}.h5')