In [None]:
import random
import numpy as np
import matplotlib.pyplot as plt
import os

from PIL import Image
from keras.datasets import mnist
from IPython.display import Image as IPImage

from neuralnetlib.preprocessing import one_hot_encode
from neuralnetlib.models import Sequential, GAN
from neuralnetlib.layers import Input, Dense, Conv2D, Reshape, Flatten, UpSampling2D

In [None]:
# Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
n_classes = np.unique(y_train).shape[0]

# Reshape images to include channel dimension
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

# Normalize pixel values
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# Labels to categorical
y_train = one_hot_encode(y_train, n_classes)
y_test = one_hot_encode(y_test, n_classes)

In [None]:
i = random.randint(0, len(x_train) - 1)
plt.imshow(x_train[i].reshape(28,28), cmap='gray')
plt.title('Class: ' + str(np.argmax(y_train[i])))
plt.show()

In [None]:
noise_dim = 32

generator = Sequential()
generator.add(Input(noise_dim))
generator.add(Dense(7 * 7 * 128))
generator.add(Reshape((7, 7, 128)))
generator.add(UpSampling2D(size=(2, 2)))  # 14x14
generator.add(Conv2D(64, kernel_size=3, padding='same', activation='relu'))
generator.add(UpSampling2D(size=(2, 2)))  # 28x28
generator.add(Conv2D(32, kernel_size=3, padding='same', activation='relu'))
generator.add(Conv2D(1, kernel_size=3, padding='same', activation='sigmoid'))

In [None]:
discriminator = Sequential()
discriminator.add(Input((28, 28, 1)))
discriminator.add(Conv2D(32, kernel_size=3, strides=2, padding='same', activation='relu'))  # 14x14
discriminator.add(Conv2D(64, kernel_size=3, strides=2, padding='same', activation='relu'))  # 7x7
discriminator.add(Flatten())
discriminator.add(Dense(128, activation='relu'))
discriminator.add(Dense(1, activation='sigmoid'))

In [None]:
gan = GAN(latent_dim=noise_dim)

gan.compile(
    generator,
    discriminator,
    generator_optimizer='adam',
    discriminator_optimizer='adam',
    loss_function='bce',
    verbose=True
)

In [None]:
history = gan.fit(x_train,
                 epochs=40,
                 batch_size=128,
                 plot_generated=True,
                 )   

In [None]:
image_files = [f for f in os.listdir() if f.endswith('.png') and f.startswith('video')]
image_files.sort(key=lambda x: int(x.replace('video', '').replace('.png', '')))

images = [Image.open(img) for img in image_files]

if images:
    images[0].save('output.gif', save_all=True, append_images=images[1:], duration=100, loop=0)

print("GIF 'output.gif' succesffuly created!")

In [None]:
IPImage(filename="output.gif")