In [None]:
# import the libraries
import numpy as np
import os
import tensorflow as tf
import pandas as pd
import multiprocessing as mp
import pickle
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Input, Dense, RepeatVector, TimeDistributed, Bidirectional, Lambda, Masking, Multiply
from tensorflow.keras.layers import LSTM
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.losses import mean_squared_error
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import Conv2D, BatchNormalization, LayerNormalization, Input, GlobalAveragePooling2D
from tensorflow.keras.layers import Conv2DTranspose, Reshape
import glob
from sklearn.model_selection import train_test_split
from IPython.display import clear_output
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Layer
from tensorflow.keras import backend as K
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.optimizers import Adam
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.layers import Reshape
from keras.constraints import Constraint
from scipy import stats
from sklearn.decomposition import PCA
import seaborn as sns

os.environ["CUDA_VISIBLE_DEVICES"]="0,1" 
gpu_devices = tf.config.list_physical_devices('GPU')
for device in gpu_devices:
    tf.config.experimental.set_memory_growth(device, True)
tf.random.set_seed(22)

In [None]:
# prepare the different types of data

path = '/home/chengli/dataset/df_varsAll_cleaned_withHBA1c_withBMI_pred_imp_v6.csv'
df_EHRs = pd.read_csv(path, index_col=0)

#variables = ['HBA1C']
variables =  ['TG', 'CREAT', 'CAC', 'COLHDL', 'COLTOT', 'COLLDL', 'HBA1C', 
             'EK201', 'EK202', 'TT103']
variables_pred = [v+'_pred_death_missings' for v in variables]
variables_pred2 = [v+'_pred' for v in variables]
col = variables_pred2

def replace_na(x):
    n_na = x.isna().sum()
    if n_na<len(variables_pred) + 2:
        for i,v in enumerate(variables_pred):
            x[v] = x[variables_pred[i]]
    return x

df_EHRs.loc[df_EHRs.ttd>0,'ttd'] = 0
df_EHRs['ttd'] = df_EHRs['ttd'].fillna(-6)
df_EHRs['ttd'] = df_EHRs['ttd']*-1
df_EHRs = df_EHRs.drop(columns = ['age_death'])
df_EHRs = df_EHRs.apply(replace_na, axis = 1)

In [None]:
df_EHRs[variables_pred2].head(10)

In [None]:
file = '~/send_chengli/kernel_matrix_allVars_cleaned_v7.pkl'
K_matrix = pd.read_pickle(file)
idps_list = K_matrix.index

In [None]:
# Normalization
df_EHRs[col] = (df_EHRs[col]-df_EHRs[col].min())/(df_EHRs[col].max()-df_EHRs[col].min())

In [None]:
col_pre = col + ['months_from_diag']
df_EHRs.sex = df_EHRs.sex.replace(to_replace = ['H', 'D'], value=[1,0])
def get_real_df(idp):
    df = df_EHRs.loc[df_EHRs.idp==idp, col_pre]
    df = df.sort_values(by='months_from_diag')
    df = df.loc[:, col_pre]
    df.loc[df_EHRs.isna().any(axis=1)] = -0.1
    df.drop('months_from_diag', axis=1)
    return np.array(df, dtype=np.float32)
n_cpu = 10
with mp.Pool(n_cpu) as pool:
    df_ehrs_real = pool.map(get_real_df, idps_list)

In [None]:
np.shape(df_ehrs_real)

In [None]:
# clip model weights to a given hypercube
class ClipConstraint(Constraint):
    # set clip value when initialized
     def __init__(self, clip_value):
            self.clip_value = clip_value
    
    # clip model weights to hypercube
     def __call__(self, weights):
            return K.clip(weights, -self.clip_value, self.clip_value)
 
     # get the config
     def get_config(self):
             return {'clip_value': self.clip_value}

In [None]:
def matrices_loss(k_pred,k_true):
    
    k_pred = k_pred/tf.norm(k_pred)
    k_true = k_true/tf.norm(k_true)
    L = tf.norm(k_pred-k_true)
    return L

def vectors_loss(mask_value):
    mask_value = K.variable(mask_value)
    def masked_mse(y_true, y_pred):
        # find out which timesteps in `y_true` do not contain mascked value
        mask = K.not_equal(y_true, mask_value)
        mask = K.cast(mask, K.floatx())

        # multiply categorical_crossentropy with the mask
        loss = (y_true-y_pred)*mask
        loss = K.square(loss) 

        # take average w.r.t. the number of unmasked entries
        return K.sum(loss) / K.sum(mask)
    return masked_mse

In [None]:
# calculate wasserstein loss
def wasserstein_loss(y_true, y_pred):
     return K.mean(y_true * y_pred)

In [None]:
# define the self-attention layer
class SelfAttentionLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(SelfAttentionLayer, self).__init__()

    def build(self, input_shape):
        self.query = self.add_weight(shape=[input_shape[-1], 1], initializer="glorot_uniform", trainable=True, name="query")
        self.key = self.add_weight(shape=[input_shape[-1], 1], initializer="glorot_uniform", trainable=True, name="key")

    def call(self, inputs):
        query_scores = tf.einsum('btk,kc->btc', inputs, self.query)
        key_scores = tf.einsum('btk,kc->btc', inputs, self.key)

        attention_scores = tf.nn.softmax(query_scores * key_scores, axis=1)

        output = tf.einsum('btc,btk->bkc', attention_scores, inputs)
        return output

    def get_config(self):
        config = super(SelfAttentionLayer, self).get_config()
        return config

    @classmethod
    def from_config(cls, config):
        return cls()

