In [None]:
# Dataset 압축해제
!unzip x-ray.zip

## path 설정

In [None]:
import os
images_path = './images/'   # WGAN image 저장 폴더
models_path = './models/'   # 각 model 저장 폴더   
npy_path = './npy/'         # npy 배열 저장 폴더 
path_category_ = [images_path, models_path, npy_path]

for idx, folder in enumerate(path_category_) : 
    if os.path.isdir(folder) : 
        print(folder, 'Made')
    else : 
        os.mkdir(folder)
        print(folder, 'Make')

In [None]:
import shutil

# 이미지 path
fracture_path = './x-ray/fracture_resize_reverse_crop/'
normal_path = './x-ray/Normal_resize_reverse_crop/'

# 이미지 증강 path
generate_fracture_path = './x-ray/generate_fracture/'
generate_normal_path ='./x-ray/generate_normal/'
path_category = [generate_fracture_path, generate_normal_path]

# path folders Make 
for idx, folder in enumerate(path_category) :
    if os.path.isdir(folder) : 
        shutil.rmtree(folder)   # 폴더가 만들어져 있으면 기존의 폴더 지우고 
        os.mkdir(folder)        # 새로 생성함.
        print(folder, 'Folders Make!!')
    else : 
        os.mkdir(folder)        # 폴더가 없으면 새로 생성
        print(folder, 'Folders Make!!')

## Data 전처리 
* 이미지 증강 X10
* 이미지 로드하면서 equalization 
* 로드된 이미지 numpy 배열변환
* -1~1 normalization 
* fracture only train - fracture_train.npy 
* normal only train - normal_train.npy 
* classification train - classify_train.npy (X_train_c, X_test, Y_train_c, Y_test)

In [None]:
import os, glob, sys, numpy as np
import cv2
import shutil

from PIL import Image
from tqdm import tqdm
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img


# images.shape = (None, 224, 244). value = 0~225
def equalize_images(images):
    return np.array([cv2.equalizeHist(image) for image in images])
    
def generate_images_for_data_augmentation(original_path, output_path, prefix, max_gen_count):
    #create_output_path(output_path)
    
    file_list = os.listdir(original_path)
    
    datagen = ImageDataGenerator(  
                rotation_range=10,
                width_shift_range=0.01,
                height_shift_range=0.01,
                #shear_range=0.2,
                zoom_range=0.1,
                horizontal_flip=True,
                fill_mode="nearest")
    
    for filename in tqdm(file_list) :
        # copy
        shutil.copyfile(original_path + filename, output_path + filename)
        
        # generates
        img = load_img(original_path  + filename)  # this is a PIL image
        
        img_data = img_to_array(img) # this is a Numpy array
        img_data = img_data.reshape((1,) + img_data.shape)  # this is a Numpy array 

        # the .flow() command below generates batches of randomly transformed images
        # and saves the results to the `preview/` directory
        generated_count = 0
        
        for batch in datagen.flow(img_data, batch_size = 1, save_to_dir=output_path, save_prefix=prefix, save_format="jpg"):
            generated_count += 1

            if generated_count > max_gen_count:
                break  # otherwise the generator would loop indefinitely

In [None]:
# 이미지 증강
generate_images_for_data_augmentation(fracture_path, generate_fracture_path, '', 10)
generate_images_for_data_augmentation(normal_path, generate_normal_path, '', 10)

In [None]:
# GAN_Data_set - facture_only
X = []

for file_name in tqdm(glob.glob(generate_fracture_path + '*.jpg')) :
    img = Image.open(file_name).convert('L')
    data = np.array(img)
    # equalization
    data = equalize_images(data) 
    X.append(data)

fracture_train = np.array(X)
# normalization -1 ~ 1
fracture_train = (fracture_train - 127.5) / 127.5
print(fracture_train.shape)

In [None]:
# fracutre_train 저장
np.save('./npy/fracture_train.npy',fracture_train)

In [None]:
# normal data set 
X_ = []

for file_name in tqdm(glob.glob(generate_normal_path + '*.jpg')) :
    img = Image.open(file_name).convert('L')
    data = np.array(img)
    # equalization
    data = equalize_images(data)
    X_.append(data)

normal_train = np.array(X_)
# normalization -1 ~ 1
normal_train = (normal_train - 127.5) / 127.5
print(normal_train.shape)

In [None]:
# normal_train 저장
np.save('./npy/normal_train.npy',normal_train)

