# CTGAN Model

In this guide we will go through a series of steps that will let you discover
functionalities of the `CTGAN` model, including how to:

* Create an instance of `CTGAN`.
* Fit the instance to your data.
* Generate synthetic versions of your data.
* Use the a `CTGAN` to anonymize PII information.
* Customize the data tranformations to improve the learning process.
* Specify model hyperparameters to improve the output quality.

## What is CTGAN?

The `sdv.tabular.CTGAN` model from `SDV` is based on the `ctgan.CTGANSynthesizer` class
from the [CTGAN library](https://github.com/sdv-dev/CTGAN), a Deep Learning based data
synthesizer that uses Generative Adversarial Networks to generate tabular data and which
was presented at the NeurIPS 2020 conference by the paper titled [Modeling Tabular data
using Conditional GAN](https://arxiv.org/abs/1907.00503). For more details about the model,
please read the linked paper and visit the [CTGAN library](https://github.com/sdv-dev/CTGAN).

Let's now discover how to learn a dataset and later on generate synthetic data with the same
format and statistical properties by using the `CTGAN` class from SDV.

## Quick Usage

We will start by loading one of our demo datasets, the `student_placements`, which contains information
about MBA students that applied for placements during the year 2020.

<div class="alert alert-warning">

WARNING

In order to follow this guide you need to have `ctgan` installed on your system.
If you have not done it yet, please install `ctgan` now by executing the command
`pip install sdv[ctgan]` in a terminal.

</div>

In [1]:
# Setup logging and warnings - change ERROR to INFO for increased verbosity
import logging
logging.basicConfig(level=logging.ERROR)

logging.getLogger().setLevel(level=logging.ERROR)
logging.getLogger('sdv').setLevel(level=logging.ERROR)

import warnings
warnings.simplefilter("ignore")

In [2]:
from sdv.demo import load_tabular_demo

data = load_tabular_demo('student_placements')
data.head().T

Unnamed: 0,0,1,2,3,4
student_id,17264,17265,17266,17267,17268
gender,M,M,M,M,M
second_perc,67,79.33,65,56,85.8
high_perc,91,78.33,68,52,73.6
high_spec,Commerce,Science,Arts,Science,Commerce
degree_perc,58,77.48,64,52,73.3
degree_type,Sci&Tech,Sci&Tech,Comm&Mgmt,Sci&Tech,Comm&Mgmt
work_experience,False,True,False,False,False
experience_years,0,1,0,0,0
employability_perc,55,86.5,75,66,96.8


As you can see, this table contains information about students which includes, among other things:

- Their id and gender
- Their grades and specializations
- Their work experience
- The salary that they where offered
- The duration and dates of their placement

You will notice that there is data with the following characteristics:

- There are float, integer, boolean, categorical and datetime values.
- There are some variables that have missing data. In particular, all the data related to the
  placement details is missing in the rows where the studen was not placed.

Let us use `CTGAN` to learn this data and then sample synthetic data about new students
to see how well de model captures the characteristics indicated above. In order to do this you wil
need to:

- Import the `sdv.tabular.CTGAN` class and create an instance of it.
- Call its `fit` method passing our table.
- Call its `sample` method indicating the number of synthetic rows that you want to generate.

In [3]:
from sdv.tabular import CTGAN

model = CTGAN()
model.fit(data)

<div class="alert alert-info">

**NOTE**

Notice that the model `fitting` process took care of transforming the different fields using the
appropriate [Reversible Data Transforms](http://github.com/sdv-dev/RDT) to ensure that the data has
a format that the CTGANSynthesizer class can handle.

</div>

### Generate synthetic data from the model

Once the modeling has finished you are ready to generate new synthetic data by calling the `sample` method
from your model passing the number of rows that we want to generate.

In [4]:
new_data = model.sample(200)

This will return a table identical to the one which the model was fitted on, but filled with new data
which resembles the original one.

In [5]:
new_data.head()

Unnamed: 0,student_id,gender,second_perc,high_perc,high_spec,degree_perc,degree_type,work_experience,experience_years,employability_perc,mba_spec,mba_perc,salary,placed,start_date,end_date,duration
0,17441,M,73.690152,69.673688,Commerce,51.052557,Comm&Mgmt,False,0,70.588885,Mkt&HR,63.997868,30936.144158,False,2020-01-02 23:16:15.017963776,NaT,3.0
1,17538,F,66.212947,70.874036,Science,65.108024,Comm&Mgmt,False,0,60.770266,Mkt&HR,62.272543,26538.42044,True,2020-01-14 01:34:05.305196544,NaT,
2,17319,F,69.882207,93.998807,Commerce,53.245815,Comm&Mgmt,False,0,57.880872,Mkt&HR,73.016935,78386.180922,False,2020-08-12 02:51:35.667222784,2020-11-11 15:57:13.245723648,12.0
3,17450,F,72.539752,66.945866,Commerce,87.550494,Comm&Mgmt,True,0,68.509602,Mkt&Fin,52.755707,36540.572434,False,2020-02-18 04:21:34.996754176,NaT,6.0
4,17467,M,44.049549,64.884668,Commerce,69.509229,Comm&Mgmt,False,0,46.244579,Mkt&HR,61.413745,,False,NaT,NaT,3.0


<div class="alert alert-info">

**Note**

You can control the number of rows by specifying the number of `samples` in the
`model.sample(<num_rows>)`. To test, try `model.sample(10000)`. Note that the original 
table only had ~200 rows.

</div>

### Save and Load the model

In many scenarios it will be convenient to generate synthetic versions of your data
directly in systems that do not have access to the original data source. For example,
if you may want to generate testing data on the fly inside a testing environment that
does not have access to your production database. In these scenarios, fitting the
model with real data every time that you need to generate new data is feasible, so you
will need to fit a model in your production environment, save the fitted model into a
file, send this file to the testing environment and then load it there to be able to
`sample` from it.

Let's see how this process works.

#### Save and share the model

Once you have fitted the model, all you need to do is call its `save` method passing the
name of the file in which you want to save the model. Note that the extension of the filename
is not relevant, but we will be using the `.pkl` extension to highlight that the serialization
protocol used is [pickle](https://docs.python.org/3/library/pickle.html).

In [6]:
model.save('my_model.pkl')

This will have created a file called `my_model.pkl` in the same directory in which you are
running SDV.

<div class="alert alert-info">

**IMPORTANT**
    
If you inspect the generated file you will notice that its size is much smaller
than the size of the data that you used to generate it. This is because the serialized model
contains **no information about the original data**, other than the parameters it needs to
generate synthetic versions of it. This means that you can safely share this `my_model.pkl`
file without the risc of disclosing any of your real data!
    
</div>

#### Load the model and generate new data

The file you just generated can be send over to the system where the synthetic data will be
generated. Once it is there, you can load it using the `CTGAN.load` method, and
then you are ready to sample new data from the loaded instance:

In [7]:
loaded = CTGAN.load('my_model.pkl')
new_data = loaded.sample(200)

<div class="alert alert-warning">
    
**WARNING**
    
Notice that the system where the model is loaded needs to also have `sdv` and `ctgan`
installed, otherwise it will not be able to load the model and use it.
    
</div>

### Specifying the Primary Key of the table

One of the first things that you may have noticed when looking that demo data
is that there is a `student_id` column which acts as the primary key of the table,
and which is supposed to have unique values. Indeed, if we look at the number of
times that each value appears, we see that all of them appear at most once:

In [8]:
data.student_id.value_counts().max()

1

However, if we look at the synthetic data that we generated, we observe that there
are some values that appear more than once:

In [9]:
new_data.student_id.value_counts().max()

4

In [10]:
new_data[new_data.student_id == new_data.student_id.value_counts().index[0]]

Unnamed: 0,student_id,gender,second_perc,high_perc,high_spec,degree_perc,degree_type,work_experience,experience_years,employability_perc,mba_spec,mba_perc,salary,placed,start_date,end_date,duration
1,17466,F,66.120204,68.977991,Science,58.766869,Sci&Tech,True,0,76.883011,Mkt&HR,61.869871,32326.346265,True,2019-12-28 04:13:46.796109568,NaT,6.0
24,17466,F,77.944572,91.30787,Commerce,55.883138,Sci&Tech,False,0,67.3786,Mkt&HR,59.579028,22911.391337,False,2020-01-12 11:04:01.895049216,2020-05-14 07:09:07.497003008,
131,17466,F,54.45764,71.54411,Science,73.673943,Comm&Mgmt,True,0,46.854488,Mkt&HR,72.22356,30063.879104,True,2019-12-19 23:05:25.145376512,2020-12-05 06:28:04.171285248,12.0
191,17466,M,61.799714,58.562872,Science,51.892368,Comm&Mgmt,True,0,43.525589,Mkt&HR,68.949925,,False,2020-01-19 07:37:45.552702208,2020-03-22 23:50:56.617063424,6.0


This happens because the model was not notified at any point about the fact that the
`student_id` had to be unique, so when it generates new data it will provoke collisions
sooner or later. In order to solve this, we can pass the argument `primary_key` to our
model when we create it, indicating the name of the column that is the index of the table.

In [11]:
model = CTGAN(
    primary_key='student_id'
)
model.fit(data)
new_data = model.sample(200)

As a result, the model will learn that this column must be unique and generate a unique
sequence of valures for the column:

In [12]:
new_data.head()

Unnamed: 0,student_id,gender,second_perc,high_perc,high_spec,degree_perc,degree_type,work_experience,experience_years,employability_perc,mba_spec,mba_perc,salary,placed,start_date,end_date,duration
0,0,F,44.023595,89.18338,Science,70.297832,Comm&Mgmt,True,0,108.294297,Mkt&HR,72.756088,22129.511982,False,2020-05-26 22:27:06.678130688,NaT,12.0
1,1,F,83.407289,38.3513,Commerce,48.800467,Comm&Mgmt,True,0,47.070753,Mkt&HR,57.660073,26571.849716,False,NaT,2020-02-21 18:43:17.294135552,
2,2,M,84.120641,72.486169,Science,50.828555,Comm&Mgmt,False,0,64.641102,Mkt&HR,67.17083,20808.880193,True,2020-02-17 02:40:33.249085696,2020-08-25 20:33:43.158897152,
3,3,M,74.627655,60.684835,Commerce,77.626732,Sci&Tech,True,0,72.917882,Mkt&Fin,50.690796,,True,2020-02-21 11:38:14.343960576,NaT,12.0
4,4,M,77.310701,70.918237,Commerce,67.685123,Sci&Tech,True,0,49.458468,Mkt&Fin,49.155378,56774.073666,True,2020-01-14 21:12:54.684945152,2020-12-12 02:17:26.096853760,3.0


In [13]:
new_data.student_id.value_counts().max()

1

### Anonymizing Personally Identifiable Information (PII) 

There will be many cases where the data will contain Personally Identifiable Information
which we cannot disclose. In these cases, we will want our Tabular Models to replace the
information within these fields with fake, simulated data that looks similar to the real
one but does not contain any of the original values.

Let's load a new dataset that contains a PII field, the `student_placements_pii` demo, and
try to generate synthetic versions of it that do not contain any of the PII fields.

<div class="alert alert-info">
    
**NOTE**
    
The `student_placements_pii` dataset is a modified version of the `student_placements`
dataset with one new field, `address`, which contains PII information about the students.
Notice that this additional `address` field has been simulated and does not correspond to data
from the real users.

</div>

In [14]:
data_pii = load_tabular_demo('student_placements_pii')

In [15]:
data_pii.head().T

Unnamed: 0,0,1,2,3,4
student_id,17264,17265,17266,17267,17268
address,"70304 Baker Turnpike\nEricborough, MS 15086","805 Herrera Avenue Apt. 134\nMaryview, NJ 36510","3702 Bradley Island\nNorth Victor, FL 12268",Unit 0879 Box 3878\nDPO AP 42663,"96493 Kelly Canyon Apt. 145\nEast Steven, NC 3..."
gender,M,M,M,M,M
second_perc,67,79.33,65,56,85.8
high_perc,91,78.33,68,52,73.6
high_spec,Commerce,Science,Arts,Science,Commerce
degree_perc,58,77.48,64,52,73.3
degree_type,Sci&Tech,Sci&Tech,Comm&Mgmt,Sci&Tech,Comm&Mgmt
work_experience,False,True,False,False,False
experience_years,0,1,0,0,0


If we use our tabular model on this new data we will see how the synthetic
data that it generates discloses the addresses from the real students:

In [16]:
model = CTGAN(
    primary_key='student_id',
)
model.fit(data_pii)

In [17]:
new_data_pii = model.sample(200)
new_data_pii.head()

Unnamed: 0,student_id,address,gender,second_perc,high_perc,high_spec,degree_perc,degree_type,work_experience,experience_years,employability_perc,mba_spec,mba_perc,salary,placed,start_date,end_date,duration
0,0,"21378 Russell Neck\nLake Robertstad, KS 49747",M,75.044847,71.576348,Commerce,83.096791,Comm&Mgmt,True,0,79.055771,Mkt&HR,64.658703,33693.705264,True,NaT,2020-06-29 15:44:45.912531456,3.0
1,1,"33435 Vazquez Via\nSouth Kristinaberg, FL 98070",M,86.947312,60.604005,Science,67.344151,Comm&Mgmt,False,0,82.515924,Mkt&HR,62.960067,,True,2020-02-29 22:14:10.607650048,2020-10-26 19:38:02.735934208,
2,2,Unit 9984 Box 1462\nDPO AA 02334,M,58.17755,56.086358,Science,63.923179,Comm&Mgmt,True,1,58.084807,Mkt&HR,60.695632,19502.759486,True,2020-01-19 00:34:44.470542336,2020-03-05 12:27:09.552780544,3.0
3,3,"5449 Evans Well\nPort Brett, DE 56274",F,69.313716,32.146236,Arts,68.991617,Comm&Mgmt,True,0,92.849033,Mkt&Fin,52.999346,18857.791024,True,2020-01-20 19:26:35.442548992,2020-12-05 19:34:21.231077632,12.0
4,4,"46861 Hanson Ridges Suite 587\nNorth Timstad, ...",F,81.886067,57.003139,Science,58.693238,Sci&Tech,False,1,87.34737,Mkt&HR,78.187187,26438.166833,False,2020-02-17 06:47:48.959166720,2020-11-05 18:49:39.082294784,


In [18]:
new_data_pii.address.isin(data_pii.address).sum()

200

In order to solve this, we can pass an additional argument `anonymize_fields` to
our model when we create the instance. This `anonymize_fields` argument will need
to be a dictionary that contains:

- The name of the field that we want to anonymize.
- The category of the field that we want to use when we generate fake values for it.

The list complete list of possible categories can be seen in the [Faker Providers
](https://faker.readthedocs.io/en/master/providers.html) page, and it contains a huge
list of concepts such as:

- name
- address
- country
- city
- ssn
- credit_card_number
- credit_card_expier
- credit_card_security_code
- email
- telephone
- ...

In this case, since the field is an e-mail address, we will pass a dictionary indicating
the category `address`

In [19]:
model = CTGAN(
    primary_key='student_id',
    anonymize_fields={
        'address': 'address'
    }
)
model.fit(data_pii)

As a result, we can see how the real `address` values have been replaced by other fake
addresses that were not taken from the real data that we learned.

In [20]:
new_data_pii = model.sample(200)
new_data_pii.head()

Unnamed: 0,student_id,address,gender,second_perc,high_perc,high_spec,degree_perc,degree_type,work_experience,experience_years,employability_perc,mba_spec,mba_perc,salary,placed,start_date,end_date,duration
0,0,"662 Amanda Roads\nJasmineview, VA 69949",M,74.266541,54.266827,Commerce,81.964649,Comm&Mgmt,False,0,67.143385,Mkt&HR,56.880193,51752.841184,True,2020-06-01 12:43:17.050192384,2021-01-30 06:59:09.015085824,12.0
1,1,7276 Williams Crescent Suite 953\nRichardmouth...,M,84.995207,57.042033,Science,76.940656,Others,True,0,80.944589,Mkt&Fin,61.678602,,False,NaT,2020-10-20 12:15:40.566909696,
2,2,"167 Amanda Viaduct Apt. 638\nJefferymouth, TN ...",M,89.675094,50.594566,Science,67.424508,Comm&Mgmt,False,0,106.845877,Mkt&Fin,64.586527,30381.936029,True,NaT,NaT,3.0
3,3,"8420 Pierce Mission Suite 496\nBaileyview, WY ...",F,85.881457,44.975642,Commerce,79.230513,Comm&Mgmt,False,1,64.628816,Mkt&HR,67.132052,52257.351545,True,NaT,2020-09-03 11:02:14.521708800,
4,4,USNV Williams\nFPO AA 50494,M,84.652475,64.385631,Science,72.363807,Comm&Mgmt,True,2,74.973994,Mkt&Fin,49.097866,,True,2020-05-31 06:10:52.135067904,NaT,6.0


In [21]:
new_data_pii.address.isin(data_pii.address).sum()

0

## Advanced Usage

Now that we have discovered the basics, let's go over a few more advanced usage examples
and see the different arguments that we can pass to our `CTGAN` Model in order to
customize it to our needs.

### CTGAN Hyperparameters

A part from the common Tabular Model arguments, `CTGAN` has a number of additional
hyperparameters that control its learning behavior and can impact on the
performance of the model, both in terms of quality of the generated data
and computationa time.

#### epochs and batch size

The first hyperparameters that we see are the `epochs` and `batch_size` arguments,
which control the number of iterations that the model will perform to optimize
its parameters, as well as the number of samples used in each step. Its default
values are `300` and `500` respectively, and `batch_size` needs to always be a
value which is multiple of `10`.

These hyperparameters have a very direct effect in time the training process lasts
but also on the performance of the data, so for new datasets, you might want to start
by setting a low value on both of them to see how long the training process takes on
your data and later on increase the number to acceptable values in order to improve
the performance.

#### log_frequency

Whether to use log frequency of categorical levels in conditional sampling. It
defaults to `True`.

This argument affects how the model processes the frequencies of the categorical
values that are used to condition the rest of the values. In some cases,
changing it to `False` could lead to better performance.

#### Neural Network dimensions

`CTGAN` has the following hyperparameters that allow you to control the
size of the different layers that compose its neural networks:

- embedding_dim (int): Size of the random sample passed to the Generator. Defaults to 128.
- gen_dim (tuple or list of ints): Size of the output samples for each one of the Residuals.
  A Resiudal Layer will be created for each one of the values provided. Defaults to (256, 256).
- dis_dim (tuple or list of ints): Size of the output samples for each one of the Discriminator
  Layers. A Linear Layer will be created for each one of the values provided. Defaults to (256, 256).

#### l2scale

The `l2scale` argument, which defaults to `1e-6`, sets the wheight Decay of the Adam Optimizer
used to optimize the Neural Networks.

#### verbose

Whether to print fit progress on stdout. Defaults to `False`.

<div class="alert alert-warning">
    
**WARNING**
    
The value that you set on the `batch_size` argument must always be
a multiple of `10`!

</div>

As an example, we will try to fit the `CTGAN` model slightly increasing the number of epochs,
reducing the `batch_size`, adding one additional layer to the models involved and using a
smaller wright decay.

Before we start, we will evaluate the qualtiy of the previously generated data using the
`sdv.evaluation.evaluate` function

In [22]:
from sdv.evaluation import evaluate

evaluate(new_data, data)

0.7358033112185358

Afterwards, we create a new instance of the `CTGAN` model with the
hyperparameter values that we want to use

In [23]:
model = CTGAN(
    primary_key='student_id',
    epochs=500,
    batch_size=100,
    gen_dim=(256, 256, 256),
    dis_dim=(256, 256, 256),
    l2scale=1e-07
)

And fit to our data.

In [24]:
model.fit(data)

Finally, we are ready to generate new data and evaluate the results.

In [25]:
new_data = model.sample(len(data))

In [26]:
new_data

Unnamed: 0,student_id,gender,second_perc,high_perc,high_spec,degree_perc,degree_type,work_experience,experience_years,employability_perc,mba_spec,mba_perc,salary,placed,start_date,end_date,duration
0,0,F,61.821628,29.688149,Arts,57.488105,Comm&Mgmt,False,0,42.272295,Mkt&HR,65.532350,,False,NaT,NaT,
1,1,M,70.398101,57.910618,Commerce,53.308006,Sci&Tech,False,0,54.243483,Mkt&HR,69.859354,27747.615009,False,NaT,NaT,3.0
2,2,M,67.381182,98.496980,Science,90.983930,Comm&Mgmt,True,0,80.179696,Mkt&Fin,69.497789,29140.316281,True,2019-12-13 00:44:54.337081344,2020-09-08 07:16:15.742616064,3.0
3,3,F,52.892598,87.604954,Commerce,78.490973,Others,False,0,61.335932,Mkt&Fin,68.648329,,True,2020-03-21 02:30:29.114332672,NaT,6.0
4,4,M,55.501661,43.368819,Commerce,61.034177,Others,False,0,66.712305,Mkt&Fin,66.047435,29093.362398,False,2019-12-20 08:56:08.563084800,NaT,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
210,210,M,77.230473,95.292291,Commerce,86.871845,Comm&Mgmt,False,0,51.083926,Mkt&HR,71.319923,27761.633164,True,2020-03-01 16:48:09.925997568,2020-03-12 21:40:10.422644480,12.0
211,211,M,68.485307,73.750524,Commerce,78.009572,Comm&Mgmt,False,0,67.037164,Mkt&Fin,60.492617,27248.170274,True,2020-02-12 08:31:47.660419072,2020-03-12 19:23:55.035388672,3.0
212,212,F,96.553530,81.887793,Arts,79.522649,Comm&Mgmt,True,0,56.305032,Mkt&Fin,70.966608,16853.761961,True,2020-03-20 18:09:25.575686400,2020-09-09 13:21:02.129512448,
213,213,F,67.031508,91.089137,Commerce,86.340928,Comm&Mgmt,False,0,65.853228,Mkt&HR,68.286300,27257.615784,True,2019-12-23 08:14:36.942717952,NaT,6.0


In [27]:
from sdv.evaluation import evaluate

evaluate(new_data, data)

0.732620396484379

As we can see, in this case these modifications changed the obtained results slightly,
but they did neither introduce dramatic changes in the performance.

### How do I specify constraints?

If you look closely at the data you may notice that some properties were
not completely captured by the model. For example, you may have seen
that sometimes the model produces an `experience_years` number greater
than `0` while also indicating that `work_experience` is `False`. These
type of properties are what we call `Constraints` and can also be
handled using `SDV`. For further details about them please visit the
[Handling Constraints](04_Handling_Constraints.ipynb) tutorial.

### Can I evaluate the Synthetic Data?

A very common question when someone starts using **SDV** to generate
synthetic data is: *\"How good is the data that I just generated?\"*

In order to answer this question, **SDV** has a collection of metrics
and tools that allow you to compare the *real* that you provided and the
*synthetic* data that you generated using **SDV** or any other tool.

You can read more about this in the [Evaluating Synthetic Data Generators](05_Evaluating_Synthetic_Data_Generators.ipynb) tutorial.