In [None]:
'''
Initial code was copied from:
https://github.com/jason71995/Keras-GAN-Library
'''

In [33]:
import tensorflow as tf
import keras.backend as K
from keras.models import Sequential
from keras.layers import Conv2D,GlobalAveragePooling2D,LeakyReLU,Conv2DTranspose,Activation,BatchNormalization
from keras.optimizers import Adam

from keras.layers import Dense
from keras import initializers
from keras.layers.core import Dropout



noise_dim = 5

def build_generator(input_shape):

    generator = Sequential()
    
    generator.add(Dense(128,
                        input_dim = noise_dim))
    generator.add(BatchNormalization())
    generator.add(LeakyReLU(0.2))
    
#     generator.add(Dense(512))
#     generator.add(BatchNormalization())
#     generator.add(LeakyReLU(0.2))
    

    


    
    generator.add(Dense(105, activation='tanh'))
    return generator


def build_discriminator(input_shape):

    discriminator = Sequential()
    discriminator.add(Dense(128,
                            input_dim=105))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))
    
    
#     discriminator.add(Dense(512))
#     discriminator.add(LeakyReLU(0.2))
#     discriminator.add(Dropout(0.3))
    

        

    
    discriminator.add(Dense(1, activation='sigmoid'))

    return discriminator

def build_functions(batch_size, noise_size, image_size, generator, discriminator):

    noise = K.random_normal((batch_size,) + noise_size,0.0,1.0,"float32")
    real_image = K.placeholder((batch_size,) + image_size)

    fake_image = generator(noise)

    d_input = K.concatenate([real_image, fake_image], axis=0)
    pred_real, pred_fake = tf.split(discriminator(d_input), num_or_size_splits = 2, axis = 0)

    pred_real = K.clip(pred_real,K.epsilon(),1-K.epsilon())
    pred_fake = K.clip(pred_fake,K.epsilon(),1-K.epsilon())

    d_loss = -(K.mean(K.log(pred_real)) + K.mean(K.log(1-pred_fake)))
    g_loss = -K.mean(K.log(pred_fake))

    # get updates of mean and variance in batch normalization layers
    d_updates = discriminator.get_updates_for([d_input])
    g_updates = generator.get_updates_for([noise])

    d_training_updates = Adam(lr=0.0001, beta_1=0.0, beta_2=0.9).get_updates(d_loss, discriminator.trainable_weights)
    d_train = K.function([real_image, K.learning_phase()], [d_loss],d_updates + d_training_updates)

    g_training_updates = Adam(lr=0.0001, beta_1=0.0, beta_2=0.9).get_updates(g_loss, generator.trainable_weights)
    g_train = K.function([real_image, K.learning_phase()], [g_loss], g_updates + g_training_updates)

    return d_train,g_train

In [3]:
# Loading MHAD data for action1, all persons and all repeatations of each person
from utils.data_loader import data_loader
data_object= data_loader(matlab_action_path='../gan/')
myData, mymin, mymax = data_object.actions_normalised([1], twoD_true_or_threeD_false=False)
myData.shape

(117562, 105)

In [35]:
import keras

#from gan_libs.DCGAN import build_generator, build_discriminator, build_functions
# from gan_libs.LSGAN import build_generator, build_discriminator, build_functions
# from gan_libs.SNGAN import build_generator, build_discriminator, build_functions
# from gan_libs.WGAN_GP import build_generator, build_discriminator, build_functions

from utils.common import set_gpu_config, predict_images
from utils.draw_pose import draw_pose
import numpy as np

#set_gpu_config("0",0.5)

epoch = 30 + 1
image_size = (1,1,105)
noise_size = (1,1,5)
batch_size = 16

x_train = myData
np.random.shuffle(x_train)

generator = build_generator(noise_size)
#print(generator.summary())
discriminator = build_discriminator(image_size)
#print(discriminator.summary())
d_train, g_train = build_functions(batch_size, noise_size, image_size, generator, discriminator)

