In [16]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.sparsity.keras import UpdatePruningStep

In [17]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

In [18]:
def create_cnn():
    model = models.Sequential([
        layers.Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)),
        layers.MaxPooling2D(2,2),
        layers.Conv2D(64, (3,3), activation='relu'),
        layers.MaxPooling2D(2,2),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

In [19]:
def build_generator():
    noise = tf.keras.layers.Input(shape=(100,))
    x = tf.keras.layers.Dense(256, activation='relu')(noise)
    x = tf.keras.layers.Dense(784, activation='tanh')(x)
    x = tf.keras.layers.Reshape((28, 28, 1))(x)
    return tf.keras.models.Model(noise, x)

In [20]:
def build_critic():
    image = tf.keras.layers.Input(shape=(28, 28, 1))
    x = tf.keras.layers.Flatten()(image)
    x = tf.keras.layers.Dense(512, activation='relu')(x)
    x = tf.keras.layers.Dense(256, activation='relu')(x)
    x = tf.keras.layers.Dense(1)(x)  # No sigmoid!
    return tf.keras.models.Model(image, x)

In [21]:
def train_wgan(generator, critic, epochs=10000, batch_size=64, clip_value=0.01, n_critic=5):
    # Optimizers (use RMSprop in vanilla WGAN)
    optimizer = tf.keras.optimizers.legacy.RMSprop(learning_rate=0.00005)

    for epoch in range(epochs):
        for _ in range(n_critic):
            # Train Critic
            noise = np.random.normal(0, 1, (batch_size, 100))
            generated_images = generator.predict(noise)
            real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]

            with tf.GradientTape() as tape:
                real_output = critic(real_images, training=True)
                fake_output = critic(generated_images, training=True)
                d_loss = -tf.reduce_mean(real_output) + tf.reduce_mean(fake_output)

            grads = tape.gradient(d_loss, critic.trainable_variables)
            optimizer.apply_gradients(zip(grads, critic.trainable_variables))

            # Weight clipping
            for var in critic.trainable_variables:
                var.assign(tf.clip_by_value(var, -clip_value, clip_value))

        # Train Generator
        noise = np.random.normal(0, 1, (batch_size, 100))
        with tf.GradientTape() as tape:
            generated_images = generator(noise, training=True)
            fake_output = critic(generated_images, training=True)
            g_loss = -tf.reduce_mean(fake_output)

        grads = tape.gradient(g_loss, generator.trainable_variables)
        optimizer.apply_gradients(zip(grads, generator.trainable_variables))

        # Log
        if epoch % 1000 == 0:
            print(f"Epoch {epoch} | Critic Loss: {d_loss.numpy()} | Generator Loss: {g_loss.numpy()}")


In [22]:
generator = build_generator()
critic = build_critic()

# Train the DCGAN
train_wgan(generator, critic, epochs=1000, batch_size=8, clip_value=0.01, n_critic=5)

Epoch 0 | Critic Loss: -0.0016213785856962204 | Generator Loss: 2.719638723647222e-05


In [23]:
def generate_images(generator, num_samples=10000):
    noise = np.random.normal(0, 1, (num_samples, 100))
    generated_images = generator.predict(noise)
    return np.concatenate((x_train, generated_images)), np.concatenate((y_train, y_train[:num_samples]))

In [24]:
# Train CNN on original data
cnn = create_cnn()
history_orig = cnn.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [25]:
# Train CNN on augmented data
# Generate new images after training and augment the training data
z_train_aug, y_train_aug = generate_images(generator, num_samples=10000)
cnn_aug = create_cnn()
history_aug = cnn_aug.fit(z_train_aug, y_train_aug, epochs=5, validation_data=(x_test, y_test))

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [None]:
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.0, final_sparsity=0.5, 
        begin_step=0, end_step=np.ceil(len(x_train) / 32).astype(np.int32) * 5)
}

# Apply pruning to the model
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(cnn, **pruning_params)

# Unfreeze some layers (for fine-tuning) if necessary
for layer in pruned_model.layers[-4:]:  # Example: Unfreeze the last 4 layers
    layer.trainable = True

# Recompile the model after unfreezing layers
pruned_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Create the pruning callback
pruning_callback = UpdatePruningStep()

# Fine-tune the model
history_pruned_aug = pruned_model.fit(z_train_aug, y_train_aug, epochs=5, validation_data=(x_test, y_test), callbacks=[pruning_callback])

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5

In [None]:
plt.figure(figsize=(12, 6))

# Plot validation accuracy for each model
plt.plot(history_orig.history['val_accuracy'], label='Original')
plt.plot(history_aug.history['val_accuracy'], label='Model with WGAN Augmentation')
plt.plot(history_pruned_aug.history['val_accuracy'], label='Pruned with WGAN Augmentation')

# Labeling the axes and adding a title
plt.xlabel('Epochs')
plt.ylabel('Validation Accuracy')
plt.title('Comparison of Validation Accuracy Across Models')

# Adding a legend to differentiate the lines
plt.legend()

# Display the plot
plt.show()