In [1]:
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras

In [2]:
if 'google.colab' in str(get_ipython()):
  from google.colab import drive
  drive.mount('/content/drive')
  proj_dir = "/content/drive/MyDrive/ece884_project/"
else:
  proj_dir = "../"

df = pd.read_csv(f"{proj_dir}data_clean/clean.csv")
column_names = df.columns
df = df.to_numpy()

Mounted at /content/drive


In [None]:
model_number = 1
with open(f"{proj_dir}saved_models/list_of_models/gen/generators{model_number}", "wb") as fp:
    pickle.dump(generators_saved, fp)

with open(f"{proj_dir}saved_models/list_of_models/disc/discriminators{model_number}", "wb") as fp:
    pickle.dump(discriminators_saved, fp)

In [None]:
def generated_data_filter(gen, desc, threashold, dims=df.shape):
    """
    inputs
    gen, is the list of gans we wrote with the gan.ipynb

    desc, is the list of discriminators in the notebook gan.ipynb

    threashold, is what is the discriminator's predicted probability of the data being real
    we need to see to keep the data. 
    with a threashold = 0.99 we will drop every datapoint that the discriminator says has a 
    less than .99 change of being real. 
    we will need to play with this.

    """
    quality_data = np.empty((0, dims[1]), np.float32)
    for generator, discriminator in zip(gen, desc):
        noise = tf.random.normal(shape=dims)
        generated_data = generator(noise)
        judgement = discriminator(generated_data) # probs data is real
        data_fooling_discriminator = np.compress(np.ravel(judgement) > threashold, generated_data, axis=0)
        quality_data = np.append(quality_data, data_fooling_discriminator, axis=0)
    
    for discriminator in desc:
        judgement = discriminator(quality_data)
        quality_data = np.compress(np.ravel(judgement) > threashold, quality_data, axis=0)
    return quality_data

In [None]:
results = generated_data_filter(generators_saved, discriminators_saved, threashold=0.03)

In [None]:
# integrate these two 

In [None]:
import os
import re
gans_saved = os.listdir(f"{proj_dir}saved_models/")
model_number = [int(re.sub("gan", "", x)) for x in gans_saved]
last_model = max(model_number)

In [None]:
columns_in_generated = 98
rows_in_generated = 500

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, mean_squared_error

In [None]:
def logistic_reg_confution(X, y):
    return LogisticRegression().fit(X, y).predict(X)

In [None]:
def logistic_reg_single_col(X, y):
  n_columns = X.shape[1]
  scores = []
  for column in range(n_columns):
    x_sub = X[:, column].reshape((len(y), 1))
    yhat = LogisticRegression().fit(x_sub, y).predict(x_sub)
    mse = mean_squared_error(y, yhat)
    scores.append(np.sum(mse))
  return scores

In [None]:
results = []
for i in range(1, last_model+1):
  gan_path = f"{proj_dir}saved_models/gan{i}"

  last_gan = tf.saved_model.load(gan_path)
  noise = tf.random.normal(shape=[rows_in_generated, columns_in_generated])
  generated_data = last_gan(noise)
  
  generated_label = np.concatenate([np.ones((rows_in_generated, )),
                                    np.zeros((rows_in_generated, ))], axis=0)
  
  random_index = np.random.permutation(df.shape[0])
  real_data = df[random_index, :] # sample from full dataset

  generated_and_real = np.concatenate([generated_data, real_data], axis=0)

  random_index = np.random.permutation(2*rows_in_generated)
  generated_and_real_shuffeled = generated_and_real[random_index, :]
  generated_label_shuffeled = generated_label[random_index]

  rows = logistic_reg_single_col(generated_and_real_shuffeled, generated_label_shuffeled)
  results.append(rows)
#  conf = confusion_matrix(generated_label_shuffeled, yhat)
# print(conf)

In [None]:
pd.DataFrame(results, columns=column_names)

Unnamed: 0,hospnum,rdelay,sex,age,rsleep,ratrial,rct,rvisinf,rhep24,rasp3,...,dead8,h14,isc14,nk14,strk14,hti14,pe14,dvt14,tran14,ncb14
0,0.0,0.0,0.225,0.0,0.361,0.398,0.152,0.33,0.491,0.381,...,0.49,0.497,0.487,0.494,0.478,0.5,0.497,0.499,0.496,0.491
1,0.002,0.001,0.236,0.0,0.368,0.401,0.158,0.342,0.485,0.384,...,0.492,0.497,0.488,0.494,0.479,0.498,0.494,0.5,0.495,0.485
2,0.005,0.001,0.244,0.0,0.338,0.413,0.166,0.321,0.009,0.387,...,0.487,0.498,0.492,0.492,0.479,0.501,0.497,0.499,0.499,0.489
3,0.005,0.003,0.237,0.0,0.356,0.413,0.18,0.341,0.494,0.397,...,0.496,0.498,0.489,0.496,0.483,0.5,0.497,0.5,0.496,0.48
4,0.006,0.003,0.235,0.0,0.35,0.413,0.313,0.366,0.49,0.372,...,0.495,0.497,0.485,0.495,0.477,0.5,0.498,0.001,0.498,0.487
5,0.003,0.002,0.241,0.0,0.346,0.101,0.159,0.328,0.487,0.103,...,0.005,0.252,0.49,0.496,0.484,0.5,0.001,0.499,0.493,0.482
6,0.016,0.002,0.262,0.0,0.351,0.084,0.181,0.345,0.49,0.386,...,0.004,0.497,0.491,0.494,0.482,0.499,0.497,0.499,0.498,0.489
7,0.006,0.003,0.245,0.0,0.346,0.406,0.171,0.342,0.483,0.39,...,0.496,0.0,0.01,0.004,0.486,0.0,0.491,0.499,0.498,0.483
8,0.01,0.0,0.222,0.0,0.344,0.403,0.34,0.331,0.49,0.384,...,0.495,0.496,0.011,0.495,0.481,0.499,0.495,0.5,0.499,0.488
9,0.008,0.0,0.244,0.0,0.361,0.407,0.167,0.348,0.492,0.385,...,0.496,0.496,0.492,0.492,0.48,0.5,0.497,0.5,0.489,0.481
