# Evaluation Notebook

This notebook shows how all the evaluation metrics were calculated. To run it one must import the 'evaluation_functions.py' as it contains all the metric functions.

Set up Google colab & download data

In [0]:
import tensorflow as tf
import tensorflow.keras as keras

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random

from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split

import time
import pickle

from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics.classification import accuracy_score
from sklearn.metrics import mean_squared_error
from sklearn.utils.testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning

In [0]:
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials# Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In [0]:
downloaded = drive.CreateFile({'id':'1VuZGdAGxGDwQT10sg9z4sPvlFk0Xmupq'}) 
downloaded.GetContentFile('sim_av_tumour.csv')  
df = pd.read_csv('sim_av_tumour.csv',dtype=str)

In [0]:
#import packages
from google.colab import files
files.upload()

In [0]:
#set the drive to save outputs
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

### Data Preparation

In [0]:
from models import Generator, Discriminator, cGenerator, MultiCategorical, CategoricalActivation
from evaluation_functions import gen_output, get_accuracy_metrics, compare_accuracy, probability_lists, compare_probs, compare_within_site_probs

In [0]:
np.random.seed(123)
test_df = df.drop(['TUMOURID', 'PATIENTID', 'DIAGNOSISDATEBEST', 'LINKNUMBER', 'DATE_FIRST_SURGERY'], axis=1).iloc[np.random.permutation(len(df))].iloc[:round(len(df)*0.2)]
test_df[test_df.isna()]= "NA"
X_train, X_test = train_test_split(test_df, test_size=0.1, random_state=42)

In [0]:
print(len(X_train))
len(X_test)

252506


28057

In [0]:
ohe = OneHotEncoder(sparse=False)
ohe.fit(X_train)
ohe_df = ohe.transform(X_train)

variable_sizes = [len(X_train[column].unique()) for column in X_train.columns]
print("X train variable sizes:")
print(variable_sizes)

#remove rows that contain values that do not occur in X_train
for column in X_test.columns:
  X_test = X_test[X_test[column].isin(X_train[column].unique())]
print("X test variable sizes:")
print([len(X_test[column].unique()) for column in X_test.columns])

X train variable sizes:
[526, 116, 405, 9, 40, 18, 8, 35, 8, 6, 100, 2, 8, 33, 7, 10, 7, 10, 6, 7, 11, 13, 9, 6, 6, 4, 8, 7, 5]
X test variable sizes:
[412, 111, 294, 9, 33, 16, 7, 33, 7, 6, 99, 2, 8, 23, 7, 9, 7, 9, 6, 7, 11, 11, 9, 5, 5, 4, 7, 7, 5]


In [0]:
#D size = 1
learning_rate = 1e-4
temperature = 1/3

G = Generator([100,100,100], temperature, variable_sizes[1:], conditional=True)

G_optimizer = tf.keras.optimizers.Adam(learning_rate, name = "g_optimiser")
D_optimizer = tf.keras.optimizers.Adam(learning_rate, name = "d_optimiser")

path = '/content/drive/My Drive/ST449 Project/checkpoints/c-wgan-3/'
ckpt_G = tf.train.Checkpoint(step=tf.Variable(1), optimizer=G_optimizer, net=G)
manager_G = tf.train.CheckpointManager(ckpt_G, path + 'G', max_to_keep=None)


ckpt_G.restore(manager_G.latest_checkpoint)
#sample = pd.DataFrame(gen_output(G, X_test.shape[0], 100, ohe, X_test, conditional=True, site = 'C44'))
sample = pd.DataFrame(gen_output(G, X_test.shape[0], 100, ohe, X_test, conditional=True))

In [0]:
#Vanilla GAN 
#1 Discriminator Hidden Layer
#3 places to specify GAN name
import pandas as pd
vanilla1_samples = get_samples(Generator, variable_sizes, conditional = False, checkpoint_name='vanilla-4/', repress_warnings = True)

mse_p_av = []
mse_within_p_av = []
mse_lr_av = []
mse_rf_av = []
num_samples = 0