In [None]:
# Classfication 사용할 Data set 만들기
# fracture 와 normal data 1:1 비율

# normal_train 개수 확인 
train_len = normal_train.shape[0]
normal_train_x = normal_train

# facture data 개수를 Normal_data 개수와 맞추기 
fracture_train_x = fracture_train[ : train_len]

# y_label create
normal_train_y = np.zeros((train_len, 1)) # 비골절 0
fracture_train_y = np.ones((train_len, 1)) # 골절 1

print('normal data',normal_train_x.shape, normal_train_y.shape)
print('facture data',fracture_train_x.shape, fracture_train_y.shape)

In [None]:
# fracture + normal 
x = np.append(normal_train_x, fracture_train_x, axis = 0)
y = np.append(normal_train_y, fracture_train_y, axis = 0)

# data shuffle
shuffled_index = np.random.permutation(x.shape[0])
x = x[shuffled_index,:,:]
y = y[shuffled_index]

# data ready 
split_index = int(x.shape[0] * 0.9)
X_train_c, X_test = x[:split_index], x[split_index:]
Y_train_c, Y_test = y[:split_index], y[split_index:]

print('X_train', X_train_c.shape)
print('X_test ', X_test.shape)
print('Y_train', Y_train_c.shape)
print('Y_test ', Y_test.shape)

print(Y_test[:5]) # 0과 1이 섞여 있으면 suffle 성공

In [None]:
# classify_train 저장
xy = (X_train_c, X_test, Y_train_c, Y_test)
np.save('./npy/classify_train.npy', xy)

## WGAN 학습

In [None]:
import numpy as np 

fracture_train = np.load('./npy/fracture_train.npy', allow_pickle=True)

print(fracture_train.shape)

In [None]:
from __future__ import print_function, division

from keras.layers.merge import _Merge
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import RMSprop
from functools import partial

import keras.backend as K
import matplotlib.pyplot as plt
import sys
import numpy as np

