In [13]:
import os
import argparse
from datetime import datetime
import tensorflow as tf
import tensorflow_datasets as tfds

from model import PixelCNN, bits_per_dim_loss
from utils import PlotSamplesCallback

In [14]:
tfk = tf.keras
tfkl = tf.keras.layers
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [15]:
# Training parameters
EPOCHS = 75
BATCH_SIZE = 64
BUFFER_SIZE = 1024

# Load dataset
dataset, info = tfds.load('mnist', with_info=True)
train_ds, test_ds = dataset['train'], dataset['test']

In [16]:
def prepare(element):
    image = element['image']
    image = tf.cast(image, tf.float32)
    # The image is not normalized
    return image

# PixelCNN training requires target = input
def duplicate(element):
    return element, element

train_ds = (train_ds.shuffle(BUFFER_SIZE)
                    .batch(BATCH_SIZE)
                    .map(prepare, num_parallel_calls=AUTOTUNE)
                    .map(duplicate)
                    .prefetch(AUTOTUNE))

test_ds = (test_ds.batch(BATCH_SIZE)
                   .map(prepare, num_parallel_calls=AUTOTUNE)
                   .map(duplicate)
                   .prefetch(AUTOTUNE))

In [18]:
model = PixelCNN(hidden_dim=64, n_res=4)
model.compile(optimizer='adam', loss=bits_per_dim_loss)

In [19]:
# Learning rate scheduler
steps_per_epochs = info.splits['train'].num_examples // BATCH_SIZE
decay_per_epoch = 0.999995**steps_per_epochs
schedule = tfk.optimizers.schedules.ExponentialDecay(initial_learning_rate=0.001, decay_rate=decay_per_epoch, decay_steps=1)

In [20]:
# Callbacks
time = datetime.now().strftime('%Y%m%d-%H%M%S')
log_dir = os.path.join('.', 'logs', 'pixelcnn', time)
tensorboard_clbk = tfk.callbacks.TensorBoard(log_dir=log_dir)
sample_clbk = PlotSamplesCallback(logdir=log_dir)
scheduler_clbk = tfk.callbacks.LearningRateScheduler(schedule)
callbacks = [tensorboard_clbk, sample_clbk, scheduler_clbk]

In [21]:
model.fit(train_ds, validation_data=test_ds, epochs=EPOCHS, callbacks=callbacks)

Epoch 1/75


ValueError: in user code:

    File "C:\Users\user\miniconda3\envs\tf_2.9\lib\site-packages\keras\engine\training.py", line 1051, in train_function  *
        return step_function(self, iterator)
    File "C:\Users\user\miniconda3\envs\tf_2.9\lib\site-packages\keras\engine\training.py", line 1040, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\user\miniconda3\envs\tf_2.9\lib\site-packages\keras\engine\training.py", line 1030, in run_step  **
        outputs = model.train_step(data)
    File "C:\Users\user\miniconda3\envs\tf_2.9\lib\site-packages\keras\engine\training.py", line 890, in train_step
        loss = self.compute_loss(x, y, y_pred, sample_weight)
    File "C:\Users\user\miniconda3\envs\tf_2.9\lib\site-packages\keras\engine\training.py", line 948, in compute_loss
        return self.compiled_loss(
    File "C:\Users\user\miniconda3\envs\tf_2.9\lib\site-packages\keras\engine\compile_utils.py", line 201, in __call__
        loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    File "C:\Users\user\miniconda3\envs\tf_2.9\lib\site-packages\keras\losses.py", line 140, in __call__
        return losses_utils.compute_weighted_loss(
    File "C:\Users\user\miniconda3\envs\tf_2.9\lib\site-packages\keras\utils\losses_utils.py", line 310, in compute_weighted_loss
        losses = tf.convert_to_tensor(losses)

    ValueError: None values not supported.