for i, sample in enumerate(vanilla1_samples):
  mse_p, sample_column_probs, test_column_probs = compare_probs(sample, X_test)
  mse_within_p, mse_list, sample_sizes, test_sizes, cancer_types = compare_within_site_probs(sample, X_test, conditional = False) 
  mse_lr, mse_rf, lr_sample, lr_x_test, rf_sample, rf_x_test = compare_accuracy(sample, X_test)

  mse_p_av.append(mse_p)
  mse_within_p_av.append(mse_within_p)
  mse_lr_av.append(mse_lr)
  mse_rf_av.append(mse_rf)
  num_samples += 1


vanilla1 = {'mse': [mse_p_av, mse_within_p_av, mse_lr_av, mse_rf_av],
            'prob': [sample_column_probs, test_column_probs],
            'within_prob': [mse_list, sample_sizes, test_sizes, cancer_types],
            'accuracy': [lr_sample, lr_x_test, rf_sample, rf_x_test],
            'sample': sample}

with open('/content/drive/My Drive/ST449 Project/evaluation/vanilla1.pickle', 'wb') as handle:
    pickle.dump(vanilla1, handle, protocol=pickle.HIGHEST_PROTOCOL)


Restored from /content/drive/My Drive/ST449 Project/checkpoints/vanilla-4/G/ckpt-57


In [0]:
print('MSEs for Vanilla GAN with 1 Discriminator Layer')
print('---------------------------------------------------')
print('                    MSE for Probabilties: {:.6}'.format(np.mean(mse_p_av)))
print('Weighted MSE Probabilties by Cancer Type: {:.6}'.format(np.mean([np.mean(i) for i in mse_within_p_av])))
print('                     MSE for LR Accuracy: {:.6}'.format(np.mean(mse_lr_av)))
print('                     MSE for RF Accuracy: {:.6}'.format(np.mean(mse_rf_av)))


MSEs for Vanilla GAN with 1 Discriminator Layer
---------------------------------------------------
                    MSE for Probabilties: 0.00692559
Weighted MSE Probabilties by Cancer Type: 0.0399762
                     MSE for LR Accuracy: 0.0519579
                     MSE for RF Accuracy: 0.052775



In [0]:
#Vanilla GAN 
#2 Discriminator Hidden Layer
#6 places to specify GAN name
vanilla2_samples = get_samples(Generator, variable_sizes, conditional = False, checkpoint_name='vanilla-3/', repress_warnings = True)

mse_p_av = []
mse_within_p_av = []
mse_lr_av = []
mse_rf_av = []
num_samples = 0

for i, sample in enumerate(vanilla2_samples):
  mse_p, sample_column_probs, test_column_probs = compare_probs(sample, X_test)
  mse_within_p, mse_list, sample_sizes, test_sizes, cancer_types = compare_within_site_probs(sample, X_test, conditional = False) 
  mse_lr, mse_rf, lr_sample, lr_x_test, rf_sample, rf_x_test = compare_accuracy(sample, X_test)

  mse_p_av.append(mse_p)
  mse_within_p_av.append(mse_within_p)
  mse_lr_av.append(mse_lr)
  mse_rf_av.append(mse_rf)
  num_samples += 1


vanilla2 = {'mse': [mse_p_av, mse_within_p_av, mse_lr_av, mse_rf_av],
            'prob': [sample_column_probs, test_column_probs],
            'within_prob': [mse_list, sample_sizes, test_sizes, cancer_types],
            'accuracy': [lr_sample, lr_x_test, rf_sample, rf_x_test],
            'sample': sample}

