In [None]:
from tensorflow.keras.layers import Dense, Input, SpatialDropout2D
from tensorflow.keras.layers import Conv2D, Flatten, Lambda
from tensorflow.keras.layers import LocallyConnected2D, ZeroPadding2D
from tensorflow.keras.layers import MaxPooling2D, UpSampling2D
from tensorflow.keras.layers import Reshape, Conv2DTranspose
from tensorflow.keras.layers import ActivityRegularization
from tensorflow.keras.regularizers import L1L2
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.losses import mse, binary_crossentropy
from tensorflow.keras import backend as K

def sampling(args):
    """
    Reparameterization trick by sampling fr an isotropic unit Gaussian.
    instead of sampling from Q(z|X), sample eps = N(0,I) 
        then z = z_mean + sqrt(var)*eps    
    # Arguments
        args (tensor tuple): mean and log of variance of Q(z|X)
    # Returns
        z (tensor): sampled latent vector
    """    
    z_mean, z_log_var = args
    
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean=0 and std=1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

def sampling_mean_log_var(z_mean_log_var):
    """ """
    batch, dim = K.shape(z_mean_log_var)[0], K.int_shape(z_mean_log_var)[1]
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean_log_var[...,0] + K.exp(0.5 * z_mean_log_var[...,1]) * epsilon

In [None]:
def vae_loss(z_mean, z_log_var, y_true, y_pred):
    """
    Compute VAE loss, using either mse or crossentropy.
    # Arguments
        z_mean: mean of Q(z|X)
        z_log_var: log variance of Q(z|X)
        y_true, y_pred: truth and predicated values
    # Returns
        loss value
    """
    match_loss = mse(K.flatten(y_true), K.flatten(y_pred))
    kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    return match_loss + (1e-4 * kl_loss)

In [None]:
act_func = 'softplus' # or 'relu'
locally_connected_channels = 2
latent_dim = 8
sz = 128

# this is the V1 filter bank output
v1_filtered_output = Input(shape=(sz//2,sz//2,4), name='v1_{}'.format(sz//2))

encoder = \
    Sequential(name='v1_to_pulvinar_encoder', layers=
    [
        v1_filtered_output,
        Conv2D(64, (1,1), name='v2_conv2d', activation=act_func, padding='same'),
        MaxPooling2D((2,2), name='v2_maxpool', padding='same'),
        Conv2D(64, (3,3), name='v4_conv2d', activation=act_func, padding='same'),
        MaxPooling2D((2,2), name='v4_maxpool', padding='same'),
        Conv2D(64, (1,1), name='pit_conv2d', activation=act_func, padding='same'),
        MaxPooling2D((2,2), name='pit_maxpool', padding='same'),
        Conv2D(64, (3,3), name='cit_conv2d', activation=act_func, padding='same'),
        LocallyConnected2D(locally_connected_channels, (3,3), 
                           name='ait_local', activation=act_func,
                           kernel_regularizer=L1L2(l1=0.01, l2=0.01)),
        Flatten(name='pulvinar_flatten'),
        Dense((latent_dim*2), activation=act_func, name='pulvinar_dense'),
        Reshape((latent_dim,2), name='pulinvar_reshape'),
        Lambda(sampling_mean_log_var, output_shape=(latent_dim,1), name='z')
    ])
encoder_output = encoder(v1_filtered_output)
print(encoder_output)
encoder.summary()

local_shape = encoder.get_layer('ait_local').output_shape
z_shape = encoder.get_layer('z').output_shape
decoder = \
    Sequential(name='pulvinar_to_v1_decoder', layers=
    [
        Input(shape=z_shape, name='z_sampling'),
        Dense(np.prod(local_shape[1:]), name='pulvinar_dense_back', activation=act_func),
        Reshape(local_shape[1:], name='pulvinar_antiflatten'),
        ZeroPadding2D(padding=(1,1), name='ait_padding_back'),
        LocallyConnected2D(locally_connected_channels, (3,3), 
                           name='ait_local_back', activation=act_func,
                           kernel_regularizer=L1L2(l1=0.01, l2=0.01)),
        ZeroPadding2D(padding=(1,1), name='cit_padding_back'),
        Conv2DTranspose(64, (3,3), name='cit_conv2d_trans', activation=act_func, padding='same'),       
        UpSampling2D((2,2), name='cit_upsample_back'),
        Conv2DTranspose(64, (1,1), name='pit_conv2d_trans', activation=act_func, padding='same'),
        UpSampling2D((2,2), name='pit_upsample_back'),
        Conv2DTranspose(64, (3,3), name='v4_conv2d_trans', activation=act_func, padding='same'),
        UpSampling2D((2,2), name='v4_upsample_back'),
        Conv2DTranspose(4, (1,1), name='v2_conv2d_trans', activation=act_func, padding='same')
    ])
decoder_output = decoder(encoder_output)
decoder.summary()

In [None]:
import pprint
autoencoder = Model(v1_filtered_output, decoder_output, name='v1_to_pulvinar_vae')
autoencoder.summary()
pprint.pprint(autoencoder.trainable_weights)

In [None]:
batch_size = 64

# set up paths
nih_cxr8_csv = (Path(os.environ['DATA_NIH_CXR8']) / 'Data_Entry_2017_v2020').with_suffix('.csv')
logging.info(f"Reading from {nih_cxr8_csv}")

# read the csv dataset
cxr_batched_ds = tf.data.experimental.make_csv_dataset(
    file_pattern=nih_cxr8_csv, 
    batch_size=batch_size, num_epochs=1, num_parallel_reads=20,
    shuffle=False)

# look at first 50
cxr_batched_ds = cxr_batched_ds.take(50)

# filter for No Finding
nofinding_cxrs = \
    cxr_batched_ds.filter(
        lambda item: tf.equal(item['Finding Labels'][0], 'No Finding'))

# show a few
for batch in nofinding_cxrs.take(3):
    logging.info("-----------")
    for item in batch:
        logging.info(f"{item}: {batch[item]}")

In [None]:
# perform single epoch
for batch_x_in in train_dataset:
    with tf.GradientTape() as tape:
        batch_x_pred = autoencoder(batch_x_in)
        loss = vae_loss(z_mean, z_log_var, batch_x_in, batch_x_pred)
    grads = tape.gradient(loss, autoencoder.trainable_variables)
    optimizer.apply_gradients(zip(grads, autoencoder.trainable_variables))
    print(grads)