In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from utils.models import create_vae_model
from utils.losses import reconstruction_loss
from utils.callbacks import SaveDecoderOutput, SaveDecoderModel

In [3]:
physical_gpus = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_gpus[0], True)

In [4]:
def parse_fn(dataset, input_size=(28, 28)):
    x = tf.cast(dataset['image'], tf.float32)
    x = tf.image.resize(x, input_size)
    x = x / 255.
    return x, x

In [21]:
dataset = 'fashion_mnist'     # 'cifar10', 'fashion_mnist', 'mnist'
log_dirs = 'logs/fashion_mnist'
batch_size = 16
latent_dim = 2
input_shape = (28, 28, 1)

In [22]:
# Load datasets
train_data = tfds.load(dataset, split=tfds.Split.TRAIN)
test_data = tfds.load(dataset, split=tfds.Split.TEST)

[1mDownloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to ~\tensorflow_datasets\fashion_mnist\3.0.1...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling ~\tensorflow_datasets\fashion_mnist\3.0.1.incompleteS5MNWG\fashion_mnist-train.tfrecord*...:   0%|  …

Generating test examples...: 0 examples [00:00, ? examples/s]

Shuffling ~\tensorflow_datasets\fashion_mnist\3.0.1.incompleteS5MNWG\fashion_mnist-test.tfrecord*...:   0%|   …

[1mDataset fashion_mnist downloaded and prepared to ~\tensorflow_datasets\fashion_mnist\3.0.1. Subsequent calls will reuse this data.[0m


In [23]:
# Setting datasets
AUTOTUNE = tf.data.experimental.AUTOTUNE  # 自動調整模式
train_data = train_data.shuffle(1000)
train_data = train_data.map(parse_fn, num_parallel_calls=AUTOTUNE)
train_data = train_data.batch(batch_size)
train_data = train_data.prefetch(buffer_size=AUTOTUNE)
test_data = test_data.map(parse_fn, num_parallel_calls=AUTOTUNE)
test_data = test_data.batch(batch_size)
test_data = test_data.prefetch(buffer_size=AUTOTUNE)

In [24]:
# Callbacks function
model_dir = log_dirs + '/models'
os.makedirs(model_dir, exist_ok=True)
model_tb = keras.callbacks.TensorBoard(log_dir=log_dirs)
model_sdw = SaveDecoderModel(model_dir + '/best_model.h5', monitor='val_loss')
model_testd = SaveDecoderOutput(28, log_dir=log_dirs)

In [25]:
# create vae model
vae_model = create_vae_model(input_shape, latent_dim)

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 28, 28, 32)   320         input_3[0][0]                    
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 14, 14, 64)   18496       conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 7, 7, 64)     36928       conv2d_6[0][0]                   
____________________________________________________________________________________________

In [26]:
# training
optimizer = tf.keras.optimizers.RMSprop()
vae_model.compile(optimizer, loss=reconstruction_loss)
vae_model.fit(train_data, epochs=100, validation_data=test_data, callbacks=[model_tb, model_sdw, model_testd])

Epoch 1/100








Epoch 2/100




Epoch 3/100




Epoch 4/100




Epoch 5/100
Epoch 6/100




Epoch 7/100




Epoch 8/100
Epoch 9/100




Epoch 10/100




Epoch 11/100
Epoch 12/100




Epoch 13/100




Epoch 14/100
Epoch 15/100




Epoch 16/100
Epoch 17/100




Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100




Epoch 26/100




Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100




Epoch 34/100




Epoch 35/100




Epoch 36/100




Epoch 37/100




Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100




Epoch 43/100
Epoch 44/100




Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100




Epoch 69/100
Epoch 70/100
Epoch 71/100




Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 79/100
Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100