In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers
from ganutilities import manage_batch_2c
from createganmodels import new_gan_discriminator, new_gan_generator

In [15]:
# input variables for training

input_ticks = 3200
noise_dim = 100
batch_size = 50
learning_rate = 1e-3
data_path = 'fillsplit3200.pickle'
starting_epoch = 25

In [18]:
# You can either create new models, or load already trained ones to continue training

#discriminator = new_gan_discriminator(input_ticks)
#generator = new_gan_generator(input_ticks)

gen_path = './bigepochs/gene25'
generator = tf.keras.models.load_model(gen_path)

disc_path = './bigepochs/dise25'
discriminator = tf.keras.models.load_model(disc_path)











In [4]:
# Uses cross entropy loss (as recommended usually for GANs)
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

# Discriminator loss is given by the loss from real_output and fake_output
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

# Discriminator successfully guessing the generator is fake is higher loss for the generator

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

# Use Adam optimizer (as recommended for many neural nets)
generator_optimizer = tf.keras.optimizers.Adam(learning_rate)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate)

In [19]:
# trains gan for input number of epochs
def train_dcgan(epochs):
    # uses batch manager to get new shuffled batches every epoch
    batch_manager = manage_batch_2c(data_path, batch_size)
    for i in range(epochs):
        batches = batch_manager.new_epoch()
        for batch in batches:
            formatted_batch = batch
            
            # generated latent space for generator
            noise = tf.random.normal([batch_size, noise_dim])
            
            # generates sequences, and gets generator and discriminator loss
            with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                generated_sequences = generator(noise, training=True)

                real_output = discriminator(formatted_batch, training=True)
                fake_output = discriminator(generated_sequences, training=True)

                gen_loss = generator_loss(fake_output)
                disc_loss = discriminator_loss(real_output, fake_output)
                print(gen_loss, disc_loss)
            
            # applying gradients
            gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
            gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

            generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
            discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
        
        
        generator.save("./bigepochs/gene" + str(i+starting_epoch+1))
        discriminator.save("./bigepochs/dise" + str(i+starting_epoch+1))
        print("finished epoch!", i+starting_epoch + 1)


In [20]:
# trains gan for 10 epochs
train_dcgan(10)