# generator.load_weights("e25_generator.h5".format(e))
# discriminator.load_weights("e25_discriminator.h5".format(e))


number_of_all_data = x_train.shape[0]
number_of_batches = int(number_of_all_data/batch_size)
print('Number of Batches passed in each epoch: ',number_of_batches)

import time
start_time = time.time()
is_first_epoch = True

for e in range(epoch):
    index = 0
    for batch in range(number_of_batches):
        real_images = x_train[index:index+batch_size]
        index =+ batch_size
        real_images.shape = (batch_size,1,1,105)
        d_loss, = d_train([real_images, 1])
        g_loss, = g_train([real_images, 1])
        if np.random.randint(low = 0, high = 100) == 1:   
            print ("[{0}/{1}] d_loss: {2:.4}, g_loss: {3:.4}".format(e, epoch, d_loss, g_loss))
            #generating a sample
            image = generator.predict(np.zeros(shape=(1,5)))
            image = np.array(image)
            draw_pose(image.reshape(105),'output',"e{0}_batch{1}".format(e,batch))
    
    if(is_first_epoch):
        elapsed_time = time.time() - start_time 
        print('\n\nTime Taken for single epoch:')
        print(time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))
        is_first_epoch = False
    
    if e % 30 == 0 and e > 0:
        generator.save_weights("e{0}_generator.h5".format(e))
        discriminator.save_weights("e{0}_discriminator.h5".format(e))
 
#just monitoring:
##########################################################################################
elapsed_time = time.time() - start_time 
print('\n\n\n\nNumber of parameter for the Generator and discriminator respectively:\n')
print(generator.count_params())
print('')
print(discriminator.count_params())
print('\n\nNumber of Epochs and steps for each epoch:\n')
print('epochs: ',epoch, '   batches: ', number_of_batches)

print('\n\nTime Taken:')
print(time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))
##########################################################################################
#NOTE: previously in each epoch 1000 steps were iterated. in each iteration data was permitated and a batch was chosen. 
#now data is shuffled first and there are number_of_data/batch_size steps to be passed in a single epoch. 
# since 200 epochs and 1000 steps each was used, now 25 epochs and ˜7300 steps is used so the balance is kept.

Number of Batches passed in each epoch:  7347
[0/31] d_loss: 0.9363, g_loss: 0.7686
[0/31] d_loss: 0.9352, g_loss: 0.6858
[0/31] d_loss: 0.8258, g_loss: 0.7834
[0/31] d_loss: 0.836, g_loss: 0.8136
[0/31] d_loss: 0.9052, g_loss: 0.7502
[0/31] d_loss: 0.9916, g_loss: 0.7573
[0/31] d_loss: 1.008, g_loss: 0.7738
[0/31] d_loss: 0.969, g_loss: 0.7815
[0/31] d_loss: 1.163, g_loss: 0.7907
[0/31] d_loss: 1.228, g_loss: 0.8195
[0/31] d_loss: 1.121, g_loss: 0.9573
[0/31] d_loss: 1.28, g_loss: 0.9008
[0/31] d_loss: 1.334, g_loss: 0.909
[0/31] d_loss: 1.269, g_loss: 0.9739
[0/31] d_loss: 1.274, g_loss: 1.005
[0/31] d_loss: 1.241, g_loss: 1.0
[0/31] d_loss: 1.3, g_loss: 1.027
[0/31] d_loss: 1.283, g_loss: 0.9336
[0/31] d_loss: 1.253, g_loss: 0.9775
[0/31] d_loss: 1.292, g_loss: 0.9004
[0/31] d_loss: 1.31, g_loss: 0.9835
[0/31] d_loss: 1.196, g_loss: 0.9697
[0/31] d_loss: 1.168, g_loss: 1.024
[0/31] d_loss: 1.228, g_loss: 1.052
[0/31] d_loss: 1.193, g_loss: 0.9878
[0/31] d_loss: 1.176, g_loss: 1.024