In [None]:
class RandomWeightedAverage(_Merge):
    """Provides a (random) weighted average between real and generated image samples"""
    def _merge_function(self, inputs):
        alpha = K.random_uniform((32, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

class WGANGP():
    def __init__(self):
        self.img_rows = 128
        self.img_cols = 128
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        # Following parameter and optimizer set as recommended in paper
        self.n_critic = 5
        optimizer = RMSprop(lr=0.00005)

        # Build the generator and critic
        self.generator = self.build_generator()
        self.critic = self.build_critic()

        #-------------------------------
        # Construct Computational Graph
        #       for the Critic
        #-------------------------------

        # Freeze generator's layers while training critic
        self.generator.trainable = False

        # Image input (real sample)
        real_img = Input(shape=self.img_shape)

        # Noise input
        z_disc = Input(shape=(self.latent_dim,))
        # Generate image based of noise (fake sample)
        fake_img = self.generator(z_disc)

        # Discriminator determines validity of the real and fake images
        fake = self.critic(fake_img)
        valid = self.critic(real_img)

        # Construct weighted average between real and fake images
        interpolated_img = RandomWeightedAverage()([real_img, fake_img])
        # Determine validity of weighted sample
        validity_interpolated = self.critic(interpolated_img)

        # Use Python partial to provide loss function with additional
        # 'averaged_samples' argument
        partial_gp_loss = partial(self.gradient_penalty_loss, averaged_samples=interpolated_img)
        partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names

        self.critic_model = Model(inputs=[real_img, z_disc],outputs=[valid, fake, validity_interpolated])
        self.critic_model.compile(loss=[self.wasserstein_loss, self.wasserstein_loss, partial_gp_loss],
                                  optimizer=optimizer,
                                  loss_weights=[1, 1, 10])
        #-------------------------------
        # Construct Computational Graph
        #         for Generator
        #-------------------------------

        # For the generator we freeze the critic's layers
        self.critic.trainable = False
        self.generator.trainable = True

        # Sampled noise for input to generator
        z_gen = Input(shape=(self.latent_dim,))
        # Generate images based of noise
        img = self.generator(z_gen)
        # Discriminator determines validity
        valid = self.critic(img)
        # Defines generator model
        self.generator_model = Model(z_gen, valid)
        self.generator_model.compile(loss=self.wasserstein_loss, optimizer=optimizer)


    def gradient_penalty_loss(self, y_true, y_pred, averaged_samples):
        """
        Computes gradient penalty based on prediction and weighted real / fake samples
        """
        gradients = K.gradients(y_pred, averaged_samples)[0]
        # compute the euclidean norm by squaring ...
        gradients_sqr = K.square(gradients)
        #   ... summing over the rows ...
        gradients_sqr_sum = K.sum(gradients_sqr, axis=np.arange(1, len(gradients_sqr.shape)))
        #   ... and sqrt
        gradient_l2_norm = K.sqrt(gradients_sqr_sum)
        # compute lambda * (1 - ||grad||)^2 still for each single sample
        gradient_penalty = K.square(1 - gradient_l2_norm)
        # return the mean as loss over all the batch samples
        return K.mean(gradient_penalty)


    def wasserstein_loss(self, y_true, y_pred):
        return K.mean(y_true * y_pred)

    def build_generator(self):

        model = Sequential()

        model.add(Dense(256 * 16 * 16, activation="relu", input_dim=self.latent_dim))
        model.add(Reshape((16, 16, 256)))
        model.add(UpSampling2D())
        
        model.add(Conv2D(128, kernel_size=4, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(UpSampling2D())

        model.add(Conv2D(64, kernel_size=4, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(UpSampling2D())

        model.add(Conv2D(32, kernel_size=4, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Activation("relu"))
        model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
        model.add(Activation("tanh"))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_critic(self):

        model = Sequential()

        model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        
        model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        
        model.add(Conv2D(256, kernel_size=3, strides=2, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        
        model.add(Conv2D(512, kernel_size=3, strides=1, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Flatten())
        model.add(Dense(1))
        model.add(Activation('tanh'))

        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, train_data, epochs, batch_size, sample_interval=50):

        # Load the dataset
        X_train = fracture_train

        # Adversarial ground truths
        valid = -np.ones((batch_size, 1))
        fake =  np.ones((batch_size, 1))
        dummy = np.zeros((batch_size, 1)) # Dummy gt for gradient penalty
        for epoch in range(epochs):

            for _ in range(self.n_critic):

                # ---------------------
                #  Train Discriminator
                # ---------------------

                # Select a random batch of images
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                imgs = X_train[idx]
                # Sample generator input
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
                # Train the critic
                d_loss = self.critic_model.train_on_batch([imgs, noise], [valid, fake, dummy])

            # ---------------------
            #  Train Generator
            # ---------------------

            g_loss = self.generator_model.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f] [G loss: %f]" % (epoch, d_loss[0], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 3, 3
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/whole_%d.png" % epoch)
        plt.close()

In [None]:
wgan = WGANGP()

In [None]:
# Wgan 학습
wgan.train(fracture_train, epochs=15000, batch_size=32, sample_interval=100)

In [None]:
wgan.critic.save("./models/gan_d_model_tanh_.h5")

## Classification 학습

In [None]:
from keras.models import Model, Sequential
from keras.layers import Input, Conv2D, MaxPooling2D, Dropout, Flatten, Dense, Activation, BatchNormalization 
from keras.layers import concatenate, Conv2DTranspose, Reshape
from os import path

def whole_model():
    inputs = Input((128,128,1))
    depth = 16
    conv1 = Conv2D(int(depth*1), (3, 3), activation='relu', padding='same')(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Conv2D(int(depth*1), (3, 3), activation='relu', padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Conv2D(int(depth*1), (3, 3), activation='relu', padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    conv1 = Conv2D(int(depth*1), (3, 3), activation='relu', padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    pool1 = Dropout(0.25)(pool1)

    conv2 = Conv2D(int(depth*2), (3, 3), activation='relu', padding='same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Conv2D(int(depth*2), (3, 3), activation='relu', padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Conv2D(int(depth*2), (3, 3), activation='relu', padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    conv2 = Conv2D(int(depth*2), (3, 3), activation='relu', padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    pool2 = Dropout(0.25)(pool2)

    conv3 = Conv2D(int(depth*4), (3, 3), activation='relu', padding='same')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(int(depth*4), (3, 3), activation='relu', padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(int(depth*4), (3, 3), activation='relu', padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(int(depth*4), (3, 3), activation='relu', padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    pool3 = Dropout(0.25)(pool3)

    conv4 = Conv2D(int(depth*8), (3, 3), activation='relu', padding='same')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Conv2D(int(depth*8), (3, 3), activation='relu', padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Conv2D(int(depth*8), (3, 3), activation='relu', padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    conv4 = Conv2D(int(depth*8), (3, 3), activation='relu', padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    pool4 = Dropout(0.25)(pool4)

    conv5 = Conv2D(int(depth*16), (3, 3), activation='relu', padding='same')(pool4)
    conv5 = BatchNormalization()(conv5)
    conv5 = Conv2D(int(depth*16), (3, 3), activation='relu', padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Conv2D(int(depth*16), (3, 3), activation='relu', padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)
    conv5 = Conv2D(int(depth*16), (3, 3), activation='relu', padding='same')(conv5)
    conv5 = BatchNormalization()(conv5)
    pool5 = MaxPooling2D(pool_size=(2, 2))(conv5)
    pool5 = Dropout(0.25)(pool5)

    conv6 = Conv2D(int(depth*32), (3, 3), activation='relu', padding='same')(pool5)
    conv6 = BatchNormalization()(conv6)
    conv6 = Conv2D(int(depth*32), (3, 3), activation='relu', padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Conv2D(int(depth*32), (3, 3), activation='relu', padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)
    conv6 = Conv2D(int(depth*32), (3, 3), activation='relu', padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)

    pool6 = MaxPooling2D(pool_size=(2, 2))(conv6)
    pool6 = Dropout(0.25)(pool6)
    
    c_conv_output = pool6
    dense1 = Flatten()(c_conv_output)
    dense2 = Dense(int(depth*32), activation='relu')(dense1)
    dense2 = BatchNormalization()(dense2)
    c_outputs = Dense(1, activation='tanh')(dense2) # softmax 로 할 경우 1- > 2 로변경
    c_model_ = Model(inputs=[inputs], outputs=[c_outputs], name='classify_model')


    up7 = concatenate([Conv2DTranspose(int(depth*32), (2, 2), strides=(2, 2), padding='same')(conv6), conv5], axis=3)
    conv7 = Conv2D(int(depth*16), (3, 3), activation='relu', padding='same')(up7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Conv2D(int(depth*16), (3, 3), activation='relu', padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Conv2D(int(depth*16), (3, 3), activation='relu', padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)
    conv7 = Conv2D(int(depth*16), (3, 3), activation='relu', padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)

    
    up8 = concatenate([Conv2DTranspose(int(depth*16), (2, 2), strides=(2, 2), padding='same')(conv7), conv4], axis=3)
    conv8 = Conv2D(int(depth*8), (3, 3), activation='relu', padding='same')(up8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Conv2D(int(depth*8), (3, 3), activation='relu', padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Conv2D(int(depth*8), (3, 3), activation='relu', padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)
    conv8 = Conv2D(int(depth*8), (3, 3), activation='relu', padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)
    
    up9 = concatenate([Conv2DTranspose(int(depth*8), (2, 2), strides=(2, 2), padding='same')(conv8), conv3], axis=3)
    conv9 = Conv2D(int(depth*4), (3, 3), activation='relu', padding='same')(up9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Conv2D(int(depth*4), (3, 3), activation='relu', padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Conv2D(int(depth*4), (3, 3), activation='relu', padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)
    conv9 = Conv2D(int(depth*4), (3, 3), activation='relu', padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)
    
    up10 = concatenate([Conv2DTranspose(int(depth*4), (2, 2), strides=(2, 2), padding='same')(conv9), conv2], axis=3)
    conv10 = Conv2D(int(depth*2), (3, 3), activation='relu', padding='same')(up10)
    conv10 = BatchNormalization()(conv10)
    conv10 = Conv2D(int(depth*2), (3, 3), activation='relu', padding='same')(conv10)
    conv10 = BatchNormalization()(conv10)
    conv10 = Conv2D(int(depth*2), (3, 3), activation='relu', padding='same')(conv10)
    conv10 = BatchNormalization()(conv10)
    conv10 = Conv2D(int(depth*2), (3, 3), activation='relu', padding='same')(conv10)
    conv10 = BatchNormalization()(conv10)    
    
    up11 = concatenate([Conv2DTranspose(int(depth*2), (2, 2), strides=(2, 2), padding='same')(conv10), conv1], axis=3)
    conv11 = Conv2D(int(depth*1), (3, 3), activation='relu', padding='same')(up11)
    conv11 = BatchNormalization()(conv11)
    conv11 = Conv2D(int(depth*1), (3, 3), activation='relu', padding='same')(conv11)
    conv11 = BatchNormalization()(conv11)
    conv11 = Conv2D(int(depth*1), (3, 3), activation='relu', padding='same')(conv11)
    conv11 = BatchNormalization()(conv11)
    conv11 = Conv2D(int(depth*1), (3, 3), activation='relu', padding='same')(conv11)
    conv11 = BatchNormalization()(conv11)    
    
    conv12 = Conv2D(1, (1, 1), activation='tanh')(conv11)
    g_model = Model(inputs=[inputs], outputs=[conv12])

    return c_model_, g_model

In [None]:
c_model_, g_model = whole_model()

In [None]:
from IPython.display import clear_output
from keras.callbacks import Callback
import matplotlib.pyplot as plt

class PlotLosses(Callback):
    def on_train_begin(self, logs={}):

        self.i = 0
        self.x = []
        self.losses = []
        self.val_losses = []
        
        self.x_ = []
        self.accuracy = []
        self.val_accuracy = []
        
        self.fig = plt.figure()

        self.logs = []

    def on_epoch_end(self, epoch, logs={}):

        self.logs.append(logs)
        self.x.append(self.i)
        self.losses.append(logs.get('loss'))
        self.val_losses.append(logs.get('val_loss'))

        self.accuracy.append(logs.get('accuracy'))
        self.val_accuracy.append(logs.get('val_accuracy'))

        self.i += 1

        clear_output(wait=True)
        plt.figure(figsize=(15,5))
        plt.subplot(1,2,1)
        plt.ylim(0, 1)
        plt.plot(self.x, self.losses, label="loss")
        plt.ylim(0, 1)
        plt.plot(self.x, self.val_losses, label="val_loss")
        plt.legend()
        
        plt.subplot(1,2,2)
        plt.ylim(0, 1)
        plt.plot(self.x, self.accuracy, label="accuracy")
        plt.ylim(0, 1)
        plt.plot(self.x, self.val_accuracy, label="val_accuracy")
        plt.legend()
        plt.show()

        print("loss = ", self.losses[-1], ", val_loss = ", self.val_losses[-1])
        print("accuracy = ", self.accuracy[-1], ", val_accuracy = ", self.val_accuracy[-1])


In [None]:
from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

# 학습 중에 loss 그리기
plot_losses = PlotLosses()

# overfitting이 발생하면 학습 중지
early_stopping = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=50)

# 학습 실행하면서 학습율 감소시키기
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)


In [None]:
from keras import optimizers
from keras.optimizers import Adam
from os import path

callbacks = [plot_losses, reduce_lr]
c_model_.summary()
c_model_.compile(optimizer=optimizers.Adadelta(), loss='binary_crossentropy', metrics=["accuracy"])
# c_model_.compile(optimizer=optimizers.Adadelta(), loss='sparse_categorical_crossentropy', metrics=["accuracy"])  # Activation=softmax 

In [None]:
X_train_c, X_test, Y_train_c, Y_test = np.load('./npy/classify_train.npy', allow_pickle=True)

In [None]:
c_model_.fit(X_train_c, Y_train_c, 
            epochs = 100, 
            verbose = 1, 
            batch_size = 32, 
            validation_split = 0.1, 
            shuffle=True, 
            callbacks=callbacks)

In [None]:
c_model_.save('./models/class_c_model_tanh_.h5')

loss, acc = c_model_.evaluate(X_test, Y_test)
print("loss =", loss)
print("acc =", acc)

# Whole 학습

In [None]:
import keras 

c_model  = keras.models.load_model('./models/class_c_model_tanh_.h5')
c_model.trainable = False

for layer in c_model.layers:
    layer.trainable = False

d_model  = keras.models.load_model('./models/gan_d_model_tanh_.h5')
d_model.trainable = False

for layer in d_model.layers:
    layer.trainable = False

g_model.trainable = True

In [None]:
inputs = Input((128,128,1))
generated = g_model(inputs)
d_output = d_model(generated)
c_output = c_model(generated)

In [None]:
import keras.backend as K

d_r = 0.0005 
u_r = 0.999
c_r = 0.0005

def custom_loss(y_true, y_pred):
    return c_r * K.mean(c_output) + u_r * K.mean(keras.losses.mean_squared_error(inputs, generated)) + d_r * K.mean(d_output)

whole_model = Model(inputs=[inputs], outputs=[d_output, c_output])

whole_model.compile(loss=custom_loss, optimizer='adam', metrics=['accuracy'])
whole_model.summary()

In [None]:
normal_train = np.load('./npy/normal_train.npy', allow_pickle=True)

In [None]:
# real = 0, normal = 0 으로 normal_train 개수에 맞춰서 생성 
real = np.zeros(normal_train.shape[0])
normal= np.zeros(normal_train.shape[0])

In [None]:
# normal data로만 학습
hist = whole_model.fit(normal_train, [real, normal], validation_split=0.1, epochs=100, batch_size=32, verbose=1)

In [None]:
# # normal, fracture섞인 data로 학습

# # real = 0, normal = 0 으로 X_train_c 개수에 맞춰서 생성 
# real = np.zeros(X_train_c.shape[0])
# normal= np.zeros(X_train_c.shape[0])

# hist = whole_model.fit(X_train_c, [real, normal], validation_split=0.1, epochs=20, batch_size=32, verbose=1)

In [None]:
# 모델 저장
whole_model.save('./models/whole_model.h5')

## Whole 학습 결과

* Input is classified as (Normal, Fracture) : Input Image의 정답
* classified as (Normal, Fracture) : Input Image의 분류 (0:Noraml, 1:Fracture)
* Generated is Discriminated as (Real, Fake) : Generated Image의 판별 (0:Real, 1:Fake)
* Generated is classified as (Noraml, Fracture) : Generated Image의 분류 (0:Noraml, 1:Fracture)


* 최종목표 : Input은 Fracture or Normal 이여도 Generated Image는 Normal이 출력되야된다. 
input = S
1. S(골절) -> G(비골절) = Fracture area (area 출력)
2. S(비골절) -> G(비골절) = Fracture area (미표기) 

In [None]:
# Util 함수

def show_result(target_data=X_test, count=10, threshold=0.2):

    shuffled_index = np.random.permutation(X_test.shape[0])
    shuffled_index = np.arange(0,count)
    model_input_shape = (1,128, 128, 1)

    for i in range(count):
        cl_ = Y_test[i]
        target_input = target_data[shuffled_index[i]]
        c_out = c_model.predict(target_input.reshape(model_input_shape)) 
        g_out = g_model.predict(target_input.reshape(model_input_shape))
        d_out = d_model.predict(g_out)
        c_out_g_in = c_model.predict(g_out)
        
        diff = np.absolute(g_out[0] - target_input)
        diff[diff < threshold] = 0.0

        # 실제 이미지
        if Y_test[i] == 0 :
            print('Input is classified as Normal Image %.1f' % Y_test[i], '(0 : Normal, 1 : Fracture)')
        else : 
            print('Input is classified as Fracture Image %.1f' % Y_test[i], '(0 : Normal, 1 : Fracture)')

        # classify  0 normal, 1 fracture 
        if c_out[0][1] <= 0.5 :  
            print('Classified as Normal Image %.2f' % c_out[0][1], '(0 : Normal, 1 : Fracture)')
        else :
            print('Classified as Fractrue Image %.2f' % c_out[0][1], '(0 : Normal, 1 : Fracture)')
        
        # discriminator 0 real 1 fake 디스크리미네이터의 진짜와 가짜
        d_out_ = np.argmax(d_out, axis = 1)
        if d_out_ <= 0.5 :  
            print('Generated is Discriminated as Real Image %.2f' % d_out_, '(0 : Real, 1 : Fake)')
        else : 
            print('Generated is Discriminated as Fake Image %.2f' % d_out_, '(0 : Real, 1 : Fake)')            
        
        # generator_classify 0 normal, 1 fracture
        if c_out_g_in[0][1] <= 0.5 :  
            print('Generated is classified as Normal Image %.2f' % c_out_g_in[0][1], '(0 : Normal, 1 : Fracture)')
        else :
            print('Generated is classified as Fracture Image %.2f' % c_out_g_in[0][1], '(0 : Normal, 1 : Fracture)')
        
            
        plt.figure(figsize=(10,10))
        plt.subplot(1,3,1)
        plt.title('Input img')
        plt.axis('off')
        plt.imshow(target_input.reshape((128,128)),cmap='gray' , vmin=0, vmax=1)
        
        plt.subplot(1,3,2)
        plt.title('Generated img')
        plt.axis('off')
        plt.imshow(g_out[0].reshape((128,128)),cmap='gray', vmin=0, vmax=1)
        
        plt.subplot(1,3,3)
        plt.title('Fracture area')
        plt.axis('off')
        plt.imshow(diff.reshape((128,128)),cmap='gray', vmin=0, vmax=1)
        plt.show()
        plt.close()
        print()
#cmap=plt.cm.binary
# show_result(X_test)

In [None]:
show_result(X_test, threshold=2.0)