In [None]:
import cv2
import tqdm
import os
os.chdir('/vol/medic01/users/bh1511/PyCharm_Deployment/ResNetAE')
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'   # see issue #152
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import time

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_probability as tfp

import matplotlib.pyplot as plt

from ResNetAE import ResNetVQVAE


In [None]:
##### ChestXray14 ######################################################################################################

# data_root = '/vol/biomedic/users/bh1511/datasets/cxr/ChestXray-NIHCC'
data_root = '/data/datasets/chest_xray/ChestXray-NIHCC/'

train_val_df = pd.read_csv(os.path.join(data_root, 'train_val_list.csv'))
test_df = pd.read_csv(os.path.join(data_root, 'test_list.csv'))

train_images_df = data_root + '/images/' + train_val_df['Image Index']
val_images_df = data_root + '/images/' + test_df['Image Index']


In [None]:
##### HParams ##########################################################################################################

NUM_LATENT_K = 512                # Number of codebook entries
NUM_LATENT_D = 64                 # Dimension of each codebook entries
BETA = 0.25                       # Weight for the commitment loss

INPUT_SHAPE = (256, 256, 1)
SIZE = 16                         # Spatial size of latent embedding

VQVAE_BATCH_SIZE = 16             # Batch size for training the VQVAE
VQVAE_NUM_EPOCHS = 20             # Number of epochs
VQVAE_LEARNING_RATE = 1e-4        # Learning rate

SAVE_DIR = 'checkpoint4'

PIXELCNN_BATCH_SIZE = 128         # Batch size for training the PixelCNN prior
PIXELCNN_NUM_EPOCHS = 100         # Number of epochs
PIXELCNN_LEARNING_RATE = 3e-4     # Learning rate
PIXELCNN_NUM_BLOCKS = 12          # Number of Gated PixelCNN blocks in the architecture
PIXELCNN_NUM_FEATURE_MAPS = 32    # Width of each PixelCNN block

LOAD_MODEL = False


In [None]:
##### Tensorflow Dataset ###############################################################################################

