In [1]:
import numpy as np
from keras.layers import *
from keras.layers.merge import concatenate as concat
from keras.models import Model
from keras import backend as K
from keras.utils import to_categorical
from keras.callbacks import EarlyStopping
from keras.optimizers import Adam
import matplotlib.pyplot as plt
from keras.losses import mse, binary_crossentropy

Using TensorFlow backend.


In [2]:
# compute the number of labels
num_labels = 4
image_size = 128
# network parameters
input_shape = (image_size, image_size, 1)
label_shape = (num_labels, )
batch_size = 32
kernel_size = 3
filters = 16
latent_dim = 64
epochs = 30

In [3]:
def sampling(args):
    """Implements reparameterization trick by sampling
    from a gaussian with zero mean and std=1.
    Arguments:
        args (tensor): mean and log of variance of Q(z|X)
    Returns:
        sampled latent vector (tensor)
    """

    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean=0 and std=1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

In [4]:
inputs = Input(shape=input_shape, name='encoder_input')
y_labels = Input(shape=label_shape, name='class_labels')
x = Dense(image_size * image_size)(y_labels)
x = Reshape((image_size, image_size, 1))(x)
x = concatenate([inputs, x])
for i in range(4):
    filters *= 2
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               activation='relu',
               strides=2,
               padding='same')(x)

# shape info needed to build decoder model
shape = K.int_shape(x)

# generate latent vector Q(z|X)
x = Flatten()(x)
x = Dense(16, activation='relu')(x)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)

# use reparameterization trick to push the sampling out as input
# note that "output_shape" isn't necessary 
# with the TensorFlow backend
z = Lambda(sampling,
           output_shape=(latent_dim,),
           name='z')([z_mean, z_log_var])

# instantiate encoder model
encoder = Model([inputs, y_labels],
                [z_mean, z_log_var, z], 
                name='encoder')
encoder.summary()
# plot_model(encoder,
#            to_file='cvae_cnn_encoder.png', 
#            show_shapes=True)

# build decoder model
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = concatenate([latent_inputs, y_labels])
x = Dense(shape[1]*shape[2]*shape[3], activation='relu')(x)
x = Reshape((shape[1], shape[2], shape[3]))(x)

for i in range(4):
    x = Conv2DTranspose(filters=filters,
                        kernel_size=kernel_size,
                        activation='relu',
                        strides=2,
                        padding='same')(x)
    filters //= 2

outputs = Conv2DTranspose(filters=1,
                          kernel_size=kernel_size,
                          activation='sigmoid',
                          padding='same',
                          name='decoder_output')(x)

# instantiate decoder model
decoder = Model([latent_inputs, y_labels],
                outputs, 
                name='decoder')
decoder.summary()
outputs = decoder([encoder([inputs, y_labels])[2], y_labels])
cvae = Model([inputs, y_labels], outputs, name='cvae')

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
class_labels (InputLayer)       (None, 4)            0                                            
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 16384)        81920       class_labels[0][0]               
__________________________________________________________________________________________________
encoder_input (InputLayer)      (None, 128, 128, 1)  0                                            
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 128, 128, 1)  0           dense_1[0][0]                    
____________________________________________________________________________________________

In [5]:
reconstruction_loss = mse(K.flatten(inputs), K.flatten(outputs))
reconstruction_loss *= image_size * image_size
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
# kl_loss *= -0.5 * beta
cvae_loss = K.mean(reconstruction_loss + kl_loss)
cvae.add_loss(cvae_loss)
cvae.compile(optimizer='rmsprop')
cvae.summary()

Model: "cvae"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      (None, 128, 128, 1)  0                                            
__________________________________________________________________________________________________
class_labels (InputLayer)       (None, 4)            0                                            
__________________________________________________________________________________________________
encoder (Model)                 [(None, 64), (None,  734384      encoder_input[0][0]              
                                                                 class_labels[0][0]               
__________________________________________________________________________________________________
decoder (Model)                 (None, 128, 128, 1)  2108161     encoder[1][2]                 

  'be expecting any data to be passed to {0}.'.format(name))


In [6]:
from Dental_Tool.Data_processing import *
from Dental_Tool.Dental_Model import *
from Dental_Tool.Process_results import *
from Dental_Tool.Dataloader import *
from Dental_Tool.KFold_v3 import *

In [7]:
directory = [ 
                "Dental_Data/PBL/10_20200901", 
                "Dental_Data/PBL/10_20200901_Flip", 
                "Dental_Data/PBL/10_clahe_20200901", 
                "Dental_Data/PBL/10_clahe_20200901_Flip"
            ]

directory = [ i + "/mapping.json" for i in directory]
argscale_num = len(directory) * 20
data = load_json(directory, interdental=False)
dataset = json_2_dataframe_PBL(data)
dataset = dataset[dataset.Class == 2]

In [8]:
def load_images(path_list, resize):
        X = []
        for path in tqdm(path_list):
                image = cv2.imread(path, 0)
                image = cv2.resize(image, resize)
                image = image.astype("float32") / 255.0
#                 image = image - np.mean(image)
                image = np.expand_dims(image, axis=2)
                X.append(image)
        return np.array(X)

dataset_size = len(dataset)

dataset = shuffle(dataset).reset_index(drop=True)
train_idx = int(dataset_size * 0.6)
valid_idx = int(dataset_size * 0.2)

train = dataset.iloc[:train_idx]
valid = dataset.iloc[train_idx: train_idx + valid_idx]
test  = dataset.iloc[train_idx + valid_idx: ]

classes = 4
x_train, y_train = load_images(train["Path"], (image_size, image_size)), to_categorical(train["tooth_type"], classes)
x_valid, y_valid = load_images(valid["Path"], (image_size, image_size)), to_categorical(valid["tooth_type"], classes)
x_test, y_test   = load_images(test["Path"], (image_size, image_size)), to_categorical(test["tooth_type"], classes)



print(x_train.shape)
print(y_train.shape)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=19248.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6416.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=6416.0), HTML(value='')))


(19248, 128, 128, 1)
(19248, 4)


In [9]:
cvae.fit([x_train, y_train],
         epochs=epochs,
         batch_size=batch_size,
         validation_data=([x_test, y_test], None))

Train on 19248 samples, validate on 6416 samples
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


<keras.callbacks.callbacks.History at 0x244febda208>