In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from data import *
import random

In [151]:
initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=1.)
def discriminator_bad(dim_length,d=32,input_shape=(64,64,1),name='discriminator',output_dim=10):
    dims=[d *(2**i) for i in range(2,2+dim_length)]
    img_inputs = keras.Input(shape=(input_shape))
    x = layers.Conv2D(d, (4, 4), strides=(2, 2), padding="same",kernel_initializer=initializer)(img_inputs)
    x=layers.LeakyReLU(alpha=0.2)(x)

    for dim in dims:
        x = layers.Conv2D(d, (4, 4), strides=(2, 2), padding="same",kernel_initializer=initializer)(x)
        x= layers.BatchNormalization()(x)
        x=layers.LeakyReLU(alpha=0.2)(x)

    binary_x =layers.Conv2D(d, (4,4),strides=(2,2),kernel_initializer=initializer)(x)
    binary_x=layers.Flatten(name='binary_flatten')(binary_x)
    binary_output = layers.Dense(1,name='binary_output')(binary_x)

    multiclass_x=layers.Dense(d,kernel_initializer=initializer)(x)
    mutliclass_x=layers.Conv2D(dims[-1], (4,4),strides=(2,2),kernel_initializer=initializer)(multiclass_x)
    multiclass_x=layers.Flatten(name='multiclass_flatten')(multiclass_x)
    mutlticlass_output=layers.Dense(output_dim,name='multiclass_output')(multiclass_x)
    model = keras.Model(inputs=img_inputs, outputs=[binary_output,mutlticlass_output], name=name)
    return model

def discriminator(input_shape=(64,64,1),name='discriminator',output_dim=10):
    img_inputs = keras.Input(shape=(input_shape))
    d=input_shape[0]
    #x=layers.Flatten()(img_inputs)
    x=layers.Conv2D(16, (4, 4), strides=(2, 2), padding="same")(img_inputs)
    x=layers.LeakyReLU(alpha=.2)(x)
    x= layers.BatchNormalization()(x)
    x=layers.Conv2D(32, (4, 4), strides=(2, 2), padding="same")(x)
    x=layers.LeakyReLU(alpha=.2)(x)
    x= layers.BatchNormalization()(x)
    x=layers.Conv2D(64, (4, 4), strides=(2, 2), padding="same")(x)
    x=layers.LeakyReLU(alpha=.2)(x)
    x= layers.BatchNormalization()(x)
    x=layers.Flatten()(x)
    #binary_x=layers.Flatten(name='binary_flatten')(x)
    binary_output = layers.Dense(1,name='binary_output')(x)
    #multiclass_x=layers.Dense(,kernel_initializer=initializer)(x)
    #multiclass_x=layers.Flatten(name='multiclass_flatten')(multiclass_x)
    mutlticlass_output=layers.Dense(output_dim,name='multiclass_output')(x)
    model = keras.Model(inputs=img_inputs, outputs=[binary_output,mutlticlass_output], name=name)
    return model

^^The above functions define a discriminator

In [88]:
d=dataset_limited(['cubism'],1)

100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 17.30it/s]


dataset_limited is defined in data.py; it just gets the feature maps (represented as numpy arrays) for the artistic genre

In [152]:
dick=discriminator()

In [153]:
dick.summary()

