# 

# pythae synthetic data generation

In [1]:
pip install pythae

Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install --upgrade pythae


Note: you may need to restart the kernel to use updated packages.


## vaegan


# @article{chadebec2022pythae,
	title={Pythae: Unifying Generative Autoencoders in Python -- A Benchmarking Use Case},
  	author={Chadebec, Clément and Vincent, Louis J. and Allassonnière, Stéphanie},
  	journal={arXiv preprint arXiv:2206.08309},
	url = {https://arxiv.org/abs/2206.08309},
  	year = {2022}
}

In [1]:
import numpy as np
import pandas as pd
import sdv
import seaborn as sns
import matplotlib.pyplot as plt

#libraries needed for pythae
from pythae.models import VAEGAN, VAEGANConfig
from pythae.trainers import CoupledOptimizerAdversarialTrainerConfig
from pythae.pipelines.training import TrainingPipeline

#pytorch
import torch


device = "cuda" if torch.cuda.is_available() else "cpu"


In [2]:
sdv.__version__

'1.0.0b1'

In [3]:
real_data1=pd.read_csv('german_credit_data.csv')

#data transforming to numerical values using sdv
from rdt import HyperTransformer
ht = HyperTransformer()

# do not transform any categorical columns in the dataset
ht.remove_transformers_by_sdtype(sdtype='numerical')
ht.detect_initial_config(data=real_data1)
config=ht.get_config()
ht.fit(real_data1)

real_data_t=ht.transform(real_data1)

#changing from pd to torch verison
real_data=torch.Tensor(real_data_t.values)

In [4]:
real_data1.shape

(1000, 21)

In [5]:
real_data

tensor([[ 0.5310,  6.0000,  0.6765,  ...,  0.7980,  0.4815,  0.3500],
        [ 0.8025, 48.0000,  0.2650,  ...,  0.2980,  0.4815,  0.8500],
        [ 0.1970, 12.0000,  0.6765,  ...,  0.2980,  0.4815,  0.3500],
        ...,
        [ 0.1970, 12.0000,  0.2650,  ...,  0.2980,  0.4815,  0.3500],
        [ 0.5310, 45.0000,  0.2650,  ...,  0.7980,  0.4815,  0.8500],
        [ 0.8025, 45.0000,  0.6765,  ...,  0.2980,  0.4815,  0.3500]])

In [6]:
#splitting datasets into train and test
from torch.utils.data import DataLoader, random_split


dataset = real_data
#splitting in 80:20 order
train_size = int(0.8 * len(dataset))  # 80% for train
test_size = len(dataset) - train_size  # remaining 20% for test

# Using random_split to randomly split the dataset
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Creating data loaders for train and test datasets (optional)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)



In [7]:
#VAEGAN model

#VAEGANConfig
model_config=VAEGANConfig(input_dim=( 1000, 21),
    latent_dim=16,
    adversarial_loss_scale=0.8,
    reconstruction_layer= 0,
    margin=0.4,
    equilibrium= 0.68
)

model = VAEGAN(
    model_config=model_config
)

#training_config
config = CoupledOptimizerAdversarialTrainerConfig(
    output_dir='my_model',
    learning_rate=1e-4,
    per_device_train_batch_size=200,
    per_device_eval_batch_size=200,
    num_epochs=10, # Change this to train the model a bit more
)

In [8]:
pipeline = TrainingPipeline(
    training_config=config,
    model=model
)

In [9]:
pipeline(
    train_data=train_dataset,
    eval_data=test_dataset
)

Checking train dataset...


AttributeError: 'Tensor' object has no attribute 'keys'

In [11]:
import os
from pythae.models import AutoModel
from pythae.samplers import NormalSampler

In [12]:
# create normal sampler
normal_samper = NormalSampler(
    model=model
)

In [13]:
gen_data = normal_samper.sample(
    num_samples=30000
)

In [14]:
gen_data