[4/31] d_loss: 1.274, g_loss: 0.7523
[4/31] d_loss: 1.293, g_loss: 0.7373
[4/31] d_loss: 1.307, g_loss: 0.8122
[4/31] d_loss: 1.3, g_loss: 0.7915
[4/31] d_loss: 1.236, g_loss: 0.8204
[4/31] d_loss: 1.343, g_loss: 0.7867
[4/31] d_loss: 1.304, g_loss: 0.7912
[4/31] d_loss: 1.233, g_loss: 0.822
[4/31] d_loss: 1.281, g_loss: 0.7261
[4/31] d_loss: 1.312, g_loss: 0.7523
[4/31] d_loss: 1.271, g_loss: 0.8044
[4/31] d_loss: 1.396, g_loss: 0.7346
[4/31] d_loss: 1.245, g_loss: 0.8004
[4/31] d_loss: 1.241, g_loss: 0.8101
[4/31] d_loss: 1.325, g_loss: 0.7725
[4/31] d_loss: 1.257, g_loss: 0.8144
[4/31] d_loss: 1.292, g_loss: 0.7594
[4/31] d_loss: 1.255, g_loss: 0.8376
[4/31] d_loss: 1.319, g_loss: 0.8912
[4/31] d_loss: 1.282, g_loss: 0.7506
[5/31] d_loss: 1.305, g_loss: 0.7519
[5/31] d_loss: 1.293, g_loss: 0.7696
[5/31] d_loss: 1.264, g_loss: 0.734
[5/31] d_loss: 1.274, g_loss: 0.8102
[5/31] d_loss: 1.258, g_loss: 0.7847
[5/31] d_loss: 1.287, g_loss: 0.8274
[5/31] d_loss: 1.305, g_loss: 0.7166
[5/31

[7/31] d_loss: 1.249, g_loss: 0.9277
[8/31] d_loss: 1.194, g_loss: 0.9014
[8/31] d_loss: 1.205, g_loss: 0.877
[8/31] d_loss: 1.212, g_loss: 0.8481
[8/31] d_loss: 1.244, g_loss: 0.9938
[8/31] d_loss: 1.223, g_loss: 0.8326
[8/31] d_loss: 1.28, g_loss: 0.8587
[8/31] d_loss: 1.17, g_loss: 0.6742
[8/31] d_loss: 1.206, g_loss: 0.9813
[8/31] d_loss: 1.143, g_loss: 0.7806
[8/31] d_loss: 1.222, g_loss: 0.9713
[8/31] d_loss: 1.202, g_loss: 0.9147
[8/31] d_loss: 1.281, g_loss: 1.002
[8/31] d_loss: 1.212, g_loss: 0.8013
[8/31] d_loss: 1.217, g_loss: 0.854
[8/31] d_loss: 1.125, g_loss: 0.893
[8/31] d_loss: 1.319, g_loss: 0.7781
[8/31] d_loss: 1.222, g_loss: 0.7796
[8/31] d_loss: 1.182, g_loss: 0.8683
[8/31] d_loss: 1.152, g_loss: 0.874
[8/31] d_loss: 1.175, g_loss: 0.8353
[8/31] d_loss: 1.179, g_loss: 0.8328
[8/31] d_loss: 1.24, g_loss: 0.7503
[8/31] d_loss: 1.252, g_loss: 0.8175
[8/31] d_loss: 1.276, g_loss: 0.8656
[8/31] d_loss: 1.176, g_loss: 0.8899
[8/31] d_loss: 1.222, g_loss: 0.7926
[8/31] d_

[10/31] d_loss: 1.342, g_loss: 0.9504
[10/31] d_loss: 1.27, g_loss: 0.938
[11/31] d_loss: 1.204, g_loss: 1.03
[11/31] d_loss: 1.125, g_loss: 0.8553
[11/31] d_loss: 1.093, g_loss: 0.8413
[11/31] d_loss: 1.085, g_loss: 1.056
[11/31] d_loss: 1.149, g_loss: 0.886
[11/31] d_loss: 1.251, g_loss: 1.086
[11/31] d_loss: 1.133, g_loss: 1.171
[11/31] d_loss: 1.186, g_loss: 0.9828
[11/31] d_loss: 1.083, g_loss: 0.9804
[11/31] d_loss: 1.358, g_loss: 0.9308
[11/31] d_loss: 1.251, g_loss: 0.9764
[11/31] d_loss: 1.275, g_loss: 1.258
[11/31] d_loss: 1.13, g_loss: 0.7556
[11/31] d_loss: 1.118, g_loss: 0.7682
[11/31] d_loss: 1.109, g_loss: 0.8736
[11/31] d_loss: 1.082, g_loss: 1.018
[11/31] d_loss: 1.215, g_loss: 1.015
[11/31] d_loss: 1.173, g_loss: 0.952
[11/31] d_loss: 1.149, g_loss: 0.8843
[11/31] d_loss: 1.086, g_loss: 0.9109
[11/31] d_loss: 1.089, g_loss: 0.7846
[11/31] d_loss: 1.217, g_loss: 0.8489
[11/31] d_loss: 1.302, g_loss: 1.002
[11/31] d_loss: 1.068, g_loss: 1.039
[11/31] d_loss: 1.039, g_lo

[14/31] d_loss: 1.044, g_loss: 1.179
[14/31] d_loss: 1.123, g_loss: 0.8679
[14/31] d_loss: 1.138, g_loss: 0.8478
[14/31] d_loss: 1.173, g_loss: 1.197
[14/31] d_loss: 1.083, g_loss: 0.7456
[14/31] d_loss: 1.034, g_loss: 1.138
[14/31] d_loss: 1.056, g_loss: 0.7937
[14/31] d_loss: 1.253, g_loss: 1.443
[14/31] d_loss: 1.088, g_loss: 0.7752
[14/31] d_loss: 1.092, g_loss: 0.8937
[14/31] d_loss: 1.14, g_loss: 0.6488
[14/31] d_loss: 1.034, g_loss: 0.8801
[14/31] d_loss: 1.234, g_loss: 0.883
[14/31] d_loss: 1.168, g_loss: 0.9667
[14/31] d_loss: 1.055, g_loss: 1.014
[14/31] d_loss: 1.324, g_loss: 1.211
[14/31] d_loss: 1.298, g_loss: 1.232
[14/31] d_loss: 1.014, g_loss: 0.9984
[14/31] d_loss: 1.001, g_loss: 1.004
[14/31] d_loss: 1.151, g_loss: 1.152
[14/31] d_loss: 1.074, g_loss: 0.8722
[14/31] d_loss: 1.111, g_loss: 1.04
[14/31] d_loss: 1.194, g_loss: 0.9597
[14/31] d_loss: 1.064, g_loss: 0.9697
[14/31] d_loss: 1.178, g_loss: 0.8055
[14/31] d_loss: 1.085, g_loss: 1.08
[14/31] d_loss: 0.9136, g_l

[19/31] d_loss: 0.9399, g_loss: 0.9098
[19/31] d_loss: 1.158, g_loss: 1.377
[19/31] d_loss: 1.107, g_loss: 1.212
[19/31] d_loss: 1.007, g_loss: 0.841
[19/31] d_loss: 0.8585, g_loss: 1.149
[19/31] d_loss: 1.178, g_loss: 0.7606
[19/31] d_loss: 1.007, g_loss: 1.295
[19/31] d_loss: 0.8847, g_loss: 1.008
[19/31] d_loss: 0.971, g_loss: 1.126
[19/31] d_loss: 0.9842, g_loss: 1.339
[19/31] d_loss: 1.11, g_loss: 1.025
[19/31] d_loss: 0.9911, g_loss: 1.226
[19/31] d_loss: 1.128, g_loss: 1.335
[19/31] d_loss: 1.027, g_loss: 1.202
[19/31] d_loss: 1.295, g_loss: 1.045
[19/31] d_loss: 1.084, g_loss: 0.9312
[19/31] d_loss: 1.017, g_loss: 1.059
[19/31] d_loss: 1.186, g_loss: 1.176
[19/31] d_loss: 1.241, g_loss: 1.185
[19/31] d_loss: 1.266, g_loss: 1.295
[19/31] d_loss: 1.031, g_loss: 1.268
[19/31] d_loss: 1.135, g_loss: 1.478
[19/31] d_loss: 1.052, g_loss: 0.9422
[19/31] d_loss: 1.183, g_loss: 1.093
[19/31] d_loss: 1.211, g_loss: 1.462
[19/31] d_loss: 1.111, g_loss: 1.184
[19/31] d_loss: 0.9113, g_loss

[21/31] d_loss: 1.02, g_loss: 1.417
[21/31] d_loss: 1.052, g_loss: 0.8299
[21/31] d_loss: 0.9843, g_loss: 1.227
[21/31] d_loss: 1.063, g_loss: 1.505
[21/31] d_loss: 1.068, g_loss: 1.357
[21/31] d_loss: 1.036, g_loss: 0.9884
[21/31] d_loss: 1.198, g_loss: 1.258
[21/31] d_loss: 1.31, g_loss: 1.33
[21/31] d_loss: 1.202, g_loss: 0.7854
[22/31] d_loss: 1.161, g_loss: 1.43
[22/31] d_loss: 1.056, g_loss: 0.9322
[22/31] d_loss: 1.137, g_loss: 0.9848
[22/31] d_loss: 1.0, g_loss: 1.589
[22/31] d_loss: 1.065, g_loss: 1.625
[22/31] d_loss: 1.066, g_loss: 1.345
[22/31] d_loss: 1.115, g_loss: 1.162
[22/31] d_loss: 1.062, g_loss: 0.9506
[22/31] d_loss: 1.042, g_loss: 1.186
[22/31] d_loss: 1.09, g_loss: 1.417
[22/31] d_loss: 0.8563, g_loss: 1.143
[22/31] d_loss: 0.9072, g_loss: 1.276
[22/31] d_loss: 1.076, g_loss: 1.249
[22/31] d_loss: 1.252, g_loss: 1.341
[22/31] d_loss: 0.9024, g_loss: 1.395
[22/31] d_loss: 0.9512, g_loss: 1.009
[22/31] d_loss: 0.8229, g_loss: 1.19
[22/31] d_loss: 0.8935, g_loss: 1.

[24/31] d_loss: 0.9679, g_loss: 1.781
[24/31] d_loss: 1.004, g_loss: 1.391
[24/31] d_loss: 0.9574, g_loss: 1.17
[24/31] d_loss: 1.039, g_loss: 1.155
[25/31] d_loss: 1.218, g_loss: 1.144
[25/31] d_loss: 1.035, g_loss: 1.317
[25/31] d_loss: 0.8819, g_loss: 0.8083
[25/31] d_loss: 0.9569, g_loss: 1.013
[25/31] d_loss: 1.009, g_loss: 1.207
[25/31] d_loss: 1.173, g_loss: 0.7842
[25/31] d_loss: 1.223, g_loss: 1.263
[25/31] d_loss: 1.188, g_loss: 1.503
[25/31] d_loss: 1.04, g_loss: 1.457
[25/31] d_loss: 1.091, g_loss: 0.9753
[25/31] d_loss: 0.9268, g_loss: 1.13
[25/31] d_loss: 1.168, g_loss: 0.9334
[25/31] d_loss: 1.128, g_loss: 0.8974
[25/31] d_loss: 1.081, g_loss: 1.463
[25/31] d_loss: 0.9511, g_loss: 1.153
[25/31] d_loss: 1.214, g_loss: 1.244
[25/31] d_loss: 1.085, g_loss: 1.147
[25/31] d_loss: 1.234, g_loss: 1.039
[25/31] d_loss: 0.8932, g_loss: 1.21
[25/31] d_loss: 0.9643, g_loss: 1.243
[25/31] d_loss: 1.102, g_loss: 1.01
[25/31] d_loss: 1.05, g_loss: 1.01
[25/31] d_loss: 0.863, g_loss: 1

[27/31] d_loss: 1.08, g_loss: 1.225
[27/31] d_loss: 0.8824, g_loss: 0.8855
[27/31] d_loss: 1.14, g_loss: 1.1
[27/31] d_loss: 1.098, g_loss: 1.287
[27/31] d_loss: 0.9196, g_loss: 1.061
[27/31] d_loss: 0.9979, g_loss: 1.475
[27/31] d_loss: 0.9963, g_loss: 1.123
[27/31] d_loss: 1.112, g_loss: 1.043
[27/31] d_loss: 1.124, g_loss: 1.345
[27/31] d_loss: 0.7993, g_loss: 1.139
[27/31] d_loss: 1.028, g_loss: 1.475
[27/31] d_loss: 0.9395, g_loss: 1.228
[27/31] d_loss: 0.8074, g_loss: 0.9929
[27/31] d_loss: 1.082, g_loss: 1.13
[27/31] d_loss: 0.9978, g_loss: 1.47
[27/31] d_loss: 0.96, g_loss: 0.9335
[27/31] d_loss: 0.9494, g_loss: 1.583
[27/31] d_loss: 0.9595, g_loss: 0.899
[28/31] d_loss: 1.186, g_loss: 0.7919
[28/31] d_loss: 0.894, g_loss: 1.253
[28/31] d_loss: 1.066, g_loss: 1.73
[28/31] d_loss: 1.266, g_loss: 1.471
[28/31] d_loss: 1.078, g_loss: 1.44
[28/31] d_loss: 1.183, g_loss: 1.42
[28/31] d_loss: 0.9625, g_loss: 1.478
[28/31] d_loss: 1.034, g_loss: 1.036
[28/31] d_loss: 0.9067, g_loss: 1

[30/31] d_loss: 0.7868, g_loss: 1.379
[30/31] d_loss: 0.7696, g_loss: 1.201
[30/31] d_loss: 1.11, g_loss: 0.9477
[30/31] d_loss: 1.069, g_loss: 0.7453
[30/31] d_loss: 1.024, g_loss: 2.184
[30/31] d_loss: 0.9063, g_loss: 1.51
[30/31] d_loss: 0.9751, g_loss: 1.444
[30/31] d_loss: 1.053, g_loss: 1.564
[30/31] d_loss: 1.18, g_loss: 1.362




Number of parameter for the Generator and discriminator respectively:

14825

13697


Number of Epochs and steps for each epoch:

epochs:  31    batches:  7347


Time Taken:
00:31:53


In [18]:
    #gpu details:
    ################----------------------------
from six.moves import cStringIO as StringIO
import gpustat

gpustats = gpustat.new_query()
fp = StringIO()
gpustats.print_formatted(
     fp=fp, no_color=False, show_user=False,
     show_cmd=False, show_pid=False, show_power=False, show_fan_speed=False)

result = fp.getvalue()
print('\n\n')
print(result)
    ################----------------------------




[1m[37mcbrc-All-Series[m  Mon May  6 14:40:03 2019
[36m[0][m [34mGeForce GTX 1080 Ti[m |[1m[31m 53'C[m, [32m  0 %[m | [36m[1m[33m 6283[m / [33m11172[m MB | [1m[30mcbrc[m([33m5961M[m) [1m[30mgdm[m([33m16M[m) [1m[30mgdm[m([33m50M[m) [1m[30mcbrc[m([33m114M[m) [1m[30mcbrc[m([33m136M[m)



In [None]:
generator.load_weights("e400_generator.h5".format(e))
discriminator.load_weights("e400_discriminator.h5".format(e))

In [None]:
for i in range(100):
    x= np.array((i/100,0,0,0,0)).reshape(1,5)
    image = generator.predict(x)
    image = np.array(image)
    draw_pose(image.reshape(105),'output',"e{0}".format(i))