Model: "discriminator"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_97 (InputLayer)           [(None, 64, 64, 1)]  0                                            
__________________________________________________________________________________________________
conv2d_42 (Conv2D)              (None, 32, 32, 16)   272         input_97[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_136 (LeakyReLU)     (None, 32, 32, 16)   0           conv2d_42[0][0]                  
__________________________________________________________________________________________________
batch_normalization_143 (BatchN (None, 32, 32, 16)   64          leaky_re_lu_136[0][0]            
______________________________________________________________________________________

In [118]:
def make_generator():
    latent_dim = 128
    generator = keras.Sequential(
        [
            keras.Input(shape=(latent_dim,)),
            # We want to generate 128 coefficients to reshape into a 7x7x128 map
            layers.Dense(8 * 8 * 128),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(),
            layers.Reshape((8, 8, 128)),
            layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(),
            layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(),
            layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
            layers.LeakyReLU(alpha=0.2),
            layers.Conv2D(1, (8, 8), padding="same", activation="sigmoid"),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(name=style_blocks[0]),
            layers.UpSampling2D(),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(name=style_blocks[1]),
            layers.UpSampling2D(),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(name=style_blocks[2]),
            layers.UpSampling2D(),
            layers.LeakyReLU(alpha=0.2),
            layers.BatchNormalization(name=style_blocks[3]),
            layers.Conv2D(1, (2,2), padding="same", activation="sigmoid"),
            layers.LeakyReLU(alpha=0.2,name=style_blocks[4])
        ],
        name="generator",
    )
    return generator

A basic generator function

In [None]:
generator.summary()

In [11]:
genres=random_genres(8)
labels,matrices=dataset_batched(genres,limit=20)

100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 25.47it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 25.65it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 27.28it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 22.99it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 25.02it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 24.12it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 24.04it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 30.61it/s]


random_genres is a function defined in extras.py; it just gets a bunch of random genres. dataset_batched is defined in data.py, it just gets the labels (if a painting is in 'cubism', label =cubism, represented as a one-hot encoding vector) and numpy representations of the paintings

In [12]:
matrices_zipped=tf.data.Dataset.zip(tuple(m for m in matrices))

In [143]:
class GAN(keras.Model):
    def __init__(self, discriminators, generator, latent_dim):
        super(GAN, self).__init__()
        self.discriminators = discriminators
        self.generator = generator
        self.latent_dim = latent_dim


    def compile(self, d_optimizer, g_optimizer, binary_loss_fn,multiclass_loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.binary_loss_fn = binary_loss_fn
        self.multiclass_loss_fn = multiclass_loss_fn

    def train_step(self, labels_multiclass,matrices,batch_size):
        #fake images
        net_loss=[]
        random_latent_vectors = [tf.random.normal(shape=(1, self.latent_dim)) for _ in range(batch_size)]
        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)
        ret={}
        for real_mats,block,discriminator in zip(matrices,style_blocks,self.discriminators):
            func=keras.backend.function([self.generator.get_layer(index=0).input], 
                                        self.generator.get_layer(name=block).output)
            gen_mats=[func(vector)[0] for vector in random_latent_vectors]
            #print(gen_mats[0].shape)
            #print(real_mats[0].shape)
            combined_mats = tf.concat([gen_mats, real_mats], axis=0)
            # Assemble labels discriminating real from fake images
            labels_binary = tf.concat(
                [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
            )
            # Add random noise to the labels - important trick!
            labels_binary += 0.05 * tf.random.uniform(tf.shape(labels_binary))
            
            # Train the discriminator
            with tf.GradientTape() as tape:
                pred_binary,pred_multiclass = discriminator(combined_mats)
                pred_multiclass=pred_multiclass[-len(real_mats):]
                d_loss = self.binary_loss_fn(labels_binary, pred_binary)+self.multiclass_loss_fn(labels_multiclass,pred_multiclass)
            grads = tape.gradient(d_loss, discriminator.trainable_weights)
            self.d_optimizer.apply_gradients(
                zip(grads, discriminator.trainable_weights)
            )
            net_loss.append(d_loss)
        return net_loss

defines the GAN with its own training loop

In [154]:
shapes=[(64, 64,1), (128, 128,1), (256, 256,1), (512, 512, 1), (512, 512,1)]
discriminators=[discriminator(input_shape=s,name=str(s),output_dim=len(genres)) for s in shapes]
generator=make_generator()

Five discriminators, since we're using feature maps from 5 different layers

In [155]:
for d in discriminators:
    d.summary()

Model: "(64, 64, 1)"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_98 (InputLayer)           [(None, 64, 64, 1)]  0                                            
__________________________________________________________________________________________________
conv2d_45 (Conv2D)              (None, 32, 32, 16)   272         input_98[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_139 (LeakyReLU)     (None, 32, 32, 16)   0           conv2d_45[0][0]                  
__________________________________________________________________________________________________
batch_normalization_146 (BatchN (None, 32, 32, 16)   64          leaky_re_lu_139[0][0]            
________________________________________________________________________________________

In [156]:
gan = GAN(discriminators=discriminators, generator=generator, latent_dim=latent_dim)

In [157]:
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    binary_loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
    multiclass_loss_fn=keras.losses.CategoricalCrossentropy(from_logits=True),
)

In [160]:
epochs=20
for epoch in range(epochs):
    print("\nStart epoch", epoch)

    step=0
    for step,(l,m) in enumerate(zip(labels,matrices_zipped)):
        # Train the discriminator & generator on one batch of real images.
        d_loss= gan.train_step(l,m,10)
        print("discriminator loss at step {}: {}".format(step,sum([(np.round(loss,3)) for loss in d_loss])))


Start epoch 0
discriminator loss at step 0: 23091.149077415466
discriminator loss at step 1: 23543.658094406128
discriminator loss at step 2: 1327523.0825195312
discriminator loss at step 3: 2446776.3416748047
discriminator loss at step 4: 3947724.1041259766
discriminator loss at step 5: 1375585.420135498
discriminator loss at step 6: 5074869.900695801
discriminator loss at step 7: 1500128.050857544
discriminator loss at step 8: 7285997.040222168
discriminator loss at step 9: 916189.2157592773
discriminator loss at step 10: 1837345.3723144531
discriminator loss at step 11: 3749354.8834228516
discriminator loss at step 12: 4873091.0361328125
discriminator loss at step 13: 3799913.098388672
discriminator loss at step 14: 4094925.545639038
discriminator loss at step 15: 2693076.171890259

Start epoch 1
discriminator loss at step 0: 234254.8247680664
discriminator loss at step 1: 111770.27681732178
discriminator loss at step 2: 10335151.901428223
discriminator loss at step 3: 3005288.4188

discriminator loss at step 5: 24847959.281723022
discriminator loss at step 6: 4146589.866882324


KeyboardInterrupt: 