In [None]:
import tensorflow as tf
from tensorflow.keras.backend import batch_flatten

import os
from tqdm import tqdm

BATCH_SIZE= 1

optimizer = tf.optimizers.Adam(1e-4)
learning_rate= 1.0,

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train/255
x_test = x_test/255
trainset = tf.data.Dataset.from_tensor_slices(x_train)
testset = tf.data.Dataset.from_tensor_slices(x_test)

trainset = trainset.batch(BATCH_SIZE)

class Vqvae1(tf.keras.Model):
    def __init__(self):
        super(Vqvae1, self).__init__()

        self.encoder = tf.keras.Sequential(
                [
                tf.keras.layers.Flatten(),
                tf.keras.layers.Dense(4),
                tf.keras.layers.Dense(2),
                ]
                )

        self.decoder = tf.keras.Sequential(
                [
                tf.keras.layers.Dense(4),
                tf.keras.layers.Dense(28*28),
                tf.keras.layers.Reshape(target_shape=(28, 28)),
                ]
                )

    @tf.function
    def encode(self, x):
        return self.encoder(x)

    @tf.function
    def decode(self, z):
        return self.decoder(z)

    @tf.function
    def saver(self, tag):
        directory = './saved/{0}'.format(tag)
        if not os.path.exists(directory):
            os.mkdir(directory)
        self.encoder.save(directory+'/inf', save_format='h5')
        self.decoder.save(directory+'/gen', save_format='h5')

@tf.function
def mse(input, output):
    #flatten the tensors, maintaining batch dim
    return tf.losses.MSE(batch_flatten(input), batch_flatten(output))

@tf.function
def train_step(input, model):
  
    with tf.GradientTape() as tape:
        z = model.encode(input)
        output = model.decode(z)
        loss = mse(input, output)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

def train(model, trainset):
    end = x_train.shape[0]
    with tqdm(total = end) as pbar:
        for batch in tqdm(trainset):
            train_step(batch, model)
            pbar.update(BATCH_SIZE)
            self.encoder.save_weights('Vqvae2_encoder_weights%04d.hdf5'%self.batch)
            self.decoder.save_weights('Vqvae2_decoder_weights%04d.hdf5'%self.batch)

if __name__ == "__main__":
    
    model1 = Vqvae1()

    train(model1, trainset)

  0%|          | 0/60000 [00:00<?, ?it/s]
  0%|          | 1/60000 [00:01<31:56:48,  1.92s/it]
  0%|          | 84/60000 [00:02<17:12, 58.04it/s]  
  0%|          | 161/60000 [00:02<08:10, 122.05it/s]
  0%|          | 236/60000 [00:02<05:08, 193.43it/s]
  1%|          | 313/60000 [00:02<03:37, 274.04it/s]
  1%|          | 384/60000 [00:02<02:53, 343.43it/s]
  1%|          | 454/60000 [00:02<02:24, 410.70it/s]
  1%|          | 526/60000 [00:02<02:04, 476.30it/s]
  1%|          | 603/60000 [00:02<01:49, 543.85it/s]
  1%|          | 682/60000 [00:02<01:38, 605.19it/s]
  1%|▏         | 762/60000 [00:02<01:30, 655.69it/s]
  1%|▏         | 838/60000 [00:03<01:27, 678.60it/s]
  2%|▏         | 914/60000 [00:03<01:24, 695.94it/s]
  2%|▏         | 990/60000 [00:03<01:22, 712.97it/s]
  2%|▏         | 1070/60000 [00:03<01:20, 736.05it/s]
  2%|▏         | 1151/60000 [00:03<01:17, 756.68it/s]
  2%|▏         | 1229/60000 [00:03<01:17, 761.52it/s]
  2%|▏         | 1307/60000 [00:03<01:16, 762.48it/s]


In [None]:
import tensorflow as tf
from tensorflow.keras.backend import batch_flatten

import os
from tqdm import tqdm

BATCH_SIZE= 10

optimizer = tf.optimizers.Adam(1e-4)
learning_rate= 1.0,

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train/255
x_test = x_test/255
trainset = tf.data.Dataset.from_tensor_slices(x_train)
testset = tf.data.Dataset.from_tensor_slices(x_test)

trainset = trainset.batch(BATCH_SIZE)

class Vqvae2(tf.keras.Model):
    def __init__(self):
        super(Vqvae2, self).__init__()

        self.encoder = tf.keras.Sequential(
                [
                tf.keras.layers.Flatten(),
                tf.keras.layers.Dense(4),
                tf.keras.layers.Dense(2),
                ]
                )

        self.decoder = tf.keras.Sequential(
                [
                tf.keras.layers.Dense(4),
                tf.keras.layers.Dense(28*28),
                tf.keras.layers.Reshape(target_shape=(28, 28)),
                ]
                )

    @tf.function
    def encode(self, x):
        return self.encoder(x)

    @tf.function
    def decode(self, z):
        return self.decoder(z)

    @tf.function
    def saver(self, tag):
        directory = './saved/{0}'.format(tag)
        if not os.path.exists(directory):
            os.mkdir(directory)
        self.encoder.save(directory+'/inf', save_format='h5')
        self.decoder.save(directory+'/gen', save_format='h5')

@tf.function
def mse(input, output):
    #flatten the tensors, maintaining batch dim
    return tf.losses.MSE(batch_flatten(input), batch_flatten(output))

@tf.function
def train_step(input, model):
  
    with tf.GradientTape() as tape:
        z = model.encode(input)
        output = model.decode(z)
        loss = mse(input, output)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

def train(model, trainset):
    end = x_train.shape[0]
    with tqdm(total = end) as pbar:
        for batch in tqdm(trainset):
            train_step(batch, model)
            pbar.update(BATCH_SIZE)
            self.encoder.save_weights('Vqvae2_encoder_weights%04d.hdf5'%self.batch)
            self.decoder.save_weights('Vqvae2_decoder_weights%04d.hdf5'%self.batch)

if __name__ == "__main__":
    
    model2 = Vqvae2()

    train(model2, trainset)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


  0%|          | 0/60000 [00:00<?, ?it/s]
  0%|          | 10/60000 [00:01<1:50:09,  9.08it/s]
  1%|▏         | 800/60000 [00:01<01:04, 913.95it/s]
  3%|▎         | 1560/60000 [00:01<00:31, 1851.51it/s]
  4%|▍         | 2270/60000 [00:01<00:21, 2726.35it/s]
  5%|▌         | 3080/60000 [00:01<00:15, 3761.05it/s]
  5%|▌         | 308/6000 [00:01<00:15, 368.06it/s][A
  6%|▋         | 3780/60000 [00:01<00:14, 3920.20it/s]
  7%|▋         | 4390/60000 [00:01<00:13, 4181.02it/s]
  8%|▊         | 4970/60000 [00:01<00:12, 4471.85it/s]
 10%|▉         | 5780/60000 [00:02<00:10, 5324.21it/s]
 11%|█         | 6610/60000 [00:02<00:08, 6079.31it/s]
 12%|█▏        | 7400/60000 [00:02<00:08, 6527.03it/s]
 14%|█▎        | 8230/60000 [00:02<00:07, 6942.72it/s]
 15%|█▌        | 9050/60000 [00:02<00:06, 7284.56it/s]
 16%|█▋        | 9830/60000 [00:02<00:06, 7388.81it/s]
 18%|█▊        | 10600/60000 [00:02<00:06, 7383.93it/s]
 19%|█▉        | 11410/60000 [00:02<00:06, 7567.05it/s]
 20%|██        | 12190/60