In [1]:
import tensorflow as tf
import numpy as np
from tensorflow.keras import layers
#import matplotlib.pyplot as plt
import time

In [5]:
class LocationAdd(layers.Layer):
    def __init__(self, input_dim):
        super(LocationAdd, self).__init__()
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(initial_value=w_init(shape=(input_dim,), dtype='float32'), trainable=True)

    def call(self, inputs):
        return tf.add(inputs, self.w)

In [14]:
class WGAN:
    
    ''' A static model, with fixed input size.
        Model should not be defined in train(), so it should not depend on dataset dimension.
        Model API has an advantage that it can save weights between different calls of train.
        Instead of using tf.Session() as before, where only one training can happen, next will refresh,
        using Model() API avoids this, it provides a model which saves weights outside tf.Session()!
    '''
    
    def __init__(self, dim_x):
        self.dim_x = dim_x
        self.generator = self.generator_model(dim_x)
        self.discriminator = self.discriminator_model(dim_x)

    
    def generator_model(self, dim):
        inputs = layers.Input(shape=(dim,))
        out = LocationAdd(dim)(inputs)
        model = tf.keras.Model(inputs=inputs, outputs=out)
        return model
    
    def discriminator_model(self, dim):
        inputs = layers.Input(shape=(dim,))
        dense1 = layers.Dense(2*dim, activation=tf.nn.sigmoid)(inputs)
        out = layers.Dense(1, activation=tf.nn.sigmoid)(dense1)
        model = tf.keras.Model(inputs=inputs, outputs=out)
        return model
    
    @staticmethod
    def discriminator_loss(real_output, fake_output):
        cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        real_loss = cross_entropy(tf.ones_like(real_output), real_output)
        fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
        total_loss = real_loss + fake_loss
        return total_loss
    
    @staticmethod
    def generator_loss(fake_output):
        cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        return cross_entropy(tf.ones_like(fake_output), fake_output)

        
       
    def train(self, dataset, epochs, batch_size, step_size):
        self.generator_optimizer = tf.keras.optimizers.Adam(step_size)
        self.discriminator_optimizer = tf.keras.optimizers.Adam(step_size)
        for epoch in range(epochs):
            start = time.time()
            for i in range(dataset.shape[0]//batch_size):
                noise = tf.random.normal([batch_size, self.dim_x])
                with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                    generated = self.generator(noise, training=True)
                    real_output = self.discriminator(dataset[i*batch_size:(i+1)*batch_size], training=True)
                    fake_output = self.discriminator(generated, training=True)

                    gen_loss = WGAN.generator_loss(fake_output)
                    disc_loss = WGAN.discriminator_loss(real_output, fake_output)

                    gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
                    gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)

                    self.generator_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
                    self.discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))
    
            print('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
            print(self.generator.trainable_variables[0].numpy())
            #print("generator loss:", gen_loss.numpy(), "discriminator loss: ", disc_loss.numpy())

In [15]:
data = np.random.normal(size=(100,4)) + np.array([1,2,3,4])
np.mean(data, axis=0)

array([0.93648027, 1.978882  , 2.84213983, 4.09846861])

In [16]:
wgan = WGAN(dim_x=4)

In [17]:
wgan.train(data, 10, 10, 0.0001)



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

Time for epoch 1 is 3.6977691650390625 sec
[ 0.02811069 -0.02587136  0.03036069  0.00514906]
Time for epoch 2 is 3.4377169609069824 sec
[ 0.02911377 -0.02486555  0.031369    0.00614878]
Time for epoch 3 is 3.5907018184661865 sec
[ 0.03010978 -0.02385991  0.03237041  0.00716677]
Time for epoch 4 is 3.810439109802246 sec
[ 0.03111349 -0.02285625  0.03338799  0.00818473]
Time for epoch 5 is 3.245352029800415 sec
[ 0.03211852 -0.0218489   0.03441513  0.00922087]
Time for epoch 6 is 3.166031837463379 sec
[ 0.03312031 -0.02083381  0.03545093  0.01025299]
Time for epoch 7 is 3.2040462493896484 sec
[ 0.03412016 -0.01982057  0.03648442  0.01129337]
Time for epoch 8 is 3.3208160400390625 sec
[ 0.0351

In [8]:
wgan.generator.trainable_variables[0].numpy()

array([-0.11304637,  0.1813672 ,  0.06781566, -0.09694262], dtype=float32)

In [18]:
wgan.train(data, 10, 10, 0.01)

Time for epoch 1 is 3.7638468742370605 sec
[0.13746914 0.08385    0.14084247 0.1158528 ]
Time for epoch 2 is 3.9754319190979004 sec
[0.24260731 0.18970871 0.24846962 0.22440149]
Time for epoch 3 is 3.3422088623046875 sec
[0.35535112 0.3032939  0.3636259  0.34145534]
Time for epoch 4 is 3.193582057952881 sec
[0.4696641  0.42172754 0.48498908 0.46409452]
Time for epoch 5 is 3.183464765548706 sec
[0.58293366 0.544464   0.6132251  0.5924657 ]
Time for epoch 6 is 3.151916027069092 sec
[0.683527   0.6688246  0.74297917 0.7250354 ]
Time for epoch 7 is 4.061507940292358 sec
[0.7425081  0.78306454 0.86300606 0.8577488 ]
Time for epoch 8 is 3.815953016281128 sec
[0.7360032  0.8851273  0.97896034 0.99006027]
Time for epoch 9 is 3.9077441692352295 sec
[0.6628707 0.9772671 1.0994866 1.128501 ]
Time for epoch 10 is 3.224695920944214 sec
[0.55274636 1.0543734  1.220586   1.2663777 ]


In [19]:
wgan.train(data, 10, 10, 0.0001)

Time for epoch 1 is 3.673401117324829 sec
[0.55174184 1.0553634  1.2215897  1.2673678 ]
Time for epoch 2 is 3.4440970420837402 sec
[0.5507361 1.0563568 1.222593  1.2683685]
Time for epoch 3 is 3.53598690032959 sec
[0.54972816 1.0573869  1.2236145  1.2693708 ]
Time for epoch 4 is 3.1795761585235596 sec
[0.5487161 1.058404  1.2246329 1.2703743]
Time for epoch 5 is 3.1852569580078125 sec
[0.5477174 1.0593982 1.2256329 1.2713748]
Time for epoch 6 is 3.2666141986846924 sec
[0.5467418 1.0603408 1.2266039 1.2723691]
Time for epoch 7 is 3.986191749572754 sec
[0.5457898 1.0612903 1.2275658 1.2733355]
Time for epoch 8 is 4.065059185028076 sec
[0.54484683 1.0621921  1.2285038  1.2742845 ]
Time for epoch 9 is 4.0624473094940186 sec
[0.5439114 1.06309   1.2294383 1.2752315]
Time for epoch 10 is 4.251392126083374 sec
[0.54292923 1.06405    1.2304255  1.2762266 ]


In [20]:
wgan.train(data, 10, 10, 0.01)

Time for epoch 1 is 3.52024507522583 sec
[0.4438564 1.1600479 1.3298092 1.37583  ]
Time for epoch 2 is 3.163336992263794 sec
[0.3555892 1.2473421 1.4249778 1.4718797]
Time for epoch 3 is 3.1986138820648193 sec
[0.2922245 1.3416042 1.525468  1.5727109]
Time for epoch 4 is 3.188176155090332 sec
[0.2723909 1.4309354 1.6241009 1.6750035]
Time for epoch 5 is 3.1690871715545654 sec
[0.3037954 1.5044838 1.7115577 1.7684765]
Time for epoch 6 is 3.166724920272827 sec
[0.4063912 1.5771102 1.806678  1.8742937]
Time for epoch 7 is 3.3385846614837646 sec
[0.5362746 1.6321973 1.8983239 1.9768611]
Time for epoch 8 is 3.68928599357605 sec
[0.6713016 1.6723317 1.9910977 2.076459 ]
Time for epoch 9 is 3.490811824798584 sec
[0.80498004 1.6982911  2.099583   2.1841578 ]
Time for epoch 10 is 3.5169661045074463 sec
[0.9095998 1.6906655 2.2086623 2.2905853]


In [10]:
wgan.generator.trainable_variables[0].numpy()

array([-0.11961144,  0.19694817,  0.08017154, -0.10297753], dtype=float32)

In [None]:
wgan.discriminator.trainable_variables

In [None]:
# correct test version of model with self defined layers
def build_model():
    a = tf.keras.Input(shape=(4,))
    out = LocationAdd(input_dim=4)(a+5)
    model = tf.keras.Model(inputs=a, outputs=out)
    return model
model = build_model()
model2 = build_model()
print(model.trainable_variables)
print(model2.trainable_variables)
model.compile(optimizer='rmsprop', loss=tf.keras.losses.MeanSquaredError())
model.fit(x=data,y=data, batch_size=1, epochs=100)
print(model.trainable_variables)
print(model2.trainable_variables)

## tf.keras.layers.add can make variables not trainable, below is not correct
# a = tf.keras.Input(shape=(4,))
# b = tf.Variable(initial_value=tf.random_normal_initializer()(shape=(4,)), trainable=True)
# out = tf.keras.layers.add([a+5,b])