Skip to content

Commit d2395e4

Browse files
committed
fixed type error, division returns ints now, not floats
1 parent 991f4cc commit d2395e4

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

gan.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __discriminator(self):
6464
model.add(Flatten(input_shape=self.shape))
6565
model.add(Dense((self.width * self.height * self.channels), input_shape=self.shape))
6666
model.add(LeakyReLU(alpha=0.2))
67-
model.add(Dense((self.width * self.height * self.channels)/2))
67+
model.add(Dense(np.int64((self.width * self.height * self.channels)/2)))
6868
model.add(LeakyReLU(alpha=0.2))
6969
model.add(Dense(1, activation='sigmoid'))
7070
model.summary()
@@ -86,21 +86,21 @@ def train(self, X_train, epochs=20000, batch = 32, save_interval = 100):
8686
for cnt in range(epochs):
8787

8888
## train discriminator
89-
random_index = np.random.randint(0, len(X_train) - batch/2)
90-
legit_images = X_train[random_index : random_index + batch/2].reshape(batch/2, self.width, self.height, self.channels)
89+
random_index = np.random.randint(0, len(X_train) - np.int64(batch/2))
90+
legit_images = X_train[random_index : random_index + np.int64(batch/2)].reshape(np.int64(batch/2), self.width, self.height, self.channels)
9191

92-
gen_noise = np.random.normal(0, 1, (batch/2, 100))
92+
gen_noise = np.random.normal(0, 1, (np.int64(batch/2), 100))
9393
syntetic_images = self.G.predict(gen_noise)
9494

9595
x_combined_batch = np.concatenate((legit_images, syntetic_images))
96-
y_combined_batch = np.concatenate((np.ones((batch/2, 1)), np.zeros((batch/2, 1))))
96+
y_combined_batch = np.concatenate((np.ones((np.int64(batch/2), 1)), np.zeros((np.int64(batch/2), 1))))
9797

9898
d_loss = self.D.train_on_batch(x_combined_batch, y_combined_batch)
9999

100100

101101
# train generator
102102

103-
noise = np.random.normal(0, 1, (batch, 100))
103+
noise = np.random.normal(0, 1, (batch, 10git@github.com:daymos/simple_keras_GAN.git0))
104104
y_mislabled = np.ones((batch, 1))
105105

106106
g_loss = self.stacked_generator_discriminator.train_on_batch(noise, y_mislabled)

0 commit comments

Comments
 (0)