In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import pandas as pd
import os

from google.colab import drive
from google.colab import files

# custom modules
from vis_utils import show_ds_examples
from artist import Artist
from models import Generator, Discriminator, GAN
from model_monitor import ModelMonitor

In [None]:
# Loading personal google drive 
drive.mount('/content/drive')

# Upload your kaggle json api key
files.upload()

In [None]:
# Commands to be run in google collab to import kaggle dataset
! pip install -q kaggle
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json
! kaggle datasets download -d oliverbgibbons/abstract-art
! unzip /content/abstract-art.zip

In [None]:
# dataset location
data_dir = '/content/data'
# locations to save generated images, final model and losses
save_imgs_fp = '/content/drive/MyDrive/ai_artist/images'
save_model_fp = '/content/drive/MyDrive/ai_artist/model'
save_losses_fp = '/content/drive/MyDrive/ai_artist/losses'

# Hyper parameters
batch_size = 32
num_epochs = 700
gen_learning_rate = 0.0001
dis_learning_rate = 0.00001
latent_space_size = 200
discriminator_noise = 0.15

# Constants
img_height = 128
img_width = img_height

In [None]:
artist = Artist()

# Artist processes data to produce normalised batched dataset
artist.process_data(data_dir, img_height, img_width, batch_size)

In [None]:
# Initialising generator and discriminator models
generator = Generator(latent_dim=latent_space_size, image_size=img_height)
discriminator = Discriminator(img_width, img_height, discriminator_noise)

# Initialising GAN using these two models and passing constant noise value
gan = GAN(generator, discriminator)

# Model monitor will save example image after each epoch in save_imgs location
model_monitor = ModelMonitor(latent_space_size, save_imgs_fp)

In [None]:
artist.learn(gan, num_epochs, gen_learning_rate, dis_learning_rate, model_monitor)

In [None]:
# Saving the final GAN model once training finished
artist.final_gan.generator_model.save(os.path.join(save_model_fp,'final_model.h5'))