tf.Tensor(1.173255, shape=(), dtype=float32) tf.Tensor(7.2446275, shape=(), dtype=float32)
tf.Tensor(346.95258, shape=(), dtype=float32) tf.Tensor(357.7815, shape=(), dtype=float32)
tf.Tensor(83.99507, shape=(), dtype=float32) tf.Tensor(91.46596, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(107.936554, shape=(), dtype=float32)
tf.Tensor(1.2369725, shape=(), dtype=float32) tf.Tensor(9.884946, shape=(), dtype=float32)
tf.Tensor(274.7884, shape=(), dtype=float32) tf.Tensor(245.25587, shape=(), dtype=float32)
tf.Tensor(207.21594, shape=(), dtype=float32) tf.Tensor(161.24785, shape=(), dtype=float32)




INFO:tensorflow:Assets written to: ./bigepochs/gene26/assets


INFO:tensorflow:Assets written to: ./bigepochs/gene26/assets






INFO:tensorflow:Assets written to: ./bigepochs/dise26/assets


INFO:tensorflow:Assets written to: ./bigepochs/dise26/assets


finished epoch! 26
tf.Tensor(31.231565, shape=(), dtype=float32) tf.Tensor(2.665008, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(134.65463, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(142.17696, shape=(), dtype=float32)
tf.Tensor(0.66249484, shape=(), dtype=float32) tf.Tensor(27.634062, shape=(), dtype=float32)
tf.Tensor(256.28748, shape=(), dtype=float32) tf.Tensor(108.10621, shape=(), dtype=float32)
tf.Tensor(199.83578, shape=(), dtype=float32) tf.Tensor(36.365894, shape=(), dtype=float32)
tf.Tensor(47.107212, shape=(), dtype=float32) tf.Tensor(0.24139467, shape=(), dtype=float32)




INFO:tensorflow:Assets written to: ./bigepochs/gene27/assets


INFO:tensorflow:Assets written to: ./bigepochs/gene27/assets






INFO:tensorflow:Assets written to: ./bigepochs/dise27/assets


INFO:tensorflow:Assets written to: ./bigepochs/dise27/assets


finished epoch! 27
tf.Tensor(4.8363744e-16, shape=(), dtype=float32) tf.Tensor(73.97549, shape=(), dtype=float32)
tf.Tensor(2.4914083e-08, shape=(), dtype=float32) tf.Tensor(78.31111, shape=(), dtype=float32)
tf.Tensor(5.5010166, shape=(), dtype=float32) tf.Tensor(18.385324, shape=(), dtype=float32)
tf.Tensor(95.64334, shape=(), dtype=float32) tf.Tensor(79.59271, shape=(), dtype=float32)
tf.Tensor(81.23978, shape=(), dtype=float32) tf.Tensor(83.75076, shape=(), dtype=float32)
tf.Tensor(5.1071, shape=(), dtype=float32) tf.Tensor(21.330233, shape=(), dtype=float32)
tf.Tensor(7.000683, shape=(), dtype=float32) tf.Tensor(21.439692, shape=(), dtype=float32)




INFO:tensorflow:Assets written to: ./bigepochs/gene28/assets


INFO:tensorflow:Assets written to: ./bigepochs/gene28/assets






INFO:tensorflow:Assets written to: ./bigepochs/dise28/assets


INFO:tensorflow:Assets written to: ./bigepochs/dise28/assets


finished epoch! 28
tf.Tensor(89.66344, shape=(), dtype=float32) tf.Tensor(1.1598139e-05, shape=(), dtype=float32)
tf.Tensor(181.94373, shape=(), dtype=float32) tf.Tensor(1.9150891e-21, shape=(), dtype=float32)
tf.Tensor(234.63618, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(235.51773, shape=(), dtype=float32) tf.Tensor(4.101904e-07, shape=(), dtype=float32)
tf.Tensor(180.10147, shape=(), dtype=float32) tf.Tensor(0.20123129, shape=(), dtype=float32)
tf.Tensor(91.16605, shape=(), dtype=float32) tf.Tensor(5.723792, shape=(), dtype=float32)
tf.Tensor(19.716066, shape=(), dtype=float32) tf.Tensor(33.412453, shape=(), dtype=float32)




INFO:tensorflow:Assets written to: ./bigepochs/gene29/assets


INFO:tensorflow:Assets written to: ./bigepochs/gene29/assets






INFO:tensorflow:Assets written to: ./bigepochs/dise29/assets


INFO:tensorflow:Assets written to: ./bigepochs/dise29/assets


finished epoch! 29
tf.Tensor(18.272846, shape=(), dtype=float32) tf.Tensor(26.941988, shape=(), dtype=float32)
tf.Tensor(152.2041, shape=(), dtype=float32) tf.Tensor(7.411597e-28, shape=(), dtype=float32)
tf.Tensor(317.04227, shape=(), dtype=float32) tf.Tensor(57.3675, shape=(), dtype=float32)
tf.Tensor(50.285, shape=(), dtype=float32) tf.Tensor(7.584506, shape=(), dtype=float32)
tf.Tensor(1.6086531e-16, shape=(), dtype=float32) tf.Tensor(125.2349, shape=(), dtype=float32)
tf.Tensor(50.54313, shape=(), dtype=float32) tf.Tensor(20.374016, shape=(), dtype=float32)
tf.Tensor(249.40309, shape=(), dtype=float32) tf.Tensor(0.32384983, shape=(), dtype=float32)




INFO:tensorflow:Assets written to: ./bigepochs/gene30/assets


INFO:tensorflow:Assets written to: ./bigepochs/gene30/assets






INFO:tensorflow:Assets written to: ./bigepochs/dise30/assets


INFO:tensorflow:Assets written to: ./bigepochs/dise30/assets


finished epoch! 30
tf.Tensor(428.46857, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(476.8269, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(362.67743, shape=(), dtype=float32) tf.Tensor(8.299388e-32, shape=(), dtype=float32)
tf.Tensor(131.63765, shape=(), dtype=float32) tf.Tensor(0.0067921923, shape=(), dtype=float32)
tf.Tensor(47.704174, shape=(), dtype=float32) tf.Tensor(1.1022213, shape=(), dtype=float32)
tf.Tensor(146.2677, shape=(), dtype=float32) tf.Tensor(118.7604, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(316.13824, shape=(), dtype=float32)




INFO:tensorflow:Assets written to: ./bigepochs/gene31/assets


INFO:tensorflow:Assets written to: ./bigepochs/gene31/assets






INFO:tensorflow:Assets written to: ./bigepochs/dise31/assets


INFO:tensorflow:Assets written to: ./bigepochs/dise31/assets


finished epoch! 31
tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(349.7865, shape=(), dtype=float32)
tf.Tensor(1.2924381, shape=(), dtype=float32) tf.Tensor(163.5806, shape=(), dtype=float32)
tf.Tensor(266.19797, shape=(), dtype=float32) tf.Tensor(281.00882, shape=(), dtype=float32)
tf.Tensor(142.59502, shape=(), dtype=float32) tf.Tensor(170.38416, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(186.52516, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(249.55295, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(140.36781, shape=(), dtype=float32)




INFO:tensorflow:Assets written to: ./bigepochs/gene32/assets


INFO:tensorflow:Assets written to: ./bigepochs/gene32/assets






INFO:tensorflow:Assets written to: ./bigepochs/dise32/assets


INFO:tensorflow:Assets written to: ./bigepochs/dise32/assets


finished epoch! 32
tf.Tensor(135.44666, shape=(), dtype=float32) tf.Tensor(97.497734, shape=(), dtype=float32)
tf.Tensor(208.37422, shape=(), dtype=float32) tf.Tensor(118.81989, shape=(), dtype=float32)
tf.Tensor(98.874054, shape=(), dtype=float32) tf.Tensor(5.473489, shape=(), dtype=float32)
tf.Tensor(0.123780005, shape=(), dtype=float32) tf.Tensor(39.632042, shape=(), dtype=float32)
tf.Tensor(26.51674, shape=(), dtype=float32) tf.Tensor(0.47041076, shape=(), dtype=float32)
tf.Tensor(111.09982, shape=(), dtype=float32) tf.Tensor(4.5340722e-18, shape=(), dtype=float32)
tf.Tensor(192.69722, shape=(), dtype=float32) tf.Tensor(0.0, shape=(), dtype=float32)




INFO:tensorflow:Assets written to: ./bigepochs/gene33/assets


INFO:tensorflow:Assets written to: ./bigepochs/gene33/assets






INFO:tensorflow:Assets written to: ./bigepochs/dise33/assets


INFO:tensorflow:Assets written to: ./bigepochs/dise33/assets


finished epoch! 33
tf.Tensor(253.53008, shape=(), dtype=float32) tf.Tensor(8.0415063e-11, shape=(), dtype=float32)
tf.Tensor(281.0504, shape=(), dtype=float32) tf.Tensor(0.00052313716, shape=(), dtype=float32)
tf.Tensor(282.4862, shape=(), dtype=float32) tf.Tensor(2.1448876e-23, shape=(), dtype=float32)
tf.Tensor(264.7468, shape=(), dtype=float32) tf.Tensor(0.0012291339, shape=(), dtype=float32)
tf.Tensor(233.87744, shape=(), dtype=float32) tf.Tensor(0.32470378, shape=(), dtype=float32)
tf.Tensor(203.69945, shape=(), dtype=float32) tf.Tensor(4.5270605, shape=(), dtype=float32)
tf.Tensor(142.84663, shape=(), dtype=float32) tf.Tensor(1.8592765, shape=(), dtype=float32)




INFO:tensorflow:Assets written to: ./bigepochs/gene34/assets


INFO:tensorflow:Assets written to: ./bigepochs/gene34/assets






INFO:tensorflow:Assets written to: ./bigepochs/dise34/assets


INFO:tensorflow:Assets written to: ./bigepochs/dise34/assets


finished epoch! 34
tf.Tensor(66.08867, shape=(), dtype=float32) tf.Tensor(0.18229018, shape=(), dtype=float32)
tf.Tensor(9.689105, shape=(), dtype=float32) tf.Tensor(21.315649, shape=(), dtype=float32)
tf.Tensor(131.22803, shape=(), dtype=float32) tf.Tensor(0.00051252526, shape=(), dtype=float32)
tf.Tensor(266.16873, shape=(), dtype=float32) tf.Tensor(30.01282, shape=(), dtype=float32)
tf.Tensor(179.33607, shape=(), dtype=float32) tf.Tensor(3.966597e-24, shape=(), dtype=float32)
tf.Tensor(127.89675, shape=(), dtype=float32) tf.Tensor(0.41862983, shape=(), dtype=float32)
tf.Tensor(94.47291, shape=(), dtype=float32) tf.Tensor(3.19908, shape=(), dtype=float32)




INFO:tensorflow:Assets written to: ./bigepochs/gene35/assets


INFO:tensorflow:Assets written to: ./bigepochs/gene35/assets






INFO:tensorflow:Assets written to: ./bigepochs/dise35/assets


INFO:tensorflow:Assets written to: ./bigepochs/dise35/assets


finished epoch! 35