with open('/content/drive/My Drive/ST449 Project/evaluation/vanilla2.pickle', 'wb') as handle:
    pickle.dump(vanilla2, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [0]:
with open('/content/drive/My Drive/ST449 Project/evaluation/wgan1.pickle', 'rb') as handle:
    loaded = pickle.load(handle)

In [0]:
mse_p_av = loaded['mse'][0]
mse_within_p_av= loaded['mse'][1]
mse_lr_av= loaded['mse'][2]
mse_rf_av= loaded['mse'][3]

In [0]:
print('MSEs for Vanilla GAN with 2 Discriminator Layers')
print('---------------------------------------------------')
print('                    MSE for Probabilties: {:.6}'.format(np.mean(mse_p_av)))
print('Weighted MSE Probabilties by Cancer Type: {:.6}'.format(np.mean(mse_within_p_av)))
print('                     MSE for LR Accuracy: {:.6}'.format(np.mean(mse_lr_av)))
print('                     MSE for RF Accuracy: {:.6}'.format(np.mean(mse_rf_av)))


MSEs for Vanilla GAN with 2 Discriminator Layers
---------------------------------------------------
                    MSE for Probabilties: 0.00385195
Weighted MSE Probabilties by Cancer Type: 0.0285657
                     MSE for LR Accuracy: 0.259183
                     MSE for RF Accuracy: 0.251436



In [0]:
#WGAN 
#1 Discriminator Hidden Layer
#6 places to specify GAN name
wgan1_samples = get_samples(Generator, variable_sizes, conditional = False, checkpoint_name='wgan-2/', repress_warnings = True)

mse_p_av = []
mse_within_p_av = []
mse_lr_av = []
mse_rf_av = []
num_samples = 0

for i, sample in enumerate(wgan1_samples):
  mse_p, sample_column_probs, test_column_probs = compare_probs(sample, X_test)
  mse_within_p, mse_list, sample_sizes, test_sizes, cancer_types = compare_within_site_probs(sample, X_test, conditional = False) 
  mse_lr, mse_rf, lr_sample, lr_x_test, rf_sample, rf_x_test = compare_accuracy(sample, X_test)

  mse_p_av.append(mse_p)
  mse_within_p_av.append(mse_within_p)
  mse_lr_av.append(mse_lr)
  mse_rf_av.append(mse_rf)
  num_samples += 1


wgan1 = {'mse': [mse_p_av, mse_within_p_av, mse_lr_av, mse_rf_av],
            'prob': [sample_column_probs, test_column_probs],
            'within_prob': [mse_list, sample_sizes, test_sizes, cancer_types],
            'accuracy': [lr_sample, lr_x_test, rf_sample, rf_x_test],
            'sample': sample}

with open('/content/drive/My Drive/ST449 Project/evaluation/wgan1.pickle', 'wb') as handle:
    pickle.dump(wgan1, handle, protocol=pickle.HIGHEST_PROTOCOL)

Restored from /content/drive/My Drive/ST449 Project/checkpoints/wgan-2/G/ckpt-60


In [0]:
print('MSEs for WGAN with 1 Discriminator Layers')
print('---------------------------------------------------')
print('                    MSE for Probabilties: {:.6}'.format(np.mean(mse_p_av)))
print('Weighted MSE Probabilties by Cancer Type: {:.6}'.format(np.mean(mse_within_p_av)))
print('                     MSE for LR Accuracy: {:.6}'.format(np.mean(mse_lr_av)))
print('                     MSE for RF Accuracy: {:.6}'.format(np.mean(mse_rf_av)))


MSEs for WGAN with 1 Discriminator Layers
---------------------------------------------------
                    MSE for Probabilties: 0.000336559
Weighted MSE Probabilties by Cancer Type: 0.0232472
                     MSE for LR Accuracy: 0.0340192
                     MSE for RF Accuracy: 0.0349088



In [0]:
#WGAN 
#1 Discriminator Hidden Layer
#6 places to specify GAN name
wgan2_samples = get_samples(Generator, variable_sizes, conditional = False, checkpoint_name='wgan-1/', repress_warnings = True)

mse_p_av = []
mse_within_p_av = []
mse_lr_av = []
mse_rf_av = []
num_samples = 0

for i, sample in enumerate(wgan2_samples):
  mse_p, sample_column_probs, test_column_probs = compare_probs(sample, X_test)
  mse_within_p, mse_list, sample_sizes, test_sizes, cancer_types = compare_within_site_probs(sample, X_test, conditional = False) 
  mse_lr, mse_rf, lr_sample, lr_x_test, rf_sample, rf_x_test = compare_accuracy(sample, X_test)

  mse_p_av.append(mse_p)
  mse_within_p_av.append(mse_within_p)
  mse_lr_av.append(mse_lr)
  mse_rf_av.append(mse_rf)
  num_samples += 1


wgan2 = {'mse': [mse_p_av, mse_within_p_av, mse_lr_av, mse_rf_av],
            'prob': [sample_column_probs, test_column_probs],
            'within_prob': [mse_list, sample_sizes, test_sizes, cancer_types],
            'accuracy': [lr_sample, lr_x_test, rf_sample, rf_x_test],
            'sample': sample}

with open('/content/drive/My Drive/ST449 Project/evaluation/wgan2.pickle', 'wb') as handle:
    pickle.dump(wgan2, handle, protocol=pickle.HIGHEST_PROTOCOL)

Restored from /content/drive/My Drive/ST449 Project/checkpoints/wgan-1/G/ckpt-60


In [0]:
wgan2_samples = get_samples(Generator, variable_sizes, conditional = False, checkpoint_name='wgan-1/', repress_warnings = True)
mse_within_p_av = []
for i, sample in enumerate(wgan2_samples):
  mse_p, sample_column_probs, test_column_probs = compare_probs(sample, X_test)
  mse_within_p, mse_list, sample_sizes, test_sizes, cancer_types = compare_within_site_probs(sample, X_test, conditional = False)
  mse_within_p_av.append(mse_within_p) 



In [0]:
print('MSEs for WGAN with 2 Discriminator Layers')
print('---------------------------------------------------------')
print('                     MSE for Probabilties: {:.6}'.format(np.mean(mse_p_av)))
print('Weighted MSE Probabilities by Cancer Type: {:.6}'.format(np.mean(mse_within_p_av)))
print('                      MSE for LR Accuracy: {:.6}'.format(np.mean(mse_lr_av)))
print('                      MSE for RF Accuracy: {:.6}'.format(np.mean(mse_rf_av)))


MSEs for WGAN with 2 Discriminator Layers
-------------------------------------------------------
                     MSE for Probabilties: 0.000363674
Weighted MSE Probabilities by Cancer Type: 0.0158364
                      MSE for LR Accuracy: 0.0244896
                      MSE for RF Accuracy: 0.0251833



### Conditional Generators

In [0]:
#re-prepare the data
np.random.seed(123)
test_df = df.drop(['TUMOURID', 'PATIENTID', 'DIAGNOSISDATEBEST', 'LINKNUMBER', 'DATE_FIRST_SURGERY'], axis=1).iloc[np.random.permutation(len(df))].iloc[:round(len(df)*0.2)]
test_df[test_df.isna()]= "NA"
X_train, X_test = train_test_split(test_df, test_size=0.1, random_state=42)

cols = list(X_train)
cols[1], cols[0] = cols[0], cols[1]
X_train = X_train[cols]

ohe = OneHotEncoder(sparse=False)
ohe.fit(X_train)
ohe_df = ohe.transform(X_train)

cols = list(X_test)
cols[1], cols[0] = cols[0], cols[1]
X_test = X_test[cols]

variable_sizes = [len(X_train[column].unique()) for column in X_train.columns]
print(variable_sizes)

#remove rows that contain values that do not occur in X_train
for column in X_test.columns:
  X_test = X_test[X_test[column].isin(X_train[column].unique())]

print([len(X_test[column].unique()) for column in X_test.columns])

#get site mapping
ohe_mapper = ohe.transform(X_test)[:, :116]
X_mapper = X_test.reset_index()
site_mapping = {}

for site in X_mapper['SITE_ICD10_O2_3CHAR'].unique():
  site_mapping[site] = ohe_mapper[X_mapper[X_mapper['SITE_ICD10_O2_3CHAR'] == site].iloc[0:1].index[0],:]

[116, 526, 405, 9, 40, 18, 8, 35, 8, 6, 100, 2, 8, 33, 7, 10, 7, 10, 6, 7, 11, 13, 9, 6, 6, 4, 8, 7, 5]
[111, 412, 294, 9, 33, 16, 7, 33, 7, 6, 99, 2, 8, 23, 7, 9, 7, 9, 6, 7, 11, 11, 9, 5, 5, 4, 7, 7, 5]


In [0]:
#c-Vanilla GAN 
#1 Discriminator Hidden Layer
#6 places to specify GAN name
c_vanilla1_samples = get_samples(Generator, variable_sizes[1:], conditional = True, checkpoint_name='c-vanilla-3/', repress_warnings = True)

mse_p_av = []
mse_within_p_av = []
mse_lr_av = []
mse_rf_av = []
num_samples = 0

for i, sample in enumerate(c_vanilla1_samples):
  mse_p, sample_column_probs, test_column_probs = compare_probs(sample, X_test)
  mse_within_p, mse_list, sample_sizes, test_sizes, cancer_types = compare_within_site_probs(sample, X_test, conditional = True) 
  mse_lr, mse_rf, lr_sample, lr_x_test, rf_sample, rf_x_test = compare_accuracy(sample, X_test)

  mse_p_av.append(mse_p)
  mse_within_p_av.append(mse_within_p)
  mse_lr_av.append(mse_lr)
  mse_rf_av.append(mse_rf)
  num_samples += 1


c_vanilla1 = {'mse': [mse_p_av, mse_within_p_av, mse_lr_av, mse_rf_av],
            'prob': [sample_column_probs, test_column_probs],
            'within_prob': [mse_list, sample_sizes, test_sizes, cancer_types],
            'accuracy': [lr_sample, lr_x_test, rf_sample, rf_x_test],
            'sample': sample}

with open('/content/drive/My Drive/ST449 Project/evaluation/c_vanilla1.pickle', 'wb') as handle:
    pickle.dump(c_vanilla1, handle, protocol=pickle.HIGHEST_PROTOCOL)

Restored from /content/drive/My Drive/ST449 Project/checkpoints/c-vanilla-3/G/ckpt-26


In [0]:
print('MSEs for Conditonal Vanilla with 1 Discriminator Layers')
print('---------------------------------------------------------')
print('                     MSE for Probabilties: {:.6}'.format(np.mean(mse_p_av)))
print('Weighted MSE Probabilities by Cancer Type: {:.6}'.format(np.mean(mse_within_p_av)))
print('                      MSE for LR Accuracy: {:.6}'.format(np.mean(mse_lr_av)))
print('                      MSE for RF Accuracy: {:.6}'.format(np.mean(mse_rf_av)))


MSEs for Conditonal Vanilla with 1 Discriminator Layers
---------------------------------------------------------
                     MSE for Probabilties: 0.00852635
Weighted MSE Probabilities by Cancer Type: 0.0521075
                      MSE for LR Accuracy: 0.275968
                      MSE for RF Accuracy: 0.294099



In [0]:
#c-Vanilla GAN 
#2 Discriminator Hidden Layer
#6 places to specify GAN name
c_vanilla2_samples = get_samples(Generator, variable_sizes[1:], conditional = True, checkpoint_name='c-vanilla-2d2/', repress_warnings = True)

mse_p_av = []
mse_within_p_av = []
mse_lr_av = []
mse_rf_av = []
num_samples = 0

for i, sample in enumerate(c_vanilla2_samples):
  mse_p, sample_column_probs, test_column_probs = compare_probs(sample, X_test)
  mse_within_p, mse_list, sample_sizes, test_sizes, cancer_types = compare_within_site_probs(sample, X_test, conditional = True) 
  mse_lr, mse_rf, lr_sample, lr_x_test, rf_sample, rf_x_test = compare_accuracy(sample, X_test)

  mse_p_av.append(mse_p)
  mse_within_p_av.append(mse_within_p)
  mse_lr_av.append(mse_lr)
  mse_rf_av.append(mse_rf)
  num_samples += 1


c_vanilla2 = {'mse': [mse_p_av, mse_within_p_av, mse_lr_av, mse_rf_av],
            'prob': [sample_column_probs, test_column_probs],
            'within_prob': [mse_list, sample_sizes, test_sizes, cancer_types],
            'accuracy': [lr_sample, lr_x_test, rf_sample, rf_x_test],
            'sample': sample}

with open('/content/drive/My Drive/ST449 Project/evaluation/c_vanilla2.pickle', 'wb') as handle:
    pickle.dump(c_vanilla2, handle, protocol=pickle.HIGHEST_PROTOCOL)

Restored from /content/drive/My Drive/ST449 Project/checkpoints/c-vanilla-2d2/G/ckpt-60


In [0]:
print('MSEs for Conditonal Vanilla with 2 Discriminator Layers')
print('---------------------------------------------------------')
print('                     MSE for Probabilties: {:.6}'.format(np.mean(mse_p_av)))
print('Weighted MSE Probabilities by Cancer Type: {:.6}'.format(np.mean(mse_within_p_av)))
print('                      MSE for LR Accuracy: {:.6}'.format(np.mean(mse_lr_av)))
print('                      MSE for RF Accuracy: {:.6}'.format(np.mean(mse_rf_av)))


MSEs for Conditonal Vanilla with 2 Discriminator Layers
---------------------------------------------------------
                     MSE for Probabilties: 0.00800333
Weighted MSE Probabilities by Cancer Type: 0.0444197
                      MSE for LR Accuracy: 0.144916
                      MSE for RF Accuracy: 0.150394



In [0]:
#c-WGAN 
#1 Discriminator Hidden Layer
#6 places to specify GAN name
c_wgan1_samples = get_samples(Generator, variable_sizes[1:], conditional = True, checkpoint_name='c-wgan-3/', repress_warnings = True)

mse_p_av = []
mse_within_p_av = []
mse_lr_av = []
mse_rf_av = []
num_samples = 0

for i, sample in enumerate(c_wgan1_samples):
  mse_p, sample_column_probs, test_column_probs = compare_probs(sample, X_test)
  mse_within_p, mse_list, sample_sizes, test_sizes, cancer_types = compare_within_site_probs(sample, X_test, conditional = True) 
  mse_lr, mse_rf, lr_sample, lr_x_test, rf_sample, rf_x_test = compare_accuracy(sample, X_test)

  mse_p_av.append(mse_p)
  mse_within_p_av.append(mse_within_p)
  mse_lr_av.append(mse_lr)
  mse_rf_av.append(mse_rf)
  num_samples += 1


c_wgan1 = {'mse': [mse_p_av, mse_within_p_av, mse_lr_av, mse_rf_av],
            'prob': [sample_column_probs, test_column_probs],
            'within_prob': [mse_list, sample_sizes, test_sizes, cancer_types],
            'accuracy': [lr_sample, lr_x_test, rf_sample, rf_x_test],
            'sample': sample}

with open('/content/drive/My Drive/ST449 Project/evaluation/c_wgan1.pickle', 'wb') as handle:
    pickle.dump(c_wgan1, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [0]:
print('MSEs for Conditonal WGAN with 1 Discriminator Layers')
print('---------------------------------------------------------')
print('                     MSE for Probabilties: {:.6}'.format(np.mean(mse_p_av)))
print('Weighted MSE Probabilities by Cancer Type: {:.6}'.format(np.mean(mse_within_p_av)))
print('                      MSE for LR Accuracy: {:.6}'.format(np.mean(mse_lr_av)))
print('                      MSE for RF Accuracy: {:.6}'.format(np.mean(mse_rf_av)))


MSEs for Conditonal WGAN with 1 Discriminator Layers
---------------------------------------------------------
                     MSE for Probabilties: 0.000475133
Weighted MSE Probabilities by Cancer Type: 0.00795714
                      MSE for LR Accuracy: 0.0239637
                      MSE for RF Accuracy: 0.0245771



In [0]:
#c-WGAN 
#2 Discriminator Hidden Layer
#6 places to specify GAN name
c_wgan2_samples = get_samples(Generator, variable_sizes[1:], conditional = True, checkpoint_name='c-wgan-2d2/', repress_warnings = True)

mse_p_av = []
mse_within_p_av = []
mse_lr_av = []
mse_rf_av = []
num_samples = 0

for i, sample in enumerate(c_wgan2_samples):
  mse_p, sample_column_probs, test_column_probs = compare_probs(sample, X_test)
  mse_within_p, mse_list, sample_sizes, test_sizes, cancer_types = compare_within_site_probs(sample, X_test, conditional = True) 
  mse_lr, mse_rf, lr_sample, lr_x_test, rf_sample, rf_x_test = compare_accuracy(sample, X_test)

  mse_p_av.append(mse_p)
  mse_within_p_av.append(mse_within_p)
  mse_lr_av.append(mse_lr)
  mse_rf_av.append(mse_rf)
  num_samples += 1


c_wgan2 = {'mse': [mse_p_av, mse_within_p_av, mse_lr_av, mse_rf_av],
            'prob': [sample_column_probs, test_column_probs],
            'within_prob': [mse_list, sample_sizes, test_sizes, cancer_types],
            'accuracy': [lr_sample, lr_x_test, rf_sample, rf_x_test],
            'sample': sample}

with open('/content/drive/My Drive/ST449 Project/evaluation/c_wgan2.pickle', 'wb') as handle:
    pickle.dump(c_wgan2, handle, protocol=pickle.HIGHEST_PROTOCOL)

Restored from /content/drive/My Drive/ST449 Project/checkpoints/c-wgan-2d2/G/ckpt-60


In [0]:
print('MSEs for Conditonal WGAN with 2 Discriminator Layers')
print('---------------------------------------------------------')
print('                     MSE for Probabilties: {:.6}'.format(np.mean(mse_p_av)))
print('Weighted MSE Probabilities by Cancer Type: {:.6}'.format(np.mean(mse_within_p_av)))
print('                      MSE for LR Accuracy: {:.6}'.format(np.mean(mse_lr_av)))
print('                      MSE for RF Accuracy: {:.6}'.format(np.mean(mse_rf_av)))


MSEs for Conditonal WGAN with 2 Discriminator Layers
---------------------------------------------------------
                     MSE for Probabilties: 0.000141389
Weighted MSE Probabilities by Cancer Type: 0.00413341
                      MSE for LR Accuracy: 0.0106915
                      MSE for RF Accuracy: 0.0139164
                      


In [0]:
#c-WGAN extended
#2 Discriminator Hidden Layer
#6 places to specify GAN name
c_wgan2_ext_samples = get_samples(cGenerator, variable_sizes[1:], conditional = True, checkpoint_name='c-wgan-ext2/', repress_warnings = True)

mse_p_av = []
mse_within_p_av = []
mse_lr_av = []
mse_rf_av = []
num_samples = 0

for i, sample in enumerate(c_wgan2_ext_samples):
  mse_p, sample_column_probs, test_column_probs = compare_probs(sample, X_test)
  mse_within_p, mse_list, sample_sizes, test_sizes, cancer_types = compare_within_site_probs(sample, X_test, conditional = True) 
  mse_lr, mse_rf, lr_sample, lr_x_test, rf_sample, rf_x_test = compare_accuracy(sample, X_test)

  mse_p_av.append(mse_p)
  mse_within_p_av.append(mse_within_p)
  mse_lr_av.append(mse_lr)
  mse_rf_av.append(mse_rf)
  num_samples += 1


c_wgan2_ext = {'mse': [mse_p_av, mse_within_p_av, mse_lr_av, mse_rf_av],
            'prob': [sample_column_probs, test_column_probs],
            'within_prob': [mse_list, sample_sizes, test_sizes, cancer_types],
            'accuracy': [lr_sample, lr_x_test, rf_sample, rf_x_test],
            'sample': sample}

with open('/content/drive/My Drive/ST449 Project/evaluation/c_wgan2_ext.pickle', 'wb') as handle:
    pickle.dump(c_wgan2_ext, handle, protocol=pickle.HIGHEST_PROTOCOL)

Restored from /content/drive/My Drive/ST449 Project/checkpoints/c-wgan-ext2/G/ckpt-60


In [0]:
print('MSEs for Extended Conditonal WGAN with 2 Discriminator Layers')
print('---------------------------------------------------------------')
print('                     MSE for Probabilties: {:.6}'.format(np.mean(mse_p_av)))
print('Weighted MSE Probabilities by Cancer Type: {:.6}'.format(np.mean(mse_within_p_av)))
print('                      MSE for LR Accuracy: {:.6}'.format(np.mean(mse_lr_av)))
print('                      MSE for RF Accuracy: {:.6}'.format(np.mean(mse_rf_av)))

MSEs for Extended Conditonal WGAN with 2 Discriminator Layers
---------------------------------------------------------------
                     MSE for Probabilties: 0.000111725
Weighted MSE Probabilities by Cancer Type: 0.076191
                      MSE for LR Accuracy: 0.0130744
                      MSE for RF Accuracy: 0.0157422