def parse_function(filename):
    # Read entire contents of image
    image_string = tf.io.read_file(filename)

    # Don't use tf.image.decode_image, or the output shape will be undefined
    image = tf.io.decode_jpeg(image_string, channels=3)

    # This will convert to float values in [0, 1]
    image = tf.image.convert_image_dtype(image, tf.float32)

    # Resize image with padding to 244x244
    image = tf.image.resize_with_pad(image, 256, 256, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # Convert image to grayscale
    image = tf.image.rgb_to_grayscale(image)

    return image

train_ds = tf.data.Dataset.from_tensor_slices((train_images_df))
train_ds = train_ds.shuffle(len(train_ds))
train_ds = train_ds.map(parse_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_ds = train_ds.batch(VQVAE_BATCH_SIZE)

test_ds = tf.data.Dataset.from_tensor_slices((val_images_df))
test_ds = test_ds.map(parse_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_ds = test_ds.batch(VQVAE_BATCH_SIZE)


In [None]:
##### Define Model, Optimizer and Loss #################################################################################

model = ResNetVQVAE(input_shape=INPUT_SHAPE)

optimizer = tf.keras.optimizers.Adam(learning_rate=VQVAE_LEARNING_RATE)
mse_loss = tf.keras.losses.MSE

train_total_loss = tf.keras.metrics.Mean()
train_VecQuant_loss = tf.keras.metrics.Mean()
train_reconstruction_loss = tf.keras.metrics.Mean()

val_total_loss = tf.keras.metrics.Mean()
val_VecQuant_loss = tf.keras.metrics.Mean()
val_reconstruction_loss = tf.keras.metrics.Mean()


In [None]:
##### ResNetVQVAE Training Loop ########################################################################################

for epoch in range(VQVAE_NUM_EPOCHS):

    train_total_loss.reset_states()
    train_reconstruction_loss.reset_states()
    train_VecQuant_loss.reset_states()

    val_total_loss.reset_states()
    val_reconstruction_loss.reset_states()
    val_VecQuant_loss.reset_states()

    print('=' * 50, f'Training EPOCH {epoch}', '=' * 50)
    start = time.time()

    ## Train Step
    t = tqdm.tqdm(enumerate(train_ds), total=len(train_ds))
    for step, data in t:
        with tf.GradientTape() as tape:
            vq_loss, data_recon, perplexity = model(data)
            recon_err = mse_loss(data_recon, data)
            loss = vq_loss + recon_err

        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        train_total_loss(loss)
        train_reconstruction_loss(recon_err)
        train_VecQuant_loss(vq_loss)

        t.set_description('>%d, t_loss=%.5f, recon_loss=%.5f, VecQuant_loss=%.5f' % (step,
                                                                                     train_total_loss.result(),
                                                                                     train_reconstruction_loss.result(),
                                                                                     train_VecQuant_loss.result()))

    ## Evaluation Step
    for step, data in tqdm.tqdm(enumerate(test_ds), total=len(test_ds)):
        vq_loss, data_recon, perplexity = model(data)
        recon_err = mse_loss(data_recon, data)
        loss = vq_loss + recon_err

        val_total_loss(loss)
        val_reconstruction_loss(recon_err)
        val_VecQuant_loss(vq_loss)

    print(f'Epoch {epoch} complete in: {time.time() - start:.5f}')
    print('t_loss={:.5f}, recon_loss={:.5f}, VecQuant_loss={:.5f}'.format(val_total_loss.result(),
                                                                          val_reconstruction_loss.result(),
                                                                          val_VecQuant_loss.result()))

    model.save_weights(os.path.join(SAVE_DIR, f'model_{epoch}.h5'))


In [None]:
##### ResNetVQVAE Evaluation Loop ######################################################################################

if LOAD_MODEL:
    a=model(np.zeros((10,256,256,1)).astype('float32'))
    model.load_weights('checkpoint4/model_42.h5')

## Inference
n_inference = 100
input_samples = []
output_samples = []
print('Running Inference')
for i, data in tqdm.tqdm(enumerate(test_ds), total=n_inference):
    _, gen_sample, _, _ = model(data, training=False)
    gen_sample = gen_sample.numpy()
    input_samples.append(data)
    output_samples.append(gen_sample)
    if i == n_inference: break

input_samples = np.concatenate(input_samples, axis=0)
output_samples = np.concatenate(output_samples, axis=0)

# assert os.path.exists(SAVE_DIR), "Directory does not exist"
# print('Saving Images')
# for i in tqdm.tqdm(range(n_inference)):
#     cv2.imwrite(os.path.join(SAVE_DIR, f'sample_{i}_real.jpg'), input_samples[i]*255.)
#     cv2.imwrite(os.path.join(SAVE_DIR, f'sample_{i}_gen.jpg'), output_samples[i]*255.)

fig = plt.figure(figsize=(30, 20))

for i in tqdm.tqdm(range(12)):
    ax = fig.add_subplot(4, 6, i*2 + 1)
    ax.set_title('Original')
    ax.imshow(np.squeeze(input_samples[i]*255.), cmap='gray')
    ax = fig.add_subplot(4, 6, i*2 + 2)
    ax.set_title('Reconstruction')
    ax.imshow(np.squeeze(output_samples[i]*255.), cmap='gray')


In [None]:
##### Quantised Encodings ##############################################################################################

if LOAD_MODEL:
    a=model(np.zeros((10, 256, 256, 1)).astype('float32'))
    model.load_weights('checkpoint4/model_42.h5')

z_train = []
for i, data in enumerate(tqdm.tqdm(train_ds)):
    encodings = model.vq_vae(model.pre_vq_conv(model.encoder(data)))[3]
    z_train.append(encodings)
    # if i == 100: break
z_train = np.concatenate(z_train, axis=0)


In [None]:
##### Visualise Codes ##############################################################################################

fig = plt.figure(figsize=(40, 20))

for i in tqdm.tqdm(range(32)):
    ax = fig.add_subplot(4, 8, i + 1)
    img = ax.imshow(np.squeeze(z_train[i,...]))


In [None]:
##### PixelCNN #########################################################################################################

# References:
# https://www.kaggle.com/ameroyer/keras-vq-vae-for-image-generation
# https://github.com/anantzoid/Conditional-PixelCNN-decoder/blob/master/layers.py
# https://github.com/ritheshkumar95/pytorch-vqvae

def gate(inputs):
    """Gated activations"""
    x, y = tf.split(inputs, 2, axis=-1)
    return tf.tanh(x) * tf.sigmoid(y)


class MaskedConv2D(tf.keras.layers.Layer):
    """Masked convolution"""

    def __init__(self, kernel_size, out_dim, direction, mode, **kwargs):
        self.direction = direction  # Horizontal or vertical
        self.mode = mode  # Mask type "a" or "b"
        self.kernel_size = kernel_size
        self.out_dim = out_dim
        super(MaskedConv2D, self).__init__(**kwargs)

    def build(self, input_shape):
        filter_mid_y = self.kernel_size[0] // 2
        filter_mid_x = self.kernel_size[1] // 2
        in_dim = int(input_shape[-1])
        w_shape = [self.kernel_size[0], self.kernel_size[1], in_dim, self.out_dim]
        mask_filter = np.ones(w_shape, dtype=np.float32)
        # Build the mask
        if self.direction == "h":
            mask_filter[filter_mid_y + 1:, :, :, :] = 0.
            mask_filter[filter_mid_y, filter_mid_x + 1:, :, :] = 0.
        elif self.direction == "v":
            if self.mode == 'a':
                mask_filter[filter_mid_y:, :, :, :] = 0.
            elif self.mode == 'b':
                mask_filter[filter_mid_y + 1:, :, :, :] = 0.0
        if self.mode == 'a':
            mask_filter[filter_mid_y, filter_mid_x, :, :] = 0.0
        # Create convolution layer parameters with masked kernel
        self.W = mask_filter * self.add_weight("W_{}".format(self.direction), w_shape, trainable=True)
        self.b = self.add_weight("v_b", [self.out_dim, ], trainable=True)

    def call(self, inputs):
        return tf.keras.backend.conv2d(inputs, self.W, strides=(1, 1)) + self.b


def gated_masked_conv2d(v_stack_in, h_stack_in, out_dim, kernel, mask='b', residual=True, i=0):
    """Basic Gated-PixelCNN block.
       This is an improvement over PixelRNN to avoid "blind spots", i.e. pixels missingt from the
       field of view. It works by having two parallel stacks, for the vertical and horizontal direction,
       each being masked  to only see the appropriate context pixels.
    """
    kernel_size = (kernel // 2 + 1, kernel)
    padding = (kernel // 2, kernel // 2)

    v_stack = tf.keras.layers.ZeroPadding2D(padding=padding, name="v_pad_{}".format(i))(v_stack_in)
    v_stack = MaskedConv2D(kernel_size, out_dim * 2, "v", mask, name="v_masked_conv_{}".format(i))(v_stack)
    v_stack = v_stack[:, :int(v_stack_in.get_shape()[-3]), :, :]
    v_stack_out = tf.keras.layers.Lambda(lambda inputs: gate(inputs), name="v_gate_{}".format(i))(v_stack)

    kernel_size = (1, kernel // 2 + 1)
    padding = (0, kernel // 2)
    h_stack = tf.keras.layers.ZeroPadding2D(padding=padding, name="h_pad_{}".format(i))(h_stack_in)
    h_stack = MaskedConv2D(kernel_size, out_dim * 2, "h", mask, name="h_masked_conv_{}".format(i))(h_stack)
    h_stack = h_stack[:, :, :int(h_stack_in.get_shape()[-2]), :]
    h_stack_1 = tf.keras.layers.Conv2D(filters=out_dim * 2, kernel_size=1, strides=(1, 1), name="v_to_h_{}".format(i))(v_stack)
    h_stack_out = tf.keras.layers.Lambda(lambda inputs: gate(inputs), name="h_gate_{}".format(i))(h_stack + h_stack_1)

    h_stack_out = tf.keras.layers.Conv2D(filters=out_dim, kernel_size=1, strides=(1, 1), name="res_conv_{}".format(i))(
        h_stack_out)
    if residual:
        h_stack_out += h_stack_in
    return v_stack_out, h_stack_out


In [None]:
##### PixelCNN Prior Network ###########################################################################################

pixelcnn_prior_inputs = tf.keras.layers.Input(shape=(SIZE, SIZE), name='pixelcnn_prior_inputs', dtype=tf.int32)
z_q = model.vq_vae.quantize_encoding(pixelcnn_prior_inputs)  # maps indices to the actual codebook

v_stack_in, h_stack_in = z_q, z_q
for i in range(PIXELCNN_NUM_BLOCKS):
    mask = 'b' if i > 0 else 'a'
    kernel_size = 3 if i > 0 else 7
    residual = True if i > 0 else False
    v_stack_in, h_stack_in = gated_masked_conv2d(v_stack_in, h_stack_in, PIXELCNN_NUM_FEATURE_MAPS,
                                                 kernel=kernel_size, residual=residual, i=i + 1)

fc1 = tf.keras.layers.Conv2D(filters=PIXELCNN_NUM_FEATURE_MAPS, kernel_size=1, name="fc1")(h_stack_in)
fc2 = tf.keras.layers.Conv2D(filters=NUM_LATENT_K, kernel_size=1, name="fc2")(fc1)
# outputs logits for probabilities of codebook indices for each cell

pixelcnn_prior = tf.keras.Model(inputs=pixelcnn_prior_inputs, outputs=fc2, name='pixelcnn-prior')

# Distribution to sample from the pixelcnn
samples = tfp.distributions.Categorical(logits=fc2).sample()
prior_sampler = tf.keras.Model(inputs=pixelcnn_prior_inputs, outputs=samples, name='pixelcnn-prior-sampler')


In [None]:
##### Train PixelCNN Prior Network #####################################################################################

pixelcnn_prior.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                       metrics=tf.keras.metrics.SparseCategoricalAccuracy(),
                       optimizer=tf.keras.optimizers.Adam(1e-3))

prior_history = pixelcnn_prior.fit(z_train, z_train,
                                   epochs=PIXELCNN_NUM_EPOCHS,
                                   batch_size=PIXELCNN_BATCH_SIZE,
                                   verbose=1)


In [None]:
##### Loss and Accuracy Graphs #####################################################################################

fig=plt.figure(figsize=(30, 10))

x = np.arange(0, len(prior_history.history['loss']))
y1 = prior_history.history['loss']
y2 = prior_history.history['sparse_categorical_accuracy']

ax = fig.add_subplot(1, 2, 1)
ax.plot(x,y1)
ax.set_title('Loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('loss')
ax = fig.add_subplot(1, 2, 2)
ax.plot(x,y2)
ax.set_title('Sparse Categorical Accuracy')
ax.set_xlabel('Epoch')
ax.set_ylabel('sparse_categorical_accuracy')


In [None]:
##### Generate Samples #################################################################################################

fig = plt.figure(figsize=(40, 30))
id=42

for i in tqdm.tqdm(range(12)):
    out = prior_sampler(z_train[id,...])
    out = model.vq_vae.quantize_encoding(out)
    X = model.decoder(out).numpy()
    X = (X - X.min()) / (X.max() - X.min()) * 255

    ax = fig.add_subplot(3, 4, i + 1)
    ax.imshow(np.squeeze(X), cmap='gray')
    ax.set_title(f'Sample {i}')

    # cv2.imwrite(os.path.join(SAVE_DIR, f'code_{id}_sample_{i}_gen.jpg'), X[0,...]*255)
