In [1]:
import tensorflow as tf 
import tensorflow_io as tfio
import tensorflow_probability as tfp

print('Tensorflow Version:')
print(tf.__version__)
print()
print('Tensorflow-Probability Version:')
print(tfp.__version__)
print()
print('Listing all GPU resources:')
print(tf.config.experimental.list_physical_devices('GPU'))
print()

import tensorflow.keras as keras
import numpy as np
import datetime
import time
import matplotlib.pyplot as plt
import matplotlib as mpl
import pickle
import os
from tqdm import trange
import sys
import git
import importlib

mpl.rcParams['image.cmap'] = 'coolwarm'

Tensorflow Version:
2.2.0

Tensorflow-Probability Version:
0.10.0

Listing all GPU resources:
[]



In [8]:
LAYER_NAME = 'all_layers'

FILTERS = 32
DATA_SIZE = 60000
PRIOR_MU = 0
PRIOR_SIGMA = 10

BATCH_SIZE = 128
EPOCHS = 200
VERBOSE = 2

N_PREDICTIONS = 100

ROOT_PATH = git.Repo("", search_parent_directories=True).git.rev_parse("--show-toplevel")
DATA_PATH = ROOT_PATH + "/data/"
SMALL_DATA_PATH = ROOT_PATH + "/load_trained_models" + "/data_small/"
LAYER_PATH = ROOT_PATH + "/layers/" + LAYER_NAME + "/"
SAVE_PATH = LAYER_PATH + LAYER_NAME + "_bayesian_model.h5"
PICKLE_PATH = LAYER_PATH + LAYER_NAME + '_hist.pkl'
MODEL_PATH = LAYER_PATH + LAYER_NAME + "_model"
IMAGE_PATH = ROOT_PATH + "/images/" + LAYER_NAME + "/"

In [9]:
print("-" * 30)
print("Constructing model...")
print("-" * 30)
spec = importlib.util.spec_from_file_location(MODEL_PATH, MODEL_PATH + ".py")
ModelLoader = importlib.util.module_from_spec(spec)
spec.loader.exec_module(ModelLoader)
model = ModelLoader.make_model()
print(model.summary())

------------------------------
Constructing model...
------------------------------

