<h1><center>Pixel CNN for RGB Images</center></h1>

[The original ***Pixel Recurrent Neural Networks*** paper](https://arxiv.org/abs/1601.06759)

[The original ***Conditional Image Generation with PixelCNN Decoders*** paper](https://arxiv.org/abs/1606.05328)

[PixelCNN](http://sergeiturukin.com/2017/02/22/pixelcnn.html)

[Gated PixelCNN](http://sergeiturukin.com/2017/02/24/gated-pixelcnn.html)

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm
import os
import io



tfk = tf.keras
tfkl = tf.keras.layers
AUTOTUNE = tf.data.experimental.AUTOTUNE

print(tf.__version__)

2.0.0


# Model

In [2]:
class MaskedConv2D(tfkl.Layer):
    def __init__(self, type, n_colors, filters, kernel_size, strides=1,
                 padding='SAME', name='masked_conv'):
        super(MaskedConv2D, self).__init__(name=name)

        if type not in {'A', 'B'}:
            raise ValueError("MaskedConv2D type should be in (A, B), "
                            f"got {type}")

        self.type = type
        self.n_colors = n_colors
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding

    def build(self, input_shape):
        _, H, W, in_ch = input_shape
        out_ch = self.filters

        if isinstance(self.kernel_size, tuple):
            k_y, k_x = self.kernel_size
        else:
            k_y = self.kernel_size
            k_x = self.kernel_size

        # Instantiate variables
        initializer = tfk.initializers.GlorotUniform()
        self.kernel = tf.Variable(
            initializer((k_y, k_x, in_ch, out_ch), dtype=tf.float32),
            trainable=True,
            aggregation=tf.VariableAggregation.MEAN,
            name='kernel'
        )

        self.bias = tf.Variable(
            initializer((1, 1, 1, out_ch), dtype=tf.float32),
            trainable=True,
            aggregation=tf.VariableAggregation.MEAN,
            name='bias'
        )

        # Create the mask
        mid_x, mid_y = k_x // 2, k_y // 2

        # Number of pixels to keep per row depending on type
        pixels_per_row_A = [k_x] * mid_y + [mid_x] + [0] * (k_y - mid_y - 1)
        pixels_per_row_B = [k_x] * mid_y + [mid_x + 1] + [0] * (k_y - mid_y - 1)
        pixels_per_row_A = tf.expand_dims(pixels_per_row_A, axis=1)
        pixels_per_row_B = tf.expand_dims(pixels_per_row_B, axis=1)

        # Flat 2D masks
        lines = tf.expand_dims(tf.range(k_x), axis=0)
        mask_A = tf.less(lines, pixels_per_row_A)
        mask_B = tf.less(lines, pixels_per_row_B)

        # Expand dims
        in_ch_per_color = in_ch // self.n_colors
        out_ch_per_color = out_ch // self.n_colors
        mask_A = tf.tile(
            mask_A[:, :, None, None],
            [1, 1, in_ch_per_color, out_ch_per_color]
        )
        mask_B = tf.tile(
            mask_B[:, :, None, None],
            [1, 1, in_ch_per_color, out_ch_per_color]
        )
        mask_0 = tf.zeros_like(mask_A, dtype=tf.bool)

        # feature map group : (R, G, B) -> (R, G, B)
        mask_colors = []
        if self.type == 'B':
            # mask patterns : (B, O, O), (B, B, 0), (B, B, B)
            mask_colors = []
            for i in range(self.n_colors):
                masks = [mask_B] * (i+1) + [mask_0] * (self.n_colors-i-1)
                mask_colors.append(tf.concat(masks, axis=2))
        else:  # Apply A or B depending on the color
            # mask patterns : (A, O, O), (B, A, 0), (B, B, A)
            for i in range(self.n_colors):
                masks = [mask_B] * i + [mask_A] + [mask_0] * (self.n_colors-i-1)
                mask_colors.append(tf.concat(masks, axis=2))

        self.mask = tf.concat(mask_colors, axis=3)
        self.mask = tf.cast(self.mask, tf.float32)

    def call(self, x):
        h = tf.nn.conv2d(
            input=x,
            filters=self.kernel * self.mask,
            strides=self.strides,
            padding=self.padding,
        )
        return h + self.bias

class ResidualBlock(tfkl.Layer):
    def __init__(self, n_colors, **kwargs):
        super(ResidualBlock, self).__init__(**kwargs)
        self.n_colors = n_colors

    def build(self, input_shape):
        # input shape (batch_size, height, width, channels)
        hidden_dim = input_shape[-1]

        self.conv1 = MaskedConv2D(
            type='B',
            n_colors=self.n_colors,
            filters=hidden_dim // 2,
            kernel_size=1,
            name='conv1x1_1'
        )

        self.conv2 = MaskedConv2D(
            type='B',
            n_colors=self.n_colors,
            filters=hidden_dim // 2,
            kernel_size=3,
            padding='SAME',
            name='conv3x3'
        )

        self.conv3 = MaskedConv2D(
            type='B',
            n_colors=self.n_colors,
            filters=hidden_dim,
            kernel_size=1,
            name='conv1x1_2'
        )

    def call(self, x):
        # x shape (batch_size, height, width, channels)
        h = self.conv1(tf.nn.relu(x))
        h = self.conv2(tf.nn.relu(h))
        h = self.conv3(tf.nn.relu(h))
        return x + h

class PixelCNN(tfk.Model):
    def __init__(self, hidden_dim, n_res=5, n_output=256, **kwargs):
        super(PixelCNN, self).__init__(**kwargs)
        self.hidden_dim = hidden_dim
        self.n_res = n_res
        self.n_output = 256  # number of possible pixel values

    def build(self, input_shape):
        # Save image_shape for generation
        self.image_shape = input_shape[1:]

        n_colors = input_shape[-1]
        self.n_colors = n_colors

        self.conv_a = MaskedConv2D(
            type='A',
            n_colors=n_colors,
            kernel_size=7,
            filters=2 * n_colors * self.hidden_dim,
            padding='SAME',
            name='conv_a'
        )

        self.res_blocks = [
            ResidualBlock(n_colors=n_colors, name=f'res_block{i}')
            for i in range(self.n_res)
        ]

        self.conv_b_1 = MaskedConv2D(
            type='B',
            n_colors=n_colors,
            kernel_size=1,
            filters=n_colors * self.n_output,
            name='conv_b_1'
        )

        self.conv_b_2 = MaskedConv2D(
            type='B',
            n_colors=n_colors,
            kernel_size=1,
            filters=n_colors * self.n_output,
            name='conv_b_2'
        )

    def call(self, x):
        h = self.conv_a(x)

        for res_block in self.res_blocks:
            h = res_block(h)

        h = self.conv_b_1(tf.nn.relu(h))
        h = self.conv_b_2(tf.nn.relu(h))

        # Format output
        h = tf.split(h, num_or_size_splits=self.n_colors, axis=-1)
        outputs = tf.stack(h, axis=3)  # (batch_size, height, width, n_colors, n_output)

        return outputs

    def sample(self, n):
        # Sample n images from PixelCNN
        height, width, channels = self.image_shape
        n_pixels = height * width * channels

        logits = tf.ones((n_pixels, self.n_output))
        flat_samples = tf.cast(tf.random.categorical(logits, n), tf.float32)
        samples = tf.reshape(flat_samples, (n, height, width, channels))

        # Sample each pixel sequentially and feed it back
        for pos in tqdm(range(n_pixels), desc="Sampling PixelCNN"):
            c = pos % channels
            h = (pos // channels) // height
            w = (pos // channels) % height
            logits = self(samples)[:, h, w, c]
            updates = tf.squeeze(tf.cast(tf.random.categorical(logits, 1), tf.float32))
            indices = tf.constant([[i, h, w, c] for i in range(n)])
            samples = tf.tensor_scatter_nd_update(samples, indices, updates)

        return samples

def bits_per_dim_loss(y_true, y_pred):
    """Return the bits per dim value of the predicted distribution."""
    B, H, W, C = y_true.shape
    num_pixels = float(H * W * C)
    log_probs = tf.math.log_softmax(y_pred, axis=-1)
    log_probs = tf.gather(log_probs, tf.cast(y_true, tf.int32), axis=-1, batch_dims=4)
    nll = - tf.reduce_sum(log_probs, axis=[1, 2, 3])
    bits_per_dim = nll / num_pixels / tf.math.log(2.)
    return bits_per_dim

# Useful

In [3]:
def plot_to_image(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    # Add the batch dimension
    image = tf.expand_dims(image, 0)
    return image

class PlotSamplesCallback(tfk.callbacks.Callback):
    """Plot `nex` reconstructed image to tensorboard."""
    def __init__(self, logdir: str, nex: int=4):
        super(PlotSamplesCallback, self).__init__()
        logdir = os.path.join(logdir, 'samples')
        self.file_writer = tf.summary.create_file_writer(logdir=logdir)
        self.nex = nex

    def plot_img(self, image):
        fig, ax = plt.subplots(nrows=1, ncols=1)
        image = tf.cast(image, tf.int32)

        if image.shape[-1] == 1:
            image = tf.squeeze(image, axis=-1)

        ax.imshow(image, vmin=0, vmax=255, cmap=plt.cm.Greys)
        ax.axis('off')

        return fig

    def on_epoch_end(self, epoch, logs=None):
        images = self.model.sample(self.nex)

        imgs = []
        for i in range(self.nex):
            fig = self.plot_img(images[i])
            imgs.append(plot_to_image(fig))

        imgs = tf.concat(imgs, axis=0)
        with self.file_writer.as_default():
            tf.summary.image(
                name='Samples',
                data=imgs,
                step=epoch,
                max_outputs=self.nex
            )

# Arguments

In [None]:
class Args:
    epochs = 10
    batch = 64
    buffer = 1024 # Buffer size for shuffling
    dataset = 'mnist' # cifar10 or mnist
    learning_rate = 0.001
    lr_decay = 0.999995
    hidden_dim = 64 # Hidden dimension per channel
    n_res = 4 # Number of res blocks

args=Args()

# Training


In [None]:
# Training parameters
EPOCHS = args.epochs
BATCH_SIZE = args.batch
BUFFER_SIZE = args.buffer  # for shuffling

# Load dataset
dataset, info = tfds.load(args.dataset, with_info=True)
train_ds, test_ds = dataset['train'], dataset['test']

def prepare(element):
    image = element['image']
    image = tf.cast(image, tf.float32)
    # The image is not normalized
    return image

# PixelCNN training requires target = input
def duplicate(element):
    return element, element

train_ds = (train_ds.shuffle(BUFFER_SIZE)
                    .batch(BATCH_SIZE)
                    .map(prepare, num_parallel_calls=AUTOTUNE)
                    .map(duplicate)
                    .prefetch(AUTOTUNE))

test_ds = (test_ds.batch(BATCH_SIZE)
                   .map(prepare, num_parallel_calls=AUTOTUNE)
                   .map(duplicate)
                   .prefetch(AUTOTUNE))

# Define model
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = PixelCNN(
        hidden_dim=args.hidden_dim,
        n_res=args.n_res
    )
    model.compile(optimizer='adam', loss=bits_per_dim_loss)

# Learning rate scheduler
steps_per_epochs = info.splits['train'].num_examples // args.batch
decay_per_epoch = args.lr_decay ** steps_per_epochs
schedule = tfk.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=args.learning_rate,
    decay_rate=decay_per_epoch,
    decay_steps=1
)

# Callbacks
time = datetime.now().strftime('%Y%m%d-%H%M%S')
log_dir = os.path.join('.', 'logs', 'pixelcnn', time)
tensorboard_clbk = tfk.callbacks.TensorBoard(log_dir=log_dir)
sample_clbk = PlotSamplesCallback(logdir=log_dir)
scheduler_clbk = tfk.callbacks.LearningRateScheduler(schedule)
callbacks = [tensorboard_clbk, sample_clbk, scheduler_clbk]

In [None]:
# Fit
model.fit(train_ds, validation_data=test_ds, epochs=EPOCHS, callbacks=callbacks)