In [93]:
import os
import sys
import math
import logging
import yaml
import h5py
from pathlib import Path

import kerastuner as kt

import numpy as np
import scipy as sp
import tensorflow as tf


%load_ext autoreload
%autoreload 2

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns
sns.set_context("poster")
sns.set(rc={'figure.figsize': (16, 9.)})
sns.set_style("whitegrid")

import pandas as pd
pd.set_option("display.max_rows", 120)
pd.set_option("display.max_columns", 120)

import began
from began.logging import setup_vae_run_logging
from began.vae import inception_module

from tensorflow.keras.layers import Conv2D, BatchNormalization, MaxPooling2D, Concatenate, Dense, Reshape, Input, UpSampling2D, Conv2DTranspose
from tensorflow.keras.models import Model

logging.basicConfig(level=logging.INFO, stream=sys.stdout)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [95]:
# function for creating a naive inception block
def inception_module(layer_in, f1, f2, f3):
    """ An implementation of the Inception module.
    """
    # 1x1 conv
    conv1 = Conv2D(f1, (1,1), padding='same', activation='relu')(layer_in)
    # 3x3 conv
    conv3 = Conv2D(f2, (3,3), padding='same', activation='relu')(layer_in)
    # 5x5 conv
    conv5 = Conv2D(f3, (5,5), padding='same', activation='relu')(layer_in)
    # 3x3 max pooling
    pool = MaxPooling2D((3,3), strides=(1,1), padding='same')(layer_in)
    # concatenate filters, assumes filters/channels last
    layer_out = Concatenate(axis=-1)([conv1, conv3, conv5, pool])
    return layer_out

In [89]:
lat_dim = 128
latz = Input((lat_dim,))

dlr0 = Dense(units=(16 * 16 * 32), activation=tf.nn.relu)(latz)
res0 = Reshape(target_shape=(16, 16, 32))(dlr0)
assert res0.shape.as_list() == [None, 16, 16, 32]

btn0 = BatchNormalization(momentum=0.9)(res0)
cvt0 = Conv2DTranspose(filters=128, kernel_size=5, strides=(2, 2), padding="SAME", activation='relu')(btn0)
assert cvt0.shape.as_list() == [None, 32, 32, 128]

btn1 = BatchNormalization(momentum=0.9)(cvt0)
cvt1 = Conv2DTranspose(filters=128, kernel_size=5, strides=(2, 2), padding="SAME", activation='relu')(btn1)
assert cvt1.shape.as_list() == [None, 64, 64, 128]

btn2 = BatchNormalization(momentum=0.9)(cvt1)
cvt2 = Conv2DTranspose(filters=128, kernel_size=5, strides=(2, 2), padding="SAME", activation='relu')(btn2)
assert cvt2.shape.as_list() == [None, 128, 128, 128]

btn3 = BatchNormalization(momentum=0.9)(cvt2)
cvt3 = Conv2DTranspose(filters=128, kernel_size=5, strides=(2, 2), padding="SAME", activation='relu')(btn3)
assert cvt3.shape.as_list() == [None, 256, 256, 128]

cvt4 = Conv2DTranspose(filters=1, kernel_size=5, strides=(1, 1), padding="SAME")(cvt3)

model = Model([latz], cvt4)

model.summary()

Model: "model_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_34 (InputLayer)        [(None, 128)]             0         
_________________________________________________________________
dense_25 (Dense)             (None, 8192)              1056768   
_________________________________________________________________
reshape_22 (Reshape)         (None, 16, 16, 32)        0         
_________________________________________________________________
batch_normalization_82 (Batc (None, 16, 16, 32)        128       
_________________________________________________________________
conv2d_transpose_8 (Conv2DTr (None, 32, 32, 128)       102528    
_________________________________________________________________
batch_normalization_83 (Batc (None, 32, 32, 128)       512       
_________________________________________________________________
conv2d_transpose_9 (Conv2DTr (None, 64, 64, 128)       4097

In [115]:
lat_dim = 128
latz = Input((lat_dim,))

dlr0 = Dense(units=(16 * 16 * 32), activation=tf.nn.relu)(latz)
res0 = Reshape(target_shape=(16, 16, 32))(dlr0)
assert res0.shape.as_list() == [None, 16, 16, 32]

btn0 = BatchNormalization(momentum=0.9)(res0)
icp0 = inception_module(btn0, 32, 128, 64)
ups0 = UpSampling2D((2, 2), interpolation='nearest')(icp0)
assert ups0.shape.as_list() == [None, 32, 32, 256]

btn1 = BatchNormalization(momentum=0.9)(ups0)
icp1 = inception_module(btn1, 32, 128, 64)
ups1 = UpSampling2D((2, 2), interpolation='nearest')(icp1)
assert ups1.shape.as_list() == [None, 64, 64, 480]

btn2 = BatchNormalization(momentum=0.9)(ups1)
icp2 = inception_module(btn2, 32, 128, 64)
ups2 = UpSampling2D((2, 2), interpolation='nearest')(icp2)
assert ups2.shape.as_list() == [None, 128, 128, 704]

btn3 = BatchNormalization(momentum=0.9)(ups2)
icp3 = inception_module(btn3, 32, 128, 64)
ups3 = UpSampling2D((2, 2), interpolation='nearest')(icp3)
print(ups3.shape)
assert ups3.shape.as_list() == [None, 256, 256, 128]

out = Dense()

model = Model([latz], cvt4)

model.summary()

(None, 128, 128, 704)
Model: "model_16"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_55 (InputLayer)           [(None, 128)]        0                                            
__________________________________________________________________________________________________
dense_46 (Dense)                (None, 8192)         1056768     input_55[0][0]                   
__________________________________________________________________________________________________
reshape_43 (Reshape)            (None, 16, 16, 32)   0           dense_46[0][0]                   
__________________________________________________________________________________________________
batch_normalization_134 (BatchN (None, 16, 16, 32)   128         reshape_43[0][0]                 
_____________________________________________________________________

In [103]:
icp0.

TensorShape([None, 16, 16, 256])