In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization

# Load MNIST data
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = X_train.reshape(-1, 784)

# Generator
generator = Sequential([
    Dense(128, input_dim=100),
    LeakyReLU(0.2),
    BatchNormalization(),
    Dense(256),
    LeakyReLU(0.2),
    BatchNormalization(),
    Dense(512),
    LeakyReLU(0.2),
    BatchNormalization(),
    Dense(784, activation='tanh')
])

# Discriminator
discriminator = Sequential([
    Dense(512, input_dim=784),
    LeakyReLU(0.2),
    Dense(256),
    LeakyReLU(0.2),
    Dense(1, activation='sigmoid')
])

# Compile discriminator
discriminator.compile(loss='binary_crossentropy', optimizer='adam')

# Combined network
discriminator.trainable = False
gan_input = generator.input
gan_output = discriminator(generator(gan_input))
gan = Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')

# Create directory if it doesn't exist
if not os.path.exists('generated_images'):
    os.makedirs('generated_images')

# Training
epochs, batch_size, half_batch = 10000, 64, 32

for epoch in range(epochs):
    idx = np.random.randint(0, X_train.shape[0], half_batch)
    real_images, noise = X_train[idx], np.random.normal(0, 1, (half_batch, 100))

    d_loss_real = discriminator.train_on_batch(real_images, np.ones((half_batch, 1)))
    d_loss_fake = discriminator.train_on_batch(generator.predict(noise), np.zeros((half_batch, 1)))
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    gan_loss = gan.train_on_batch(noise, np.ones((half_batch, 1)))  # Adjusted labels to half_batch

    if epoch % 100 == 0:
        print(f"Epoch: {epoch}, D Loss: {d_loss}, G Loss: {gan_loss}")

    if epoch % 1000 == 0:
        img = generator.predict(np.random.normal(0, 1, (1, 100))).reshape(28, 28)
        plt.imshow(img, cmap='gray')
        plt.axis('off')
        plt.savefig(f"generated_images/gan_image_{epoch}.png")
        plt.close()


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch: 5100, D Loss: 0.3287529945373535, G Loss: 3.7087011337280273
Epoch: 5200, D Loss: 0.20773239247500896, G Loss: 3.8390653133392334
Epoch: 5300, D Loss: 0.3756309449672699, G Loss: 2.7029871940612793
Epoch: 5400, D Loss: 0.22317977249622345, G Loss: 3.6635758876800537
Epoch: 5500, D Loss: 0.17744439840316772, G Loss: 5.642303466796875
Epoch: 5600, D Loss: 0.500893846154213, G Loss: 2.0116899013519287
Epoch: 5700, D Loss: 0.20697693526744843, G Loss: 3.700016975402832
Epoch: 5800, D Loss: 0.15547256916761398, G Loss: 3.6307578086853027
Epoch: 5900, D Loss: 0.5055046677589417, G Loss: 2.5152554512023926
Epoch: 6000, D Loss: 0.7254768908023834, G Loss: 3.307469367980957
Epoch: 6100, D Loss: 0.3960254639387131, G Loss: 2.1069583892822266
Epoch: 6200, D Loss: 0.17806068062782288, G Loss: 3.0512547492980957
Epoch: 6300, D Loss: 0.31383535265922546, G Loss: 3.0124142169952393
Epoch: 6400, D Loss: 0.3351738229393959, G Loss: