In [15]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [116]:
import torch
import torch.optim as optim
import numpy as np
import pandas as pd
import os
from sklearn.datasets import fetch_california_housing

In [117]:
# Define training arguments
class TrainParameters:
    def __init__(self, epochs=100, sample_interval=10, cache_prefix="wgan"):
        self.epochs = epochs
        self.sample_interval = sample_interval
        self.cache_prefix = cache_prefix

In [118]:
from model import WGAN_GP

In [119]:
# Test the processor
# Create a sample DataFrame
data = {
    "numeric1": [1, 2, 3],
    "numeric2": [4, 5, 6],
}
df = pd.DataFrame(data)
print(df.shape)
# Define the numerical and categorical columns
num_cols = ["numeric1", "numeric2"]
cat_cols = []

(3, 2)


In [120]:
# Generate random numeric data
num_rows = 10000
data = {
    "numeric1": np.random.uniform(
        low=0, high=10, size=num_rows
    ),  # Using standard normal distribution
    "numeric2": np.random.uniform(
        low=10, high=100, size=num_rows
    ),  # Using uniform distribution between 0 and 1
}

# Create DataFrame
df = pd.DataFrame(data)
num_cols = ["numeric1", "numeric2"]
cat_cols = []
# Display the first few rows of the DataFrame
df.head()

Unnamed: 0,numeric1,numeric2
0,6.996134,88.979488
1,2.082914,92.275953
2,8.483031,57.31074
3,3.387147,51.971794
4,9.993107,20.975259


In [121]:
df.describe()

Unnamed: 0,numeric1,numeric2
count,10000.0,10000.0
mean,4.995182,54.910363
std,2.884602,25.851078
min,3e-06,10.003906
25%,2.511669,32.676699
50%,4.946133,55.262207
75%,7.525112,76.919229
max,9.999255,99.96137


In [122]:
# Load the California Housing dataset
california_dataset = fetch_california_housing()

# Create a pandas DataFrame for the data
df = pd.DataFrame(california_dataset.data, columns=california_dataset.feature_names)
df["target"] = california_dataset.target
num_cols = list(df.columns)
cat_cols = []
df.head()

Unnamed: 0,MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude,target
0,8.3252,41.0,6.984127,1.02381,322.0,2.555556,37.88,-122.23,4.526
1,8.3014,21.0,6.238137,0.97188,2401.0,2.109842,37.86,-122.22,3.585
2,7.2574,52.0,8.288136,1.073446,496.0,2.80226,37.85,-122.24,3.521
3,5.6431,52.0,5.817352,1.073059,558.0,2.547945,37.85,-122.25,3.413
4,3.8462,52.0,6.281853,1.081081,565.0,2.181467,37.85,-122.25,3.422


In [151]:
model_parameters = {
    "batch_size": 64,
    "noise_dim": 32,
    "layers_dim": 128,
    "g_lr": 0.0001,
    "d_lr": 0.0001,
    "beta_1": 0.5,
    "beta_2": 0.9,
}

num_epochs = 300
# Instantiate the WGAN_GP model
wgan_gp_model = WGAN_GP(model_parameters)
train_arguments = TrainParameters(
    epochs=num_epochs, sample_interval=10, cache_prefix="wgan"
)
# Train the WGAN_GP model
wgan_gp_model.fit(df, train_arguments, num_cols, cat_cols)

[ 3.87067100e+00  2.86394864e+01  5.42899974e+00  1.09667515e+00
  1.42547674e+03  3.07065516e+00  3.56318614e+01 -1.19569704e+02
  2.06855817e+00]
[1.89977569e+00 1.25852527e+01 2.47411320e+00 4.73899376e-01
 1.13243469e+03 1.03857980e+01 2.13590065e+00 2.00348319e+00
 1.15392820e+00]


Epochs:  21%|██        | 63/300 [02:52<07:19,  1.85s/it, Epoch=63/300, Generator Loss=-6.533632755279541, Critic Loss=10.582010269165039]   

In [None]:
num_synthetic_samples = 10
synthetic_samples_df = wgan_gp_model._sample_unscaled(num_synthetic_samples)
print(synthetic_samples_df.head())

     MedInc  HouseAge  AveRooms  AveBedrms  Population  AveOccup  Latitude  \