tensor([[[0.5554, 0.6214, 0.5679,  ..., 0.4331, 0.5897, 0.5300],
         [0.4595, 0.4685, 0.4262,  ..., 0.5752, 0.5149, 0.3880],
         [0.5027, 0.5122, 0.4196,  ..., 0.4170, 0.4895, 0.4722],
         ...,
         [0.5420, 0.5846, 0.4785,  ..., 0.5071, 0.5138, 0.4855],
         [0.5788, 0.5166, 0.5338,  ..., 0.6035, 0.4866, 0.4619],
         [0.4785, 0.5659, 0.4902,  ..., 0.5759, 0.3976, 0.5021]],

        [[0.5098, 0.5176, 0.5007,  ..., 0.3846, 0.5380, 0.4235],
         [0.4850, 0.3729, 0.3904,  ..., 0.4923, 0.4912, 0.5326],
         [0.5450, 0.4638, 0.5035,  ..., 0.4814, 0.4457, 0.4599],
         ...,
         [0.5883, 0.5336, 0.5568,  ..., 0.5329, 0.4557, 0.4306],
         [0.4763, 0.4674, 0.5315,  ..., 0.6503, 0.5274, 0.4057],
         [0.4555, 0.4938, 0.4688,  ..., 0.5475, 0.5139, 0.4669]],

        [[0.5409, 0.5184, 0.5576,  ..., 0.5153, 0.4430, 0.5600],
         [0.4660, 0.4957, 0.4188,  ..., 0.5330, 0.5546, 0.4967],
         [0.4014, 0.3932, 0.5047,  ..., 0.5396, 0.4969, 0.

In [15]:
gen_data.shape

torch.Size([30000, 1000, 21])

In [16]:

import pandas as pd
import torch


# Convert the PyTorch tensor to a NumPy array
synthetic_data = gen_data[:, -1, :]


# Convert the NumPy array to a pandas DataFrame
synthetic_df = pd.DataFrame(synthetic_data)

# Specify column names for the DataFrame
column_names = ['checking_status', 'duration', 'credit_history', 'purpose',
               'credit_amount', 'savings_status', 'employment',
               'installment_commitment', 'personal_status', 'other_parties',
               'residence_since', 'property_magnitude', 'age', 'other_payment_plans',
               'housing', 'existing_credits', 'job', 'num_dependents', 'own_telephone',
               'foreign_worker', 'class']
synthetic_df.columns = column_names

# Now, synthetic_df contains the reshaped synthetic data in a pandas DataFrame with shape (30000, 21)



In [17]:
synthetic_df

Unnamed: 0,checking_status,duration,credit_history,purpose,credit_amount,savings_status,employment,installment_commitment,personal_status,other_parties,...,property_magnitude,age,other_payment_plans,housing,existing_credits,job,num_dependents,own_telephone,foreign_worker,class
0,0.478505,0.565905,0.490211,0.448236,0.472696,0.521699,0.556739,0.504936,0.445567,0.474532,...,0.498832,0.440030,0.502045,0.575150,0.536195,0.552032,0.447722,0.575852,0.397606,0.502111
1,0.455474,0.493797,0.468848,0.423789,0.614694,0.450945,0.484393,0.501624,0.475776,0.536397,...,0.554665,0.581273,0.449968,0.566813,0.486232,0.458741,0.398875,0.547526,0.513937,0.466883
2,0.494384,0.475684,0.506650,0.457652,0.485349,0.443305,0.512519,0.428163,0.553646,0.510674,...,0.436791,0.497064,0.518429,0.518192,0.450595,0.516134,0.520200,0.498477,0.510671,0.484485
3,0.533222,0.443696,0.442589,0.345123,0.540093,0.411210,0.515047,0.495991,0.564177,0.492443,...,0.404047,0.591348,0.419443,0.519134,0.469775,0.573944,0.462546,0.516776,0.566685,0.523864
4,0.491692,0.464167,0.475775,0.458381,0.501096,0.437926,0.520942,0.547313,0.590374,0.503895,...,0.432216,0.554620,0.524405,0.537387,0.473601,0.484209,0.517816,0.546572,0.528586,0.514994
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29995,0.502836,0.474452,0.443233,0.367224,0.419281,0.435557,0.634173,0.429748,0.525056,0.504594,...,0.460728,0.491926,0.469438,0.586156,0.464904,0.462628,0.427678,0.574569,0.494020,0.322354
29996,0.523987,0.560227,0.420824,0.502473,0.623851,0.435940,0.487619,0.501224,0.470839,0.550358,...,0.486309,0.595213,0.446906,0.479065,0.477742,0.535207,0.439763,0.589492,0.471735,0.473615
29997,0.438200,0.527874,0.468171,0.444992,0.589658,0.391490,0.508548,0.540120,0.400107,0.520032,...,0.504002,0.506881,0.518777,0.574175,0.428081,0.522217,0.484277,0.630850,0.477703,0.421251
29998,0.436174,0.493608,0.453452,0.472750,0.537618,0.458524,0.527670,0.420408,0.528684,0.510869,...,0.437890,0.577642,0.507049,0.558102,0.441728,0.476728,0.445763,0.562974,0.469395,0.462535


In [18]:
real_data_t

Unnamed: 0,checking_status,duration,credit_history,purpose,credit_amount,savings_status,employment,installment_commitment,personal_status,other_parties,...,property_magnitude,age,other_payment_plans,housing,existing_credits,job,num_dependents,own_telephone,foreign_worker,class
0,0.5310,6.0,0.6765,0.1400,1169.0,0.6945,0.4655,4.0,0.274,0.4535,...,0.473,67.0,0.407,0.3565,2.0,0.315,1.0,0.798,0.4815,0.35
1,0.8025,48.0,0.2650,0.1400,5951.0,0.3015,0.1695,2.0,0.703,0.4535,...,0.473,22.0,0.407,0.3565,1.0,0.315,1.0,0.298,0.4815,0.85
2,0.1970,12.0,0.6765,0.9200,2096.0,0.3015,0.6790,2.0,0.274,0.4535,...,0.473,49.0,0.407,0.3565,1.0,0.730,2.0,0.298,0.4815,0.35
3,0.5310,42.0,0.2650,0.6045,7882.0,0.3015,0.6790,2.0,0.274,0.9330,...,0.730,45.0,0.407,0.9460,1.0,0.315,2.0,0.298,0.4815,0.35
4,0.5310,24.0,0.8670,0.3970,4870.0,0.3015,0.1695,3.0,0.274,0.4535,...,0.923,53.0,0.407,0.9460,2.0,0.315,2.0,0.298,0.4815,0.85
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,0.1970,12.0,0.2650,0.6045,1736.0,0.3015,0.6790,3.0,0.703,0.4535,...,0.473,31.0,0.407,0.3565,1.0,0.730,1.0,0.298,0.4815,0.35
996,0.5310,30.0,0.2650,0.7465,3857.0,0.3015,0.1695,4.0,0.975,0.4535,...,0.730,40.0,0.407,0.3565,1.0,0.904,1.0,0.798,0.4815,0.35
997,0.1970,12.0,0.2650,0.1400,804.0,0.3015,0.4655,4.0,0.274,0.4535,...,0.166,38.0,0.407,0.3565,1.0,0.315,1.0,0.298,0.4815,0.35
998,0.5310,45.0,0.2650,0.1400,1845.0,0.3015,0.1695,4.0,0.274,0.4535,...,0.923,23.0,0.407,0.9460,1.0,0.315,1.0,0.798,0.4815,0.85
