In [1]:
%reload_ext autoreload
%autoreload 2

import os  # noqa: E402
os.environ['KERAS_BACKEND'] = 'torch'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import keras  # noqa: E402
keras.config.disable_traceback_filtering()

In [3]:
import numpy as np
from src.multimodal.models.conditional_gan import ConditionalGAN

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print('x_train shape:', x_train.shape)
print('y_train shape:', y_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')


num_classes = 10
input_shape = (28, 28, 1)

model = keras.Sequential(
    [
        keras.layers.Input(shape=input_shape),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Conv2D(128, kernel_size=(3, 3), activation='relu'),
        keras.layers.Conv2D(128, kernel_size=(3, 3), activation='relu'),
        keras.layers.GlobalAveragePooling2D(),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(num_classes, activation='softmax'),
    ]
)

# model = ConditionalGAN()
model.to('cuda')

model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=[
        keras.metrics.SparseCategoricalAccuracy(name="acc"),
    ],
)

batch_size = 128
epochs = 20

callbacks = [
    keras.callbacks.ModelCheckpoint(filepath='keras_logs/model_at_epoch_{epoch}.keras'),
    keras.callbacks.EarlyStopping(monitor='val_loss', patience=2),
]

model.fit(
    x_train,
    y_train,
    batch_size=batch_size,
    epochs=epochs,
    validation_split=0.15,
    callbacks=callbacks,
)
score = model.evaluate(x_test, y_test, verbose=0)

x_train shape: (60000, 28, 28, 1)
y_train shape: (60000,)
60000 train samples
10000 test samples
Epoch 1/20
[1m399/399[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 25ms/step - acc: 0.5353 - loss: 1.3032 - val_acc: 0.9609 - val_loss: 0.1337
Epoch 2/20
[1m399/399[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 25ms/step - acc: 0.9309 - loss: 0.2309 - val_acc: 0.9797 - val_loss: 0.0759
Epoch 3/20
[1m399/399[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 25ms/step - acc: 0.9542 - loss: 0.1563 - val_acc: 0.9832 - val_loss: 0.0603
Epoch 4/20
[1m399/399[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 25ms/step - acc: 0.9628 - loss: 0.1283 - val_acc: 0.9879 - val_loss: 0.0448
Epoch 5/20
[1m399/399[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 25ms/step - acc: 0.9697 - loss: 0.1007 - val_acc: 0.9834 - val_loss: 0.0524
Epoch 6/20
[1m399/399[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 25ms/step - acc: 0.9721 - loss: 0.0951 - val_acc: 0.989