0  0.324535 -0.033545 -0.472236  -0.417614   -1.323773 -1.006246  0.907207   
1  0.445087 -0.043398 -0.563320  -0.546439   -1.287114 -1.227023  1.195833   
2  0.543033 -0.121922 -0.457937  -0.569761   -1.332052 -1.108448  0.998705   
3  0.696007 -0.103056 -0.468329  -0.436845   -1.293441 -1.191832  1.041232   
4  0.309677 -0.132213 -0.485635  -0.458193   -1.425542 -1.002061  0.898244   

   Longitude    target  
0  -1.953671  1.543230  
1  -2.623505  1.872169  
2  -2.295481  1.794193  
3  -2.215866  2.212290  
4  -2.160011  1.535393  


In [142]:
num_synthetic_samples = len(df)
synthetic_samples_df = wgan_gp_model.sample(num_synthetic_samples)
print(synthetic_samples_df.head())

     MedInc   HouseAge  AveRooms  AveBedrms  Population   AveOccup   Latitude  \
0  4.342243  28.954988  4.489199   0.865319  -12.960693  -4.162019  37.169914   
1  5.031654  26.534416  4.027882   0.813071 -409.011719  -8.112635  38.284523   
2  5.348135  27.118765  4.201306   0.847741 -196.399902 -10.508178  38.218647   
3  4.656732  28.681334  4.173567   0.882333  101.896118  -8.780711  37.531891   
4  4.954137  29.012817  3.684471   0.878966  100.472168 -10.920546  37.768398   

    Longitude    target  
0 -122.404274  3.128103  
1 -124.938957  4.149595  
2 -124.645485  4.793169  
3 -123.782852  4.065839  
4 -124.222008  4.602161  


In [143]:
df.head()

Unnamed: 0,MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude,target
0,8.3252,41.0,6.984127,1.02381,322.0,2.555556,37.88,-122.23,4.526
1,8.3014,21.0,6.238137,0.97188,2401.0,2.109842,37.86,-122.22,3.585
2,7.2574,52.0,8.288136,1.073446,496.0,2.80226,37.85,-122.24,3.521
3,5.6431,52.0,5.817352,1.073059,558.0,2.547945,37.85,-122.25,3.413
4,3.8462,52.0,6.281853,1.081081,565.0,2.181467,37.85,-122.25,3.422


In [144]:
df.describe()

Unnamed: 0,MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude,target
count,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0
mean,3.870671,28.639486,5.429,1.096675,1425.476744,3.070655,35.631861,-119.569704,2.068558
std,1.899822,12.585558,2.474173,0.473911,1132.462122,10.38605,2.135952,2.003532,1.153956
min,0.4999,1.0,0.846154,0.333333,3.0,0.692308,32.54,-124.35,0.14999
25%,2.5634,18.0,4.440716,1.006079,787.0,2.429741,33.93,-121.8,1.196
50%,3.5348,29.0,5.229129,1.04878,1166.0,2.818116,34.26,-118.49,1.797
75%,4.74325,37.0,6.052381,1.099526,1725.0,3.282261,37.71,-118.01,2.64725
max,15.0001,52.0,141.909091,34.066667,35682.0,1243.333333,41.95,-114.31,5.00001


In [145]:
synthetic_samples_df.describe()

Unnamed: 0,MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude,target
count,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0
mean,4.994797,27.544823,4.087272,0.860306,-103.635239,-9.869694,37.922867,-124.470154,4.488219
std,0.353258,1.359112,0.204035,0.029027,150.80719,2.338162,0.49167,1.082845,0.530817
min,3.878432,19.860031,2.911714,0.67705,-851.622925,-22.455664,36.38633,-131.358902,2.791552
25%,4.747663,26.656943,3.963946,0.84283,-199.376373,-11.353495,37.582131,-125.139582,4.112327
50%,4.968167,27.586483,4.108115,0.862558,-96.440552,-9.720552,37.894236,-124.392662,4.449674
75%,5.21148,28.472967,4.231881,0.880062,-1.70047,-8.232004,38.227569,-123.709435,4.823098
max,6.658224,32.849251,4.737225,0.968995,456.957886,-2.077996,40.570518,-121.08754,7.263553
