# Conditional Generative Adversarial Network for gene expression inference

In [0]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

In [2]:
from keras.models import Model
from keras.utils import Sequence
from keras.optimizers import SGD
from keras.metrics import mean_absolute_error
from keras.layers import Input, Dense
from keras.regularizers import l1

Using TensorFlow backend.


# Data Preperation

Download the dataset from [here](https://drive.google.com/file/d/18Et3ewt471dN78tZxDHg7g-a8WI9z00H/view?usp=sharing).

### Load data to Google Colab notebook

In [3]:
# Install the PyDrive wrapper & import libraries.
# This only needs to be done once per notebook.
!pip install -U -q PyDrive
import os
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client.
# This only needs to be done once per notebook.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)


# choose a local (colab) directory to store the data.
local_download_path = os.path.expanduser('')
try:
  os.makedirs(local_download_path)
except: pass

# Download a file based on its file ID.
# A file ID looks like: laggVyWshwcyP6kEI-y_W3P8D26sz
file_id = '18Et3ewt471dN78tZxDHg7g-a8WI9z00H'
downloaded = drive.CreateFile({'id': file_id})
fname = os.path.join(local_download_path, downloaded['title'])
# print('Downloaded content "{}"'.format(downloaded.GetContentString()))

print('title: %s, id: %s' % (downloaded['title'], downloaded['id']))
downloaded.GetContentFile(fname)
print(fname)

[?25l[K     |▎                               | 10kB 22.8MB/s eta 0:00:01[K     |▋                               | 20kB 3.3MB/s eta 0:00:01[K     |█                               | 30kB 4.8MB/s eta 0:00:01[K     |█▎                              | 40kB 3.1MB/s eta 0:00:01[K     |█▋                              | 51kB 3.8MB/s eta 0:00:01[K     |██                              | 61kB 4.5MB/s eta 0:00:01[K     |██▎                             | 71kB 5.2MB/s eta 0:00:01[K     |██▋                             | 81kB 5.9MB/s eta 0:00:01[K     |███                             | 92kB 6.5MB/s eta 0:00:01[K     |███▎                            | 102kB 5.0MB/s eta 0:00:01[K     |███▋                            | 112kB 5.0MB/s eta 0:00:01[K     |████                            | 122kB 5.0MB/s eta 0:00:01[K     |████▎                           | 133kB 5.0MB/s eta 0:00:01[K     |████▋                           | 143kB 5.0MB/s eta 0:00:01[K     |█████                     

### Read .gctx data

In [4]:
!pip install -q cmapPy
# https://github.com/cmap/cmapPy/blob/master/tutorials/cmapPy_pandasGEXpress_tutorial.ipynb

[?25l[K     |██▏                             | 10kB 21.8MB/s eta 0:00:01[K     |████▍                           | 20kB 3.5MB/s eta 0:00:01[K     |██████▌                         | 30kB 5.1MB/s eta 0:00:01[K     |████████▊                       | 40kB 3.2MB/s eta 0:00:01[K     |███████████                     | 51kB 3.9MB/s eta 0:00:01[K     |█████████████                   | 61kB 4.6MB/s eta 0:00:01[K     |███████████████▎                | 71kB 5.3MB/s eta 0:00:01[K     |█████████████████▌              | 81kB 6.0MB/s eta 0:00:01[K     |███████████████████▋            | 92kB 6.7MB/s eta 0:00:01[K     |█████████████████████▉          | 102kB 5.1MB/s eta 0:00:01[K     |████████████████████████        | 112kB 5.1MB/s eta 0:00:01[K     |██████████████████████████▏     | 122kB 5.1MB/s eta 0:00:01[K     |████████████████████████████▍   | 133kB 5.1MB/s eta 0:00:01[K     |██████████████████████████████▌ | 143kB 5.1MB/s eta 0:00:01[K     |██████████████████████████

In [0]:
# Load and reshape data so each row corresponds to a sample
from cmapPy.pandasGEXpress.parse import parse
gctx_data = parse("HW6-Data.gctx")

In [6]:
data_df = gctx_data.data_df.T
print(data_df.shape)
data_df.head()

(2921, 5503)


rid,ENSG00000175063.12,ENSG00000171174.9,ENSG00000160326.9,ENSG00000204209.6,ENSG00000087460.18,ENSG00000105968.14,ENSG00000163686.9,ENSG00000079739.11,ENSG00000134057.10,ENSG00000196154.7,ENSG00000148400.9,ENSG00000164362.14,ENSG00000176171.7,ENSG00000119185.8,ENSG00000115738.5,ENSG00000171453.13,ENSG00000214063.6,ENSG00000147383.6,ENSG00000103876.7,ENSG00000109654.10,ENSG00000137843.7,ENSG00000104365.9,ENSG00000123130.12,ENSG00000100225.13,ENSG00000010810.12,ENSG00000126602.6,ENSG00000152601.13,ENSG00000149658.13,ENSG00000169598.11,ENSG00000087088.15,ENSG00000095066.7,ENSG00000161204.7,ENSG00000198369.5,ENSG00000145293.10,ENSG00000065534.14,ENSG00000072062.9,ENSG00000165704.10,ENSG00000170502.8,ENSG00000105699.12,ENSG00000172175.8,...,ENSG00000125753.9,ENSG00000125741.4,ENSG00000064692.14,ENSG00000010310.4,ENSG00000125743.6,ENSG00000011478.7,ENSG00000177051.5,ENSG00000177045.6,ENSG00000064652.6,ENSG00000170604.3,ENSG00000151292.13,ENSG00000064199.2,ENSG00000154146.7,ENSG00000149548.10,ENSG00000182013.13,ENSG00000160013.4,ENSG00000134910.8,ENSG00000155324.5,ENSG00000090372.10,ENSG00000181027.6,ENSG00000105281.8,ENSG00000110060.4,ENSG00000113368.7,ENSG00000173926.5,ENSG00000130748.6,ENSG00000064309.10,ENSG00000130749.5,ENSG00000142230.7,ENSG00000105327.11,ENSG00000105321.8,ENSG00000064651.9,ENSG00000110063.4,ENSG00000063169.6,ENSG00000105499.9,ENSG00000142227.6,ENSG00000161558.6,ENSG00000043039.5,ENSG00000105438.4,ENSG00000105464.3,ENSG00000182324.5
cid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1
GTEX-N7MS-0007-SM-2D7W1,6.258,1.57004,12.819815,25.197998,114.751381,19.435654,1.811746,13.255496,3.265682,272.174011,18.31282,0.342227,7.193198,6.058734,35.885273,5.476271,4.680382,5.160254,5.301565,0.404902,0.447544,19.715021,12.691573,363.121033,35.115589,10.6437,18.738678,12.50847,3.961152,44.450912,5.175422,19.583858,1.780967,7.288062,1.285417,38.46994,8.974514,4.648564,8.893391,9.216136,...,127.530022,2.516836,0.008248,4.131555,20.181681,9.524962,9.139047,0.583649,1.100281,7.93745,2.745857,1.039201,60.472664,1.185207,0.084997,5.044806,14.393388,2.090897,23.079409,4.407234,19.984457,4.492527,17.26272,3.341017,5.359549,0.09262,6.270971,17.616255,13.506384,26.032234,1.507246,16.141371,8.745433,0.626008,143.632782,4.309718,0.097738,37.715435,0.264591,0.008248
GTEX-N7MS-0008-SM-4E3JI,38.256783,1.417273,3.846615,13.019768,65.975456,21.775644,2.026841,21.197948,56.756538,343.272125,2.461825,0.007284,60.472664,18.91296,29.144449,13.602143,34.434437,11.517919,6.471162,2.251328,0.007284,7.389894,12.720635,28.596855,22.336266,16.432726,28.066408,13.506384,2.540541,54.584293,3.453465,15.831875,6.571239,19.810352,3.938659,25.686693,28.365864,6.58014,1.645448,6.144692,...,38.674606,2.775685,0.404902,0.950917,90.587181,6.771498,3.824857,4.283224,4.111968,8.198557,6.31073,9.505074,5.807346,1.890565,1.013015,2.915665,47.71122,11.408774,15.766845,5.973472,67.406731,4.293163,25.409655,1.880815,15.609934,0.660897,6.270971,47.598354,2.918443,5.741512,3.179466,15.544182,3.222655,0.492695,139.820328,2.678623,0.007284,113.112526,2.248742,0.587866
GTEX-N7MS-0011-R10A-SM-2HMJK,0.018755,0.958147,9.762014,11.459002,227.624619,35.348461,8.654851,13.895129,3.640244,2.26649,2.868886,0.046608,37.521179,21.95883,46.217178,8.708959,2.130838,7.83899,1.977618,25.197998,3.091491,5.194419,3.725604,18.756191,25.409655,16.71736,9.929487,14.231618,2.824677,9.29339,7.752186,22.701349,13.101074,43.441166,3.741034,39.98505,60.657154,12.527876,3.94509,1.513837,...,4.363213,3.400148,0.699935,0.137225,49.061504,6.991272,3.069155,0.510642,5.990269,17.886543,5.495917,5.217256,977.106323,0.128216,58.799564,0.17497,6.122733,7.323215,24.21722,6.058734,0.293274,2.689233,1.145243,0.296909,38.18969,1.137451,5.081839,33.039825,2.716285,5.587399,5.5914,6.231666,3.787384,7.064671,5.62324,3.376416,0.472907,20.261278,2.430775,0.44992
GTEX-N7MS-0011-R11A-SM-2HMJS,0.579502,1.871301,4.256632,15.027001,256.823669,64.796654,10.815572,9.714371,4.323125,1.752681,0.841857,0.002584,23.311316,10.36511,5.856566,8.517663,9.79667,4.822447,2.608894,25.288555,1.106072,11.892198,3.54921,17.051687,19.198412,13.895129,7.035053,23.57078,7.614133,12.026003,9.993617,28.21661,6.457727,24.756954,0.634676,39.909748,35.764259,13.369477,1.885782,3.737905,...,3.248432,2.113203,4.253413,0.547026,42.564186,6.890395,3.251218,0.446398,3.906823,21.330336,6.364108,0.820239,1.020474,0.43929,39.243355,0.059889,8.775948,1.169072,11.150365,2.160473,0.253535,3.719441,8.482278,0.637625,11.812765,11.158457,8.733294,20.503635,7.981469,5.092885,7.415754,7.426057,7.64589,10.552812,2.984973,4.967032,0.069738,19.79162,1.013015,0.88655
GTEX-N7MS-0011-R1a-SM-2HMJG,0.06581,0.737554,11.014378,11.349816,199.399826,27.737179,21.662554,9.68714,3.774905,2.225861,4.427747,0.020815,34.379059,13.422124,14.555308,7.405584,2.529936,15.978411,3.326518,35.585335,0.72811,5.456933,2.373802,29.427275,26.656343,13.173176,12.849655,14.277672,3.54921,12.518096,5.259296,22.5779,6.771498,51.644218,4.00586,44.836121,58.799564,11.006389,4.131555,1.999842,...,4.471974,2.918443,0.515728,3.46827,55.189522,8.58879,3.309234,0.769298,6.355098,13.7203,6.575679,6.475808,184.620148,0.240404,43.265656,0.233248,6.621542,6.484859,36.313843,6.904549,0.64773,3.515989,0.510642,0.563046,36.377235,1.039201,6.399546,22.217848,1.897856,5.811523,9.41801,5.111451,4.118488,8.733294,8.905753,3.787384,0.205818,28.441481,1.129635,0.301559


### Prepare data for training

In [0]:
# Split data of landmark and target genes so each 
data = data_df.to_numpy()
n_landmark = 943  # number of landmark genes
n_est = np.size(data, axis=1) - n_landmark  # number of genes to be estimated
X = data[:, :n_landmark]
Y = data[:, n_landmark:]

In [0]:
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.25, random_state=42)

In [0]:
# Normalize data using train mean and std for each gene
x_train_mean = x_train.mean(axis=0)
x_train_std = x_train.std(axis=0)
y_train_mean = y_train.mean(axis=0)
y_train_std = y_train.std(axis=0)
# print("Mean: ", train_mean)
# print("Std: ", train_std)

# normalization
def normalize(x, mean, std):
  return (x - mean) / std

x_train = normalize(x_train, x_train_mean, x_train_std)
x_test = normalize(x_test, x_train_mean, x_train_std)
y_train = normalize(y_train, y_train_mean, y_train_std)
y_test = normalize(y_test, y_train_mean, y_train_std)

# Base Model

In [10]:
def build_predictor():
    input_data = Input(shape=(n_landmark,))
    h = Dense(3000, activation='linear')(input_data)
    out = Dense(n_est, activation='linear')(h)
    model = Model(input_data, out, name='generator')
    return model

base_model = build_predictor()
base_model.summary()

Instructions for updating:
Colocations handled automatically by placer.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 943)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 3000)              2832000   
_________________________________________________________________
dense_2 (Dense)              (None, 4560)              13684560  
Total params: 16,516,560
Trainable params: 16,516,560
Non-trainable params: 0
_________________________________________________________________


In [11]:
base_model.compile(loss='mean_absolute_error', optimizer=SGD(lr=0.1))
history = base_model.fit(x_train, y_train, 
                         epochs=100, batch_size=16, 
                         validation_data=(x_test, y_test), verbose=1)

Instructions for updating:
Use tf.cast instead.
Train on 2190 samples, validate on 731 samples
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoc

In [12]:
# report MAE on train and test data
print("Train MAE: ", base_model.evaluate(x_train, y_train))
print("Test MAE: ", base_model.evaluate(x_test, y_test))

Train MAE:  0.22059063235102178
Test MAE:  0.25253316225985983


# Adversarial Model

In [0]:
# parameters
g_lr = 0.05
d_lr = 0.1
reg_param = 0.01

In [0]:
def build_discriminator():
    input_data = Input(shape=(n_est,))
    hidden_layer = Dense(700, activation='linear', kernel_regularizer=l1(reg_param))(input_data) 
#                          bias_regularizer=l2(0.01))
    output = Dense(1, activation='sigmoid')(hidden_layer)
    model = Model(input_data, output, name='discriminator')
    return model

In [15]:
discriminator = build_discriminator()
discriminator_optimizer = SGD(lr=d_lr)
print(discriminator.summary())

discriminator.compile(loss='binary_crossentropy', optimizer=discriminator_optimizer, metrics=['accuracy'])

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 4560)              0         
_________________________________________________________________
dense_3 (Dense)              (None, 700)               3192700   
_________________________________________________________________
dense_4 (Dense)              (None, 1)                 701       
Total params: 3,193,401
Trainable params: 3,193,401
Non-trainable params: 0
_________________________________________________________________
None


In [16]:
generator = build_predictor()

# Freeze discriminator weights using `trainable` parameter
discriminator.trainable = False

# then create the combined model using generator and discriminator models.
gan_in = Input(shape=(n_landmark,), name='input')
y_gen = generator(gan_in)
valid = discriminator(y_gen)

gan_model = Model(input=gan_in, output=[y_gen, valid], name='gan')
print(gan_model.summary())
# Finally compile the combined model
gan_optimizer = SGD(lr=g_lr)
gan_model.compile(loss={'generator': 'mean_absolute_error', 'discriminator': 'binary_crossentropy'}, 
                  optimizer=gan_optimizer, 
                  metrics={'discriminator': 'accuracy'})

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input (InputLayer)           (None, 943)               0         
_________________________________________________________________
generator (Model)            (None, 4560)              16516560  
_________________________________________________________________
discriminator (Model)        (None, 1)                 3193401   
Total params: 19,709,961
Trainable params: 16,516,560
Non-trainable params: 3,193,401
_________________________________________________________________
None


  # This is added back by InteractiveShellApp.init_path()


In [0]:
class DataGenerator(Sequence):
    """
        Generates batches of data. Useful for training manually in Keras.
        If necessary, you may update this class to fit your needs.
        Adopted from: https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
    """
    def __init__(self, train_X, train_y, batch_size, shuffle=True):
        self.batch_size = batch_size
        self.train_X = train_X
        self.train_y = train_y
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.train_X) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        return self.train_X[indexes], self.train_y[indexes]


    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.train_X))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)


In [0]:
batch_size = 16
epochs = 500

In [0]:
data_generator = DataGenerator(x_train, y_train, batch_size=16)

In [20]:
print('GAN metrics: ', gan_model.metrics_names)
print('discriminator metrics: ', discriminator.metrics_names)

GAN metrics:  ['loss', 'generator_loss', 'discriminator_loss', 'discriminator_acc']
discriminator metrics:  ['loss', 'acc']


In [21]:
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

for epoch in range(epochs):
    
    n = 0
    g_total_loss = 0
    g_total_mae = 0
    d_total_loss = 0
    d_total_accuracy = 0
    
    for X, y in data_generator:
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        
        
        # Train the discriminator and calculate loss and accuracy
        y_fake = generator.predict(X)
        
        # then use discriminator train_on_batch twice.
        # once with y as input and once with the output of generator.predict as input.
        d_loss_real = discriminator.train_on_batch(y, valid)
        d_loss_fake = discriminator.train_on_batch(y_fake, fake)
        
        d_loss = d_loss_real[0] + d_loss_fake[0]
        d_accuracy = d_loss_real[1] + d_loss_fake[1]

        # ---------------------
        #  Train Generator
        # ---------------------

        # Train the generator and calculate loss and MAE
        g_losses = gan_model.train_on_batch(X, [y, valid])
        
        g_loss = g_losses[0]
        g_mae = g_losses[2]
        
        d_total_loss = (d_loss + d_total_loss) / 2
        d_total_accuracy = (d_accuracy + d_total_accuracy) / 2
        g_total_loss = g_loss + g_total_loss
        g_total_mae = g_mae + g_total_mae
        n += 1
        
#         print("%d Train: [D loss: %f, acc.: %.2f%%] [G loss: %f G MAE: %f]" % 
#               (epoch, d_total_loss / n, d_total_accuracy / n, g_total_loss / n, g_total_mae / n))
        
    y_test_fake = generator.predict(x_test)
    test_score_real = gan_model.evaluate(x_test, [y_test, np.ones((y_test.shape[0], 1))], verbose=0)
    test_score_fake = gan_model.evaluate(x_test, [y_test_fake, np.zeros((y_test.shape[0], 1))], verbose=0)
    test_score = np.add(test_score_real, test_score_fake) / 2
    
    test_d_loss = test_score[2]
    test_d_accuracy = test_score[3]
    test_g_loss = test_score[0]
    test_g_mae = test_score_real[1]
    
    
    print ("%d Train: [D loss: %f, acc.: %.2f%%] [G loss: %f G MAE: %f]" % 
           (epoch, d_total_loss / n, d_total_accuracy * 100 / n, g_total_loss / n, g_total_mae / n))
    print ("%d Test: [D loss: %f, acc.: %.2f%%] [G loss: %f G MAE: %f]" % 
           (epoch, test_d_loss, test_d_accuracy * 100, test_g_loss, test_g_mae))
    print('=================================================================================')
    
    data_generator.on_epoch_end()


  'Discrepancy between trainable weights and collected trainable'


0 Train: [D loss: 0.245199, acc.: 0.00%] [G loss: 37.272769 G MAE: 1.024674]
0 Test: [D loss: 0.693236, acc.: 50.00%] [G loss: 16.952491 G MAE: 0.703014]
1 Train: [D loss: 0.245162, acc.: 0.00%] [G loss: 17.270460 G MAE: 0.708476]
1 Test: [D loss: 0.693236, acc.: 50.00%] [G loss: 16.888998 G MAE: 0.650990]
2 Train: [D loss: 0.245132, acc.: 0.00%] [G loss: 17.188124 G MAE: 0.707990]
2 Test: [D loss: 0.693234, acc.: 50.00%] [G loss: 16.844282 G MAE: 0.610117]
3 Train: [D loss: 0.245114, acc.: 0.00%] [G loss: 17.127031 G MAE: 0.707596]
3 Test: [D loss: 0.693236, acc.: 50.00%] [G loss: 16.798920 G MAE: 0.576442]
4 Train: [D loss: 0.245110, acc.: 0.00%] [G loss: 17.061786 G MAE: 0.707177]
4 Test: [D loss: 0.693236, acc.: 50.00%] [G loss: 16.750573 G MAE: 0.548374]
5 Train: [D loss: 0.245115, acc.: 0.00%] [G loss: 17.004796 G MAE: 0.707184]
5 Test: [D loss: 0.693236, acc.: 50.00%] [G loss: 16.706819 G MAE: 0.525019]
6 Train: [D loss: 0.245108, acc.: 0.00%] [G loss: 16.947038 G MAE: 0.707197]

In [22]:
# report MAE on train and test data

test_score = gan_model.evaluate(x_test, [y_test, np.ones((y_test.shape[0], 1))])
train_score = gan_model.evaluate(x_train, [y_train, np.ones((y_train.shape[0], 1))])

print ("Train MAE: %f" % test_score[1])
print ("Test MAE: %f" % train_score[1])

Train MAE: 0.224171
Test MAE: 0.180115