Input size: (None, 144, 144, 4)
Output size: (None, 144, 144, 1)
Model: "model_all_layers"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_layer (InputLayer)        [(None, 144, 144, 4) 0                                            
__________________________________________________________________________________________________
encoder_1_a (Conv2DFlipout)     (None, 144, 144, 32) 4736        input_layer[0][0]                
__________________________________________________________________________________________________
encoder_1_b (Conv2DFlipout)     (None, 144, 144, 32) 36992       encoder_1_a[0][0]                
__________________________________________________________________________________________________
downsample_1 (MaxPooling2D)     

In [32]:
class KLDivergence:
    def __init__(self, q_dist, p_dist):
        self.q_dist = q_dist
        self.p_dist = p_dist
    def call(self):
        return tfp.distributions.kl_divergence(self.q_dist, self.p_dist)

def mean_binary_crossentropy(y, y_pred):
    return tf.reduce_mean(keras.losses.binary_crossentropy(y, y_pred))

def sum_binary_crossentropy(y, y_pred):
    return DATA_SIZE * mean_binary_crossentropy(y, y_pred)

def likelihood_loss(y, y_pred):
    return sum_binary_crossentropy(y, y_pred)


def prior_fn(dtype, shape, name, trainable, add_variable_fn):
    dist = tfp.distributions.Normal(loc=PRIOR_MU*tf.ones(shape, dtype), 
                             scale=PRIOR_SIGMA*tf.ones(shape, dtype))
    multivar_dist = tfp.distributions.Independent(dist, reinterpreted_batch_ndims=tf.size(dist.batch_shape_tensor()))
    
    return multivar_dist

posterior_fn = tfp.layers.default_mean_field_normal_fn(
          loc_initializer=tf.random_normal_initializer(
              mean=PRIOR_MU, stddev=0.05),
          untransformed_scale_initializer=tf.random_normal_initializer(
              mean=np.log(np.exp(PRIOR_SIGMA/50.) - 1), stddev=0.05))

flipout_params = dict(kernel_size=(3, 3), activation="relu", padding="same",
              kernel_prior_fn=prior_fn,
              bias_prior_fn=prior_fn,
              kernel_posterior_fn=posterior_fn,
              bias_posterior_fn=posterior_fn,
              kernel_divergence_fn=None,
              bias_divergence_fn=None)

flipout_params_final = dict(kernel_size=(1, 1), activation="sigmoid", padding="same",
                            kernel_prior_fn=prior_fn,
                            bias_prior_fn=prior_fn,
                            kernel_posterior_fn=posterior_fn,
                            bias_posterior_fn=posterior_fn,
                            kernel_divergence_fn=None,
                            bias_divergence_fn=None)

params_final = dict(kernel_size=(1, 1), activation="sigmoid", padding="same",
                    data_format="channels_last",
                    kernel_initializer="he_uniform")

params = dict(kernel_size=(3, 3), activation="relu",
              padding="same", data_format="channels_last",
              kernel_initializer="he_uniform")


input_layer = keras.layers.Input(shape=(144, 144, 4), name="input_layer")

encoder_1_a = tfp.layers.Convolution2DFlipout(FILTERS, name='encoder_1_a', **flipout_params)(input_layer)
encoder_1_b = tfp.layers.Convolution2DFlipout(FILTERS, name='encoder_1_b', **flipout_params)(encoder_1_a)
downsample_1 = keras.layers.MaxPool2D(name='downsample_1')(encoder_1_b)

encoder_2_a = tfp.layers.Convolution2DFlipout(FILTERS*2, name='encoder_2_a', **flipout_params)(downsample_1)
encoder_2_b = tfp.layers.Convolution2DFlipout(FILTERS*2, name='encoder_2_b', **flipout_params)(encoder_2_a)
downsample_2 = keras.layers.MaxPool2D(name='downsample_2')(encoder_2_b)

encoder_3_a = tfp.layers.Convolution2DFlipout(FILTERS*4, name='encoder_3_a', **flipout_params)(downsample_2)
encoder_3_b = tfp.layers.Convolution2DFlipout(FILTERS*4, name='encoder_3_b', **flipout_params)(encoder_3_a)
downsample_3 = keras.layers.MaxPool2D(name='downsample_3')(encoder_3_b)

encoder_4_a = tfp.layers.Convolution2DFlipout(FILTERS*8, name='encoder_4_a', **flipout_params)(downsample_3)
encoder_4_b = tfp.layers.Convolution2DFlipout(FILTERS*8, name='encoder_4_b', **flipout_params)(encoder_4_a)
downsample_4 = keras.layers.MaxPool2D(name='downsample_4')(encoder_4_b)


encoder_5_a = tfp.layers.Convolution2DFlipout(FILTERS*16, name='encoder_5_a', **flipout_params)(downsample_4)
encoder_5_b = tfp.layers.Convolution2DFlipout(FILTERS*16, name='encoder_5_b', **flipout_params)(encoder_5_a)


upsample_4 = keras.layers.UpSampling2D(name='upsample_4', size=(2, 2), interpolation="bilinear")(encoder_5_b)
concat_4 = keras.layers.concatenate([upsample_4, encoder_4_b], name='concat_4')
decoder_4_a = tfp.layers.Convolution2DFlipout(FILTERS*8, name='decoder_4_a', **flipout_params)(concat_4)
decoder_4_b = tfp.layers.Convolution2DFlipout(FILTERS*8, name='decoder_4_b', **flipout_params)(decoder_4_a)


upsample_3 = keras.layers.UpSampling2D(name='upsample_3', size=(2, 2), interpolation="bilinear")(decoder_4_b)
concat_3 = keras.layers.concatenate([upsample_3, encoder_3_b], name='concat_3')
decoder_3_a = tfp.layers.Convolution2DFlipout(FILTERS*4, name='decoder_3_a', **flipout_params)(concat_3)
decoder_3_b = tfp.layers.Convolution2DFlipout(FILTERS*4, name='decoder_3_b', **flipout_params)(decoder_3_a)


upsample_2 = keras.layers.UpSampling2D(name='upsample_2', size=(2, 2), interpolation="bilinear")(decoder_3_b)
concat_2 = keras.layers.concatenate([upsample_2, encoder_2_b], name='concat_2')
decoder_2_a = tfp.layers.Convolution2DFlipout(FILTERS*2, name='decoder_2_a', **flipout_params)(concat_2)
decoder_2_b = tfp.layers.Convolution2DFlipout(FILTERS*2, name='decoder_2_b', **flipout_params)(decoder_2_a)


upsample_1 = keras.layers.UpSampling2D(name='upsample_1', size=(2, 2), interpolation="bilinear")(decoder_2_b)
concat_1 = keras.layers.concatenate([upsample_1, encoder_1_b], name='concat_1')
decoder_1_a = tfp.layers.Convolution2DFlipout(FILTERS, name='decoder_1_a', **flipout_params)(concat_1)
decoder_1_b = tfp.layers.Convolution2DFlipout(FILTERS, name='decoder_1_b', **flipout_params)(decoder_1_a)

output_layer = tfp.layers.Convolution2DFlipout(name="output_layer",
                                filters=1, **flipout_params_final)(decoder_1_b)

print()
print('Input size:', input_layer.shape)
print('Output size:', output_layer.shape)

model = keras.models.Model(inputs=input_layer, outputs=output_layer,
                           name = 'model_' + LAYER_NAME)

for layer in model.layers:
    if type(layer) == tfp.python.layers.conv_variational.Conv2DFlipout:
        layer.add_loss(KLDivergence(layer.kernel_posterior, layer.kernel_prior).call)
        layer.add_loss(KLDivergence(layer.bias_posterior, layer.bias_prior).call)

model.compile(optimizer=keras.optimizers.Nadam(learning_rate=1e-4),
              loss=likelihood_loss,
              metrics=[likelihood_loss, mean_binary_crossentropy],
              )


Input size: (None, 144, 144, 4)
Output size: (None, 144, 144, 1)


In [33]:
for w in model.layers[1].get_weights():
    print(w.shape)

(3, 3, 4, 32)
(3, 3, 4, 32)
(32,)
(32,)


In [None]:
model.load_weights(SAVE_PATH)
print("Model weights loaded successfully\n")

In [None]:
n_test = 200
y_test = msks_test = np.load(SMALL_DATA_PATH + 'msks_test.npy')
imgs_test = np.load(SMALL_DATA_PATH + 'imgs_test.npy')

print("First " + str(n_test) + " test samples loaded\n")

In [None]:
Xy_test = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(imgs_test),
                                tf.data.Dataset.from_tensor_slices(msks_test))
                             ).cache().batch(BATCH_SIZE).prefetch(8)