In [None]:
# define the critic
def define_critic(disc_dim=200, inp_dim=10, time_step=10):
    
    # weight initialization
    init = tf.keras.initializers.RandomNormal(stddev=0.02)
    # weight constraint
    const = ClipConstraint(0.5)
    # define the model
    inp = Input(shape=(time_step, inp_dim,))
    x = SelfAttentionLayer()(inp)
    x = LSTM(disc_dim, return_sequences=True, kernel_initializer=init, kernel_constraint=const)(x)
    x = LayerNormalization()(x)
    x = LSTM(disc_dim, return_sequences=True, kernel_initializer=init, kernel_constraint=const)(x)
    x = LayerNormalization()(x)
    #x = LSTM(20, kernel_initializer=init, kernel_constraint=const)(x)
    #x = LayerNormalization()(x)
    x = Dense(1)(x)
    
    model = Model(inputs=inp, outputs=x, name='discriminator')
    
    # compile the model
    opt = tf.keras.optimizers.RMSprop(learning_rate=0.00005)
    model.compile(loss=wasserstein_loss, optimizer=opt)
    
    return model

In [None]:
def gen_decoder(vec_dim=100, latent_dim=200):
    
    decoder = Model(inputs=model_autoencoder.layers[7].input, outputs=model_autoencoder.outputs[0])
    decoder.trainable = False
  
    inp = Input(shape=(vec_dim,))   
    x = Dense(latent_dim)(inp)
    x = Dense(latent_dim)(x)
    x = decoder(x)
    
    model = Model(inputs=inp, outputs=x, name='generator_decoder')
    
    return model

In [None]:
def con_model(verbose=False, vec_dim=100):
    
    discriminator = define_critic()
    discriminator.trainable = False
    
    #discriminator.compile(loss=wasserstein_loss, optimizer=opt)
    generator = gen_decoder()
    
    inp = Input(shape=(vec_dim,))
    x = generator(inp)
    x = discriminator(x)
    
    gan = Model(inputs=inp, outputs=x)
    
    # compile the model
    opt = tf.keras.optimizers.RMSprop(learning_rate=0.00005)
    gan.compile(loss=wasserstein_loss, optimizer=opt)
        
    if verbose:
        generator.summary()
        discriminator.summary()
        gan.summary()
    
    return generator, discriminator, gan

generator_con, discriminator_con, gan_con = con_model(verbose=True)

In [None]:
def get_noise(BATCH_SIZE=128):   
    return np.random.rand(BATCH_SIZE, 100).astype(np.float32)

In [None]:
# number of discriminator updates per alternating training iteration
DISC_UPDATES = 10 # COMM: have you tried to play a bit with these two parameters? what changes increasing the GEN steps?
# number of generator updates per alternating traning iteration
GEN_UPDATES = 1
# define the save interval
SAVE_INTERVAL = 20
# define the batch size
BATCH_SIZE = 128

c1_hist, c2_hist, g_hist = list(), list(), list()

def run_training(gan,generator, discriminator, num_epochs=100, save_freq=10):
     
    half_batch = int(BATCH_SIZE/2)
    df_real = np.array(df_ehrs_real, dtype=np.float32)
    
    # main training loop   
        for iteration in range(200):
            
            c1_tmp, c2_tmp, g_tmp = list(), list(), list()
            discriminator.trainable = True
            # discriminator training loop
            for _ in range(DISC_UPDATES):
                
                # select a random set of real EHRs
                rand_int = np.random.randint(0, df_real.shape[0]-half_batch)
                EHRs_real = df_real[rand_int:(rand_int+half_batch), :, :]
                # generate a set of random noise vectors
                noise = get_noise(BATCH_SIZE=half_batch)
                # generate a set of fake EHRs
                EHRs_fake = generator.predict(noise)
                # training the discriminator on real EHRs with label -1
                d_loss_real = discriminator.train_on_batch(EHRs_real, -np.ones([half_batch, 1], dtype=np.float32))
                c1_tmp.append(d_loss_real)
                # training the discriminator on fake EHRs with label 1
                d_loss_fake = discriminator.train_on_batch(EHRs_fake, np.ones([half_batch, 1], dtype=np.float32))
                c2_tmp.append(d_loss_fake)
            
            
            
            # generator training loop
            loss = 0
            y = -np.ones([BATCH_SIZE, 1], dtype=np.float32)
            for _ in range(GEN_UPDATES):
                # generate a set of random noise vectors
                noise = get_noise(BATCH_SIZE=BATCH_SIZE)
                # train the generator on fake EHRs with label 1
                loss += gan.train_on_batch(noise, y)
            g_tmp.append(loss / GEN_UPDATES)  
        
        clear_output(True)
               
        # visualize the loss
        print('Epoch', epoch)
        c1_hist.append(np.mean(c1_tmp))
        c2_hist.append(np.mean(c2_tmp))
        g_hist.append(np.mean(g_tmp))
        plt.plot(range(len(c1_hist)), c1_hist)
        plt.plot(range(len(c2_hist)), c2_hist)
        plt.plot(range(len(g_hist)), g_hist)
        plt.legend(['disc_real_loss', 'disc_fake_loss', 'gen_loss'])
        plt.show()
        nosie = get_noise()
        df = generator.predict(noise)
        print(df[10,:,:])
        
    return generator, discriminator, gan

In [None]:
generator_EHRs_trained, discriminator_EHRs_trained, gan_EHRs_trained = run_training(gan_con,
                                                                                    generator_con,
                                                                                    discriminator_con,
                                                                                    num_epochs=200)