In [4]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# data.py

In [1]:
import tensorflow as tf
from tensorflow.keras import datasets

BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 1000

def load_mnist(flatten=True) -> (tf.data.Dataset, tf.data.Dataset):
    """ return train_ds, test_ds """
    (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
    
    if flatten:
        train_images = train_images.reshape((60000, 28 * 28))
        test_images = test_images.reshape((10000, 28 * 28))
    else:
        train_images = train_images.reshape((60000, 28, 28, 1))
        test_images = test_images.reshape((10000, 28, 28, 1))
    
    train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
    test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
    
    train_ds = train_ds.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
    test_ds = test_ds.batch(BATCH_SIZE)
    
    return train_ds, test_ds

In [2]:
train_ds, test_ds = load_mnist(flatten=True)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [3]:
next(iter(train_ds.take(1)))

(<tf.Tensor: shape=(64, 784), dtype=uint8, numpy=
 array([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)>,
 <tf.Tensor: shape=(64,), dtype=uint8, numpy=
 array([0, 8, 0, 6, 8, 7, 6, 1, 9, 5, 1, 1, 3, 2, 3, 2, 4, 0, 3, 8, 9, 5,
        7, 7, 9, 2, 0, 8, 1, 2, 6, 8, 8, 4, 8, 8, 0, 4, 0, 5, 7, 4, 5, 9,
        2, 0, 5, 6, 5, 5, 5, 6, 3, 8, 3, 7, 5, 6, 1, 8, 7, 1, 9, 6],
       dtype=uint8)>)

# AE.py

In [6]:
import os

print(os.path.abspath("."))
cur_dir = os.path.abspath(".")
name = "AE-stacked"
print(os.path.join(cur_dir, f"weights/{name}.tf"))


/tf/tensorflow/models/1.AE
/tf/tensorflow/models/1.AE/weights/AE-stacked


In [8]:
import os
import tensorflow as tf

cur_dir = os.path.abspath(".")

class AE(tf.keras.Model):
    def __init__(self, name):
        super(AE, self).__init__(name=name)
        self.cur_dir = os.path.abspath(".")
        self.filepath = os.path.join(cur_dir, f"weights/{name}.tf")
        
    def save_weight(self):
        try:
            self.save_weights(self.filepath, overwrite=True,save_format="tf")
        except ImportError as e:
            print("h5py is not available and the weight file is in HDF5 format.", e, sep='\n')
    
    def load_weight(self):
        self.load_weights(self.filepath, by_name=False)

# AE-undercomplete.py

In [None]:
n_inputs = 784
n_hidden = 100
n_outputs = n_inputs

class AE_UC(AE):
    def __init__(self):
        super().__init__("AE-undercomplete")
        
        self.hidden = tf.keras.layers.Dense(n_hidden)
        self.outputs = tf.keras.layers.Dense(n_outputs)
        
    def call(self, x):
        x = self.hidden(x)
        return self.outputs(x)

# train.py

In [None]:
for model_class in AE.__subclasses__

In [19]:
ae_uc = AE.__subclasses__()[0]()

In [20]:
ae_uc.name

'AE-undercomplete'