In [1]:
import tensorflow as tf
from tensorflow.keras import layers as tl
from tensorflow.data import Dataset
import tensorflow_datasets as tfds
from pathlib import Path
import numpy as np

In [51]:
ds_tuple, ds_info = tfds.load(
    'fashion_mnist',
    split=['train[:90%]', 'train[90%:]', 'test'],
    as_supervised=True,  # returns `(img, label)` instead of dict(image=, ...)
    with_info=True,
)

print(ds_info)

train_ds_origin, valid_ds_origin, test_ds_origin = ds_tuple

def preprocess_ds(ds, shuffle=False, seed=42):
    ds = ds.map(lambda X, y: (tf.cast(X, tf.float32) / 255., tf.cast(X, tf.float32) / 255.))
    if shuffle:
        ds = ds.shuffle(30000, seed=seed)

    ds = ds.batch(128).prefetch(1)

    return ds
    


train_ds = preprocess_ds(train_ds_origin)
valid_ds = preprocess_ds(valid_ds_origin)
test_ds = preprocess_ds(test_ds_origin)



tfds.core.DatasetInfo(
    name='fashion_mnist',
    full_name='fashion_mnist/3.0.1',
    description="""
    Fashion-MNIST is a dataset of Zalando's article images consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes.
    """,
    homepage='https://github.com/zalandoresearch/fashion-mnist',
    data_dir='/Users/yunhongmin/tensorflow_datasets/fashion_mnist/3.0.1',
    file_format=tfrecord,
    download_size=29.45 MiB,
    dataset_size=36.42 MiB,
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=uint8),
        'label': ClassLabel(shape=(), dtype=int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=60000, num_shards=1>,
    },
    citation="""@article{DBLP:journals/corr/abs-1708-07747,
     

In [52]:
class Sampling(tl.Layer):
    def call(self, inputs):
        mean, log_var = inputs
        return tf.random.normal(tf.shape(mean)) * tf.exp(log_var / 2.) + mean


codings_size = 10

inputs = tl.Input(shape=[28, 28])
Z = tl.Flatten()(inputs)
Z = tl.Dense(150, activation='relu')(Z)
Z = tl.Dense(100, activation='relu')(Z)
codings_mean = tl.Dense(codings_size)(Z) # mean
codings_log_var = tl.Dense(codings_size)(Z) # log variance
codings = Sampling()([codings_mean, codings_log_var])

encoder = tf.keras.Model(
    inputs=[inputs],
    outputs=[codings_mean, codings_log_var, codings]
)



In [53]:
decoder_inputs = tl.Input(shape=[codings_size])
x = tl.Dense(100, activation='relu')(decoder_inputs)
x = tl.Dense(150, activation='relu')(x)
x = tl.Dense(28 * 28)(x)
outputs = tl.Reshape([28, 28])(x)
decoder = tf.keras.Model(
    inputs=[decoder_inputs],
    outputs=[outputs]
)



In [54]:
_, _, codings = encoder(inputs)
reconstructions = decoder(codings)
auto_encoder = tf.keras.Model(inputs=[inputs], outputs=[reconstructions])
latent_loss = -0.5 * tf.reduce_sum(1. + codings_log_var - tf.exp(codings_log_var)- tf.square(codings_mean), axis=-1) / (28 * 28)
latent_loss_for_batch = tf.reduce_mean(latent_loss)
auto_encoder.add_loss(latent_loss_for_batch)

In [55]:
auto_encoder.compile(loss='mse', optimizer='nadam')
history = auto_encoder.fit(train_ds, epochs=25, validation_data=valid_ds)

Epoch 1/25


2023-12-01 12:00:52.326426: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25

KeyboardInterrupt: 