In [1]:
%config Completer.use_jedi = False

import os

# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   # see issue #152
# os.environ["CUDA_VISIBLE_DEVICES"] = ""

# example of training an conditional gan on the fashion mnist dataset
from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy.random import randn
from numpy.random import randint
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, LearningRateScheduler
from keras.utils import plot_model
from keras.models import Model
from keras.layers import Input, Reshape, multiply, Embedding, merge, Concatenate, Conv1D, BatchNormalization
from keras.layers import Dense, Flatten, Multiply
from keras.layers.advanced_activations import LeakyReLU
from keras.layers import UpSampling1D
from keras.layers.core import Activation
import numpy as np
import os
from numpy import asarray
import matplotlib.pyplot as plt
from evaluation_metrics import *
metric_to_calculate = ['FID', 'MMD', 'DTW', 'PC', 'RMSE', 'TWED']

In [2]:
def discriminator(data_dim, input_classes=3):
    
    in_label = Input(shape=(1,))
    x = Embedding(input_classes, 30)(in_label)
    x = Dense(data_dim)(x)
    x = Reshape((data_dim,1))(x)
    
    D_in = Input(shape=[data_dim,1])
    x = Concatenate()([D_in, x])
    
    x = Conv1D(filters=32, kernel_size=4, strides=2, padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv1D(filters=32*2, kernel_size=4, strides=2, padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv1D(filters=32*4, kernel_size=4, strides=2, padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv1D(filters=32*8, kernel_size=4, strides=2, padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv1D(filters=32*16, kernel_size=4, strides=2, padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Flatten()(x)
    out = Dense(1, activation='sigmoid')(x)
    
    model = Model(inputs=[D_in, in_label], outputs=out)
    opt = Adam(lr=0.0002, beta_1=0.5)
    loss = 'binary_crossentropy'
    model.compile(loss=loss, optimizer=opt, metrics=['accuracy'])
    return model

# d_model = discriminator(data_dim=186, input_classes=3)
# plot_model(d_model, to_file='disc.png', show_shapes=True)

In [3]:
def generator(noise_dim=186, input_classes=3, out_dim=186):
    
    in_label = Input(shape=(1,))
    x = Embedding(input_classes, 30)(in_label)
    x = Dense(noise_dim)(x)
    x = Reshape((noise_dim,1))(x)
    
    G_in = Input(shape=[noise_dim,1])
#     gen = Dense(noise_dim)(G_in)
#     gen = LeakyReLU(alpha=0.2)(gen)
#     gen = Reshape((noise_dim,1))(gen)

    x = Concatenate()([G_in, x])
    
    x = UpSampling1D()(x)
    x = Conv1D(filters=32*16, kernel_size=2, strides=2, padding='valid', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    # x = Activation('relu')(x)

    x = UpSampling1D()(x)
    x = Conv1D(filters=32*8, kernel_size=4, strides=2, padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    # x = Activation('relu')(x)

    x = UpSampling1D()(x)
    x = Conv1D(filters=32*8, kernel_size=4, strides=2, padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    # x = Activation('relu')(x)

    x = UpSampling1D()(x)
    x = Conv1D(filters=32*4, kernel_size=4, strides=2, padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    # x = Activation('relu')(x)

    x = UpSampling1D()(x)
    x = Conv1D(filters=32*4, kernel_size=4, strides=2, padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    # x = Activation('relu')(x)

    x = UpSampling1D()(x)
    x = Conv1D(filters=32*2, kernel_size=4, strides=2, padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    # x = Activation('relu')(x)

    x = UpSampling1D()(x)
    x = Conv1D(filters=32, kernel_size=4, strides=2, padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    # x = Activation('relu')(x)

    x = UpSampling1D()(x)
    x = Conv1D(filters=1, kernel_size=4, strides=2, padding='same', kernel_initializer='he_normal')(x)
    out = Activation('sigmoid')(x)
    model = Model(inputs=[G_in, in_label], outputs=out)

#     x = LeakyReLU(alpha=0.2)(x)
#     x = Flatten()(x)
#     x = Dense(out_dim)(x)
#     x = Reshape((out_dim,1))(x)
#     out = Activation('sigmoid')(x)
#     model = Model(inputs=[G_in, in_label], outputs=out)
    
    return model

# g_model = generator(noise_dim=186, input_classes=3, out_dim=186)
# plot_model(g_model, to_file='gen.png', show_shapes=True)

In [4]:
def create_gan(d_model, g_model):
    
    # make weights in the discriminator not trainable
    d_model.trainable = False
    # get noise and label inputs from generator model
    gen_noise, gen_label = g_model.input
    # get image output from the generator model
    gen_output = g_model.output
    # connect image output and label input from generator as inputs to discriminator
    gan_output = d_model([gen_output, gen_label])
    # define gan model as taking noise and label and outputting a classification
    model = Model([gen_noise, gen_label], gan_output)
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    loss = 'binary_crossentropy'
    model.compile(loss=loss, optimizer=opt)
    return model

# gan_model = create_gan(d_model, g_model)
# plot_model(gan_model, to_file='gan.png', show_shapes=True)

In [5]:
def reshape(X):
    if len(X.shape) == 1:
        X = X.reshape(X.shape[0], 1)
        return X
    else:
        if X.shape[-1] == 1:
            return X
        else:
            X = X.reshape(X.shape[0], X.shape[1], 1)
            return X

In [6]:
def load_real_samples():
    X = np.load('Data/X.npy')
    y = np.load('Data/y.npy')

    # print (X.shape, y.shape)

    X_N = X[y==0]
    X_S = X[y==1]
    X_V = X[y==2]

    y_N = y[y==0]
    y_S = y[y==1]
    y_V = y[y==2]

    # print (X_N.shape, y_N.shape)
    # print (X_S.shape, y_S.shape)
    # print (X_V.shape, y_V.shape)

#     X_N=X_N.reshape(X_N.shape[0],X_N.shape[1],1)
#     X_S=X_S.reshape(X_S.shape[0],X_S.shape[1],1)
#     X_V=X_V.reshape(X_V.shape[0],X_V.shape[1],1)

    # print (X_N.shape, y_N.shape)
    # print (X_S.shape, y_S.shape)
    # print (X_V.shape, y_V.shape)
    return reshape(X_N), y_N, reshape(X_S), y_S, reshape(X_V), y_V

In [7]:
def generate_real_samples(X_N, y_N, X_S, y_S, X_V, y_V, n_samples):
    
    # choose random instances
    i_N = randint(0, y_N.shape[0], int(n_samples/3))
    i_S = randint(0, y_S.shape[0], int(n_samples/3))
    i_V = randint(0, y_V.shape[0], int(n_samples/3))
    
    # select ECG and labels
    X = np.vstack((X_N[i_N], X_S[i_S], X_V[i_V]))
    labels = np.hstack((y_N[i_N], y_S[i_S], y_V[i_V]))
    
    # generate class labels
    y = reshape(np.random.uniform(0.8, 1, n_samples))
#     y = y.reshape(y.shape[0], 1)
#     y = np.ones((n_samples, 1))
    return [X, labels], y

In [8]:
# generate points in latent space as input for the generator
# normal noise
def generate_latent_points(latent_dim, n_samples, n_classes=3):
    # generate points in the latent space
#     X_fake = np.random.uniform(0, 1.0, size=[n_samples, latent_dim])
    X_fake = np.random.normal(0,1.0,(n_samples,latent_dim))
    # generate labels
    labels_fake = np.hstack((np.zeros(int(n_samples/3)), np.ones(int(n_samples/3)), 2*np.ones(int(n_samples/3))))
    np.random.shuffle(labels_fake)
    return [reshape(X_fake), reshape(labels_fake)]

In [9]:
# use the generator to generate n fake examples, with class labels
def generate_fake_samples(generator, latent_dim, n_samples):
    # generate points in latent space
    z_input, labels_input = generate_latent_points(latent_dim, n_samples)
    # predict outputs
    ecgs = generator.predict([z_input, labels_input])
    # create class labels
    y = reshape(np.random.uniform(0, 0.2, n_samples))
#     y = y.reshape(y.shape[0], 1)
#     y = np.zeros((n_samples, 1))
    return [ecgs, labels_input], y

In [10]:
# create and save a plot of generated images
def save_plot(X, n, name):
    plt.figure(figsize=(10,3))
    for i in range(n * n):
        # define subplot
        plt.subplot(n, n, 1 + i)
        # turn off axis
        plt.axis('off')
        # plot raw pixel data
        plt.plot(X[i, :, 0])
    plt.savefig(name, dpi=100)
    plt.close()

In [11]:
def get_real_samples(X_N, y_N, X_S, y_S, X_V, y_V):
    
    # choose random instances
    i_N = randint(0, y_N.shape[0], 1)
    i_S = randint(0, y_S.shape[0], 1)
    i_V = randint(0, y_V.shape[0], 1)
    
    # select ECG and labels
    X = np.vstack((X_N[i_N], X_S[i_S], X_V[i_V]))
    return X

In [12]:
def save_new_plot(X_R, z_input, n_batch, name):
    n = 3
    Win = (n_batch//3)
    # XX = np.vstack((X_R, z_input))
    XX = np.vstack((X_R, z_input[0:n,:,:], z_input[Win:Win+n,:,:], z_input[2*Win:2*Win+n,:,:]))
    plt.figure(figsize=(15,5))
    for i in range(n):
        # subplot(R, C, Plot_No)
        plt.subplot(n+1, n, 1 + i)
        plt.axis('off')
        plt.plot(XX[i,:,:])
    for i in range(n, ((n+1)*(n+1))-(n+1)):
        # define subplot
        plt.subplot(n+1, n, 1 + i)
        # turn off axis
        plt.axis('off')
        # plot raw pixel data
        plt.plot(XX[i,:,:])
    # plt.show()
    plt.savefig(name, dpi=75)
    plt.close()

In [13]:
# size of the latent space
latent_dim = 186
# size of the data
data = 186
# classes
classes = 3
n=3

n_epochs=10

# multiples of three (three classes) (less thyan 24000)
n_batch=300

# create the discriminator
d_model = discriminator(data_dim=data, input_classes=classes)
# create the generator
g_model = generator(noise_dim=latent_dim, input_classes=classes, out_dim=data)
# create the gan
gan_model = create_gan(d_model, g_model)

folder_name = 'CGAN_BN_Sigmoid/'
if not os.path.isdir(folder_name):
    os.mkdir(folder_name)

plot_model(d_model, to_file=folder_name+'disc.pdf', show_shapes=True)
plot_model(g_model, to_file=folder_name+'gen.pdf', show_shapes=True)
plot_model(gan_model, to_file=folder_name+'gan.pdf', show_shapes=True)

# load image data
X_N, y_N, X_S, y_S, X_V, y_V = load_real_samples()
# # train model
# train(g_model, d_model, gan_model, dataset, latent_dim)

bat_per_epo = int(y_S.shape[0] / n_batch)
half_batch = int(n_batch / 2)

# Loss Values
G_L = np.infty
plt.ioff()

filename = folder_name + 'Plots'
if not os.path.isdir(filename):
    os.mkdir(filename)

model_name = folder_name + 'Model/'
if not os.path.isdir(model_name):
    os.mkdir(model_name)

f = open(folder_name + 'Loss.csv', 'w')
f.write('d_loss1, d_loss2, g_loss \n')
f.close()

f = open(folder_name + 'Stats.csv', 'w')
for i in range(3):
    for mtc in metric_to_calculate:
        f.write(str(mtc)+'_'+str(i)+',')
f.write('\n')
f.close()

# manually enumerate epochs
for i in range(n_epochs):
    # enumerate batches over the training set
    for j in range(bat_per_epo):
        
        # get randomly selected 'real' samples
        [X_real, labels_real], y_real = generate_real_samples(X_N, y_N, X_S, y_S, X_V, y_V, half_batch)
        # update discriminator model weights
        d_loss1, _ = d_model.train_on_batch([X_real, labels_real], y_real)
        
        # generate 'fake' examples
        [X_fake, labels], y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
        # update discriminator model weights
        d_loss2, _ = d_model.train_on_batch([X_fake, labels], y_fake)
        
        # prepare points in latent space as input for the generator
        [z_input, labels_input] = generate_latent_points(latent_dim, n_batch)
        # create inverted labels for the fake samples
        y_gan = reshape(np.random.uniform(0.8, 1, n_batch))
        # y_gan = ones((n_batch, 1))
        # update the generator via the discriminator's error
        g_loss = gan_model.train_on_batch([z_input, labels_input], y_gan)
        
        if g_loss < G_L:
            G_L = g_loss
            g_model.save(model_name + str(i*1000 + j) + '_cgan_generator.h5')
        
        f = open(folder_name + 'Loss.csv', 'a')
        f.write(str(d_loss1)+','+str(d_loss2)+','+str(g_loss)+'\n')
        f.close()

        # summarize loss on this batch
        if (j+1)%2 == 0:
            print('>%d, %d/%d, d1=%.5f, d2=%.5f g=%.5f' %(i, j, bat_per_epo, d_loss1, d_loss2, g_loss))
            name = filename+'/'+str(i*1000 + j)+'.jpg'
            # generate ECGs
            latent_points, _ = generate_latent_points(latent_dim, n_batch)
            # specify labels
            labels = np.hstack((np.zeros(n_batch//3), np.ones(n_batch//3), 2*np.ones(n_batch//3)))
            # generate images
            z_input  = g_model.predict([latent_points, labels])
            X_R = get_real_samples(X_N, y_N, X_S, y_S, X_V, y_V)
            save_new_plot(X_R, z_input, n_batch, name)
            
        if (j+1)%10 == 0:
            [X, _], _ = generate_real_samples(X_N, y_N, X_S, y_S, X_V, y_V, n_batch)
            evaluate(X, z_input, classes, metric_to_calculate, n_batch, folder_name, samples=4)

>0, 1/80, d1=0.36862, d2=0.69752 g=0.75239
>0, 3/80, d1=0.36956, d2=0.45253 g=0.73932
>0, 5/80, d1=0.36555, d2=0.38284 g=0.80587
>0, 7/80, d1=0.34115, d2=0.36173 g=0.76966
>0, 9/80, d1=0.35308, d2=0.37614 g=0.72965
>0, 11/80, d1=0.35226, d2=0.36072 g=0.70216
>0, 13/80, d1=0.33964, d2=0.35588 g=0.65455
>0, 15/80, d1=0.32742, d2=0.36307 g=0.58004
>0, 17/80, d1=0.34007, d2=0.34378 g=0.55585
>0, 19/80, d1=0.32977, d2=0.34627 g=0.49626
>0, 21/80, d1=0.36353, d2=0.34891 g=0.46837
>0, 23/80, d1=0.35036, d2=0.37476 g=0.44798
>0, 25/80, d1=0.32489, d2=0.34316 g=0.41608
>0, 27/80, d1=0.34139, d2=0.35226 g=0.44176
>0, 29/80, d1=0.32630, d2=0.35190 g=0.37773
>0, 31/80, d1=0.33390, d2=0.35731 g=0.36594
>0, 33/80, d1=0.36384, d2=0.34522 g=0.34120
>0, 35/80, d1=0.34164, d2=0.33599 g=0.34846
>0, 37/80, d1=0.32856, d2=0.34640 g=0.34656
>0, 39/80, d1=0.32783, d2=0.33636 g=0.34188
>0, 41/80, d1=0.33465, d2=0.34104 g=0.33752
>0, 43/80, d1=0.32403, d2=0.34476 g=0.34060
>0, 45/80, d1=0.33111, d2=0.34116 g=0

>4, 55/80, d1=0.31643, d2=0.33592 g=0.34939
>4, 57/80, d1=0.32202, d2=0.35380 g=0.33405
>4, 59/80, d1=0.32888, d2=0.31729 g=0.35299
>4, 61/80, d1=0.33893, d2=0.33316 g=0.33905
>4, 63/80, d1=0.31263, d2=0.34497 g=0.33035
>4, 65/80, d1=0.33668, d2=0.33330 g=0.32106
>4, 67/80, d1=0.33562, d2=0.35959 g=0.68625
>4, 69/80, d1=0.31899, d2=0.31844 g=0.35359
>4, 71/80, d1=0.33515, d2=0.35086 g=0.33540
>4, 73/80, d1=0.34633, d2=0.34368 g=0.34519
>4, 75/80, d1=0.35403, d2=0.37116 g=3.10318
>4, 77/80, d1=0.35044, d2=0.32316 g=2.13748
>4, 79/80, d1=0.34040, d2=0.34723 g=1.64416
>5, 1/80, d1=0.33749, d2=0.34307 g=1.66255
>5, 3/80, d1=0.32036, d2=0.33414 g=1.38169
>5, 5/80, d1=0.34176, d2=0.36106 g=1.34881
>5, 7/80, d1=0.33544, d2=0.34854 g=1.56775
>5, 9/80, d1=0.33711, d2=0.33747 g=1.27036
>5, 11/80, d1=0.33557, d2=0.33485 g=1.31801
>5, 13/80, d1=0.32503, d2=0.34597 g=1.47014
>5, 15/80, d1=0.33697, d2=0.33540 g=1.28561
>5, 17/80, d1=0.31539, d2=0.32220 g=1.22854
>5, 19/80, d1=0.34338, d2=0.33504 g=1

>9, 29/80, d1=0.32400, d2=0.34766 g=0.40017
>9, 31/80, d1=0.32998, d2=0.32939 g=0.39578
>9, 33/80, d1=0.33579, d2=0.34743 g=0.48161
>9, 35/80, d1=0.33419, d2=0.35506 g=0.34734
>9, 37/80, d1=0.34312, d2=0.38467 g=0.70286
>9, 39/80, d1=0.32628, d2=0.37032 g=0.55401
>9, 41/80, d1=0.31994, d2=0.34839 g=0.41835
>9, 43/80, d1=0.32501, d2=0.35023 g=0.43845
>9, 45/80, d1=0.32179, d2=0.33309 g=0.39581
>9, 47/80, d1=0.32009, d2=0.33211 g=0.41011
>9, 49/80, d1=0.32344, d2=0.32705 g=0.46200
>9, 51/80, d1=0.30822, d2=0.36697 g=0.59852
>9, 53/80, d1=0.31394, d2=0.34031 g=0.38340
>9, 55/80, d1=0.32191, d2=0.34408 g=0.35241
>9, 57/80, d1=0.33180, d2=0.34953 g=0.36024
>9, 59/80, d1=0.34447, d2=0.34803 g=0.35136
>9, 61/80, d1=0.31139, d2=0.32852 g=0.37071
>9, 63/80, d1=0.32635, d2=0.31325 g=0.40362
>9, 65/80, d1=0.32093, d2=0.32424 g=0.43088
>9, 67/80, d1=0.31900, d2=0.32857 g=0.36516
>9, 69/80, d1=0.32748, d2=0.34482 g=0.39843
>9, 71/80, d1=0.32144, d2=0.32553 g=0.38794
>9, 73/80, d1=0.33545, d2=0.3338