-
Notifications
You must be signed in to change notification settings - Fork 0
/
CelebAGenerator.py
55 lines (45 loc) · 2.77 KB
/
CelebAGenerator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
import matplotlib.pyplot as plt
from keras.layers import Conv2D, BatchNormalization, Activation, Dense, Conv2DTranspose, Input, Lambda, Reshape, Flatten, UpSampling2D, MaxPooling2D
from keras.models import Model, Sequential
import keras.backend as K
from keras import initializers
class CelebAGenerator():
def __init__(self):
self.latent_dim = 100 # Dimension of Latent Representation
self.GAN = None
self.weights_path = './model weights/celeba.h5'
def GenerateModel(self):
gf_dim = 64
gan = Sequential()
gan.add(Dense(8192, use_bias = True, bias_initializer='zeros', input_dim=100))
#gan.add(Reshape([-1,s16,s16,gf_dim*8])) old one
gan.add(Reshape([4,4,gf_dim*8]))
gan.add(BatchNormalization(epsilon = 1e-5,momentum = 0.9,scale = True)) # look into scale if error and axis
gan.add(Activation('relu'))
gan.add(Conv2DTranspose(gf_dim*4, 5, strides = (2,2), padding = 'same', use_bias = True, kernel_initializer = initializers.random_normal(stddev=0.02), bias_initializer = 'zeros')) #see in channel value and std_value for random normal
gan.add(BatchNormalization(epsilon = 1e-5,momentum = 0.9,scale = True)) # look into scale if error and axis
gan.add(Activation('relu'))
gan.add(Conv2DTranspose(gf_dim*2, 5, strides = (2,2), padding = 'same', use_bias = True, kernel_initializer = initializers.random_normal(stddev=0.02), bias_initializer = 'zeros')) #see in channel value and std_value for random normal
gan.add(BatchNormalization(epsilon = 1e-5,momentum = 0.9,scale = True)) # look into scale if error and axis
gan.add(Activation('relu'))
gan.add(Conv2DTranspose(gf_dim*1, 5, strides = (2,2), padding = 'same', use_bias = True, kernel_initializer = initializers.random_normal(stddev=0.02), bias_initializer = 'zeros')) #see in channel value and std_value for random normal
gan.add(BatchNormalization(epsilon = 1e-5,momentum = 0.9,scale = True)) # look into scale if error and axis
gan.add(Activation('relu'))
gan.add(Conv2DTranspose(3, 5, strides = (2,2), padding = 'same', use_bias = True, kernel_initializer = initializers.random_normal(stddev=0.02), bias_initializer = 'zeros')) #see in channel value and std_value for random normal
gan.add(Activation('tanh'))
self.GAN = gan
def LoadWeights(self):
self.GAN.load_weights(self.weights_path)
def GetModels(self):
return self.GAN
if __name__ == '__main__':
celeba = CelebAGenerator()
celeba.GenerateModel()
celeba.LoadWeights()
gan = celeba.GetModels()
for _ in range(100):
pred = gan.predict(np.random.randn(1,100))[0,:,:,:]
pred = (pred+1)/2
plt.imshow(pred)
plt.show()