# Training a CTGAN model/ synthesizer

In [37]:
%%capture
# Load libraries

import pandas as pd
import os
import warnings

import plotly.graph_objects as go

from datetime import datetime
from sdv.single_table import CTGANSynthesizer
from sdv.datasets.local import load_csvs
from sdv.metadata import SingleTableMetadata

In [38]:
# Ignore warnings
warnings.filterwarnings("ignore")

# Data preperation

In [39]:
%%capture
# Get current timestamp for unique identifiers
current_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Load real data
input_folder = './input/'
input_file_name = input("Enter the input file name: ").lower()
real_data_path = os.path.join(input_folder, input_file_name + '.csv')
real_data = pd.read_csv(real_data_path)

In [40]:
%%capture
# Manually create metadata that describes our real dataset
metadata_dict = {
    "columns": {
        "Timestamp": {
            "sdtype": "datetime",
            "datetime_format": "%Y-%m-%d %H:%M:%S"
        },
        "Source.IP": {
            "sdtype": "categorical"
        },
        "Source.Port": {
            "sdtype": "categorical"
        },
        "Destination.IP": {
            "sdtype": "categorical"
        },
        "Destination.Port": {
            "sdtype": "categorical"
        },
        "Protocol": {
            "sdtype": "categorical"
        },
        "Flow.Duration": {
            "sdtype": "numerical"
        },
        "Total.Fwd.Packets": {
            "sdtype": "numerical"
        },
        "Total.Backward.Packets": {
            "sdtype": "numerical"
        },
        "Total.Length.of.Fwd.Packets": {
            "sdtype": "numerical"
        },
        "Total.Length.of.Bwd.Packets": {
            "sdtype": "numerical"
        }
    },
    "METADATA_SPEC_VERSION": "SINGLE_TABLE_V1"
}
metadata = SingleTableMetadata.load_from_dict(metadata_dict)
# print(metadata)  # DEBUG, validate metadata

Modelling

In [41]:
# Input arguments
load_existing = input("Do you want to load an existing GAN model? (yes/no): ").lower()
if load_existing == "yes":
    model_name = input("Enter the name of the model to load (without '.pkl' extension): ") + '.pkl'
    synthesizer = CTGANSynthesizer.load(filepath=model_name)
else:
    enforce_rounding = input("Enforce rounding for synthetic data? (yes [default] /no): ").lower() == "yes"  # default true, control whether the synthetic data should have same number of decimal digits as the real data
    enforce_min_max_values = input("Enforce min/max values for synthetic data? (yes [default] /no): ").lower() == "yes"  # default true, control whether the synthetic data should adhere to same min/max boundaries as the real data
    epochs = int(input("Enter the number of epochs for training [default=300]: ") or 300)
    cuda = input("Enable CUDA computing? (yes/no): ").lower() == "yes"
    synthesizer = CTGANSynthesizer(
        metadata=metadata,
        enforce_rounding=enforce_rounding,
        enforce_min_max_values=enforce_min_max_values,
        epochs=epochs,
        verbose=True,
        cuda=cuda
    )

In [42]:
# Create model path
model_folder = 'model'
os.makedirs(model_folder, exist_ok=True)
model_number = 1
model_base_name = f'_ctgan_i_{args.input}_ep{args.epochs}'
model_name = 'mdl' + str(model_number) + model_base_name + '.pkl'
model_path = os.path.join(model_folder, model_name)
print(model_path)

while os.path.exists(model_path):
    print('Model path already exists:', model_path)
    model_number += 1
    model_name = 'mdl' + str(model_number) + model_base_name + '.pkl'
    model_path = os.path.join(model_folder, model_name)

print("Final model path:", model_path)

Final model path: model\ctgan_medium_input_ep2_mdl1.pkl


# Training the model synthesizer

In [43]:
# Training
synthesizer.fit(real_data)

# Save trained model
synthesizer.save(filepath=model_path)