In [None]:
prediction_size = list(msks_test.shape)
prediction_size.insert(0, N_PREDICTIONS)
prediction_test = np.zeros(prediction_size)

In [None]:
print("Getting Monte Carlo samples of test predictions...")
for i in trange(N_PREDICTIONS):
    prediction_test[i] = model.predict(Xy_test)

In [None]:
for i in trange(0, 200, 20):
    plt.figure(dpi=100)
    plt.subplot(221)
    plt.title('Test Data T2-FLAIR')
    plt.imshow(imgs_test[i, :, :, 0], cmap='gray')

    plt.subplot(222)
    plt.title('Test Data T1')
    plt.imshow(imgs_test[i, :, :, 1], cmap='gray')

    plt.subplot(223)
    plt.title('Test Data T1-Contrast')
    plt.imshow(imgs_test[i, :, :, 2], cmap='gray')

    plt.subplot(224)
    plt.title('Test Data T2')
    plt.imshow(imgs_test[i, :, :, 3], cmap='gray')
    plt.tight_layout()
    plt.savefig(IMAGE_PATH +  'input_images_'+str(i).zfill(4)+'.png')

    plt.figure(dpi=200)

    plt.subplot(221)
    plt.title('\nPredicted Label Mean')
    plt.imshow(prediction_test.mean(0)[i, :, :, 0], interpolation='nearest')
    plt.colorbar()
    plt.clim(0, 1)

    plt.subplot(222)
    plt.title('\nPredicted Label Stddev')
    plt.imshow(prediction_test.std(0)[i, :, :, 0], interpolation='nearest')
    plt.clim(0, 0.2)
    plt.colorbar()

    plt.subplot(223)
    plt.title('True Label')
    plt.imshow(msks_test[i, :, :, 0], interpolation='nearest')
    plt.clim(0, 1)
    plt.colorbar()
    
    plt.subplot(224)
    plt.title('\nTruth-Prediction Discrepency')
    plt.imshow((msks_test[i, :, :, 0] - prediction_test.mean(0)[i, :, :, 0]), 
               interpolation='nearest')
    plt.clim(-1, 1)
    plt.colorbar()


    plt.tight_layout()

    plt.savefig(IMAGE_PATH + 'prediction_images_'+str(i).zfill(4)+'.png')


    plt.figure(dpi=100)

    plt.subplot(221)
    plt.title('\nPredicted Label 10th percentile')
    plt.imshow(np.percentile(prediction_test[:, i, :, :, 0], 10, axis=0), interpolation='nearest')
    plt.clim(0, 1)
    plt.colorbar()

    plt.subplot(222)
    plt.title('\nPredicted Label 50th percentile')
    plt.imshow(np.percentile(prediction_test[:, i, :, :, 0], 50, axis=0), interpolation='nearest')
    plt.clim(0, 1)
    plt.colorbar()

    plt.subplot(223)
    plt.title('Predicted Label 90th percentile')
    plt.imshow(np.percentile(prediction_test[:, i, :, :, 0], 90, axis=0), interpolation='nearest')
    plt.clim(0, 1)
    plt.colorbar()

    plt.subplot(224)
    plt.title('True Label')
    plt.imshow(msks_test[i, :, :, 0], interpolation='nearest')
    plt.clim(0, 1)
    plt.colorbar()

    plt.tight_layout()
    plt.savefig(IMAGE_PATH + 'prediction_percentile_images_' + str(i).zfill(4)+'.png')

In [None]:
history = pickle.load(open(PICKLE_PATH, 'rb'))

In [None]:
plt.figure(dpi=100)
plt.semilogy(history['loss'], label='training loss (ELBO)')
plt.semilogy(history['val_loss'], label='testing loss')
plt.legend()
plt.savefig(IMAGE_PATH + 'training_history.png')

In [None]:
for layer in model.layers:
    print(layer)
    weights = layer.get_weights()
    for w in weights:
        print(w.shape)