PerformanceAlert: Using the CTGANSynthesizer on this data is not recommended. To model this data, CTGAN will generate a large number of columns.

Original Column Name        Est # of Columns (CTGAN)
Timestamp                   11
Source.IP                   183
Source.Port                 2494
Destination.IP              348
Destination.Port            1089
Protocol                    1
Flow.Duration               11
Total.Fwd.Packets           11
Total.Backward.Packets      11
Total.Length.of.Fwd.Packets 11
Total.Length.of.Bwd.Packets 11

We recommend preprocessing discrete columns that can have many values, using 'update_transformers'. Or you may drop columns that are not necessary to model. (Exit this script using ctrl-C)


Gen. (5.34) | Discrim. (-0.42): 100%|██████████| 2/2 [01:14<00:00, 37.44s/it]


In [44]:
# Check parameters
synthesizer.get_parameters()
# synthesizer.get_metadata()

{'enforce_min_max_values': True,
 'enforce_rounding': True,
 'locales': ['en_US'],
 'embedding_dim': 128,
 'generator_dim': (256, 256),
 'discriminator_dim': (256, 256),
 'generator_lr': 0.0002,
 'generator_decay': 1e-06,
 'discriminator_lr': 0.0002,
 'discriminator_decay': 1e-06,
 'batch_size': 500,
 'discriminator_steps': 1,
 'log_frequency': True,
 'verbose': True,
 'epochs': 2,
 'pac': 10,
 'cuda': False}

In [45]:
# Create log folder if it does not exist
log_folder = 'logs'
os.makedirs(log_folder, exist_ok=True)
training_log_file_name = 'information.txt'
training_log_file_name = os.path.join(log_folder, training_log_file_name)
with open(training_log_file_name, 'a') as log:
    log.write(f"--- CTGAN model training started at: {current_timestamp}\n")
    log.write(f"\tModel name: {model_name}\n")
    log.write(f"\tModel path: {model_path}\n")
    log.write(f"\tModel was {'NOT' if not {cuda} else ''} trained using CUDA\n")
    log.write(f"\tTraining on {epochs} epochs\n")
    log.write(f"\tInput file path: {os.path.join('input', input_file_name+'.csv')}\n")

# Save model losses in log folder as CSV
losses = synthesizer.get_loss_values()
model_log_file_name = f'losses_{model_name}_{current_timestamp}.csv'
model_log_file_path = os.path.join(log_folder, model_log_file_name)
losses.to_csv(model_log_file_path, index=False)

print("Final logs for losses path: ", model_log_file_path)

Final logs for losses path:  logs\losses_ctgan_medium_input_ep2_mdl1.pkl_20240419_111435.csv


#  Some outputs

In [46]:
losses.head()

Unnamed: 0,Epoch,Generator Loss,Discriminator Loss
0,0,5.464275,-0.296103
1,1,5.339439,-0.421249


# Examine loss values

In [57]:
loss_values = pd.read_csv(model_log_file_path)
#logs\losses_ctgan_medium_input_ep2_mdl1.pkl_20240419_111435.csv

In [59]:
# Plot loss function of generator and discriminator
fig = go.Figure(data=[go.Scatter(x=loss_values['Epoch'], y=loss_values['Generator Loss'], name='Generator Loss'),
                      go.Scatter(x=loss_values['Epoch'], y=loss_values['Discriminator Loss'], name='Discriminator Loss')])

# Update the layout for best viewing
fig.update_layout(template='plotly_white',
                    legend_orientation="h",
                    legend=dict(x=0, y=1.1))

title = 'CTGAN loss function: '
fig.update_layout(title=title, xaxis_title='Epoch', yaxis_title='Loss')
fig.show()

# Create evaluation folder if it does not exist
evaluation_folder = 'evaluation'
os.makedirs(evaluation_folder, exist_ok=True)

# Save plot
plot_html_file_path = os.path.join(evaluation_folder, f'plot_{model_log_file_name}.html')
fig.write_html(plot_html_file_path)