# Training a CTGAN model/ synthesizer

In [1]:
%%capture
# Load libraries

import pandas as pd
import os
from datetime import datetime
from sdv.single_table import CTGANSynthesizer
from sdv.datasets.local import load_csvs
from sdv.metadata import SingleTableMetadata

In [2]:
# Ignore warnings
warnings.filterwarnings("ignore")

# Data preperation

In [3]:
%%capture
# Get current timestamp for unique identifiers
current_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Load real data
input_folder = './input/'
input_file_name = input("Enter the input file name: ").lower()
real_data_path = os.path.join(input_folder, input_file_name + '.csv')
real_data = pd.read_csv(real_data_path)

In [4]:
%%capture
# Manually create metadata that describes our real dataset
metadata_dict = {
    "columns": {
        "Timestamp": {
            "sdtype": "datetime",
            "datetime_format": "%Y-%m-%d %H:%M:%S"
        },
        "Source.IP": {
            "sdtype": "categorical"
        },
        "Source.Port": {
            "sdtype": "categorical"
        },
        "Destination.IP": {
            "sdtype": "categorical"
        },
        "Destination.Port": {
            "sdtype": "categorical"
        },
        "Protocol": {
            "sdtype": "categorical"
        },
        "Flow.Duration": {
            "sdtype": "numerical"
        },
        "Total.Fwd.Packets": {
            "sdtype": "numerical"
        },
        "Total.Backward.Packets": {
            "sdtype": "numerical"
        },
        "Total.Length.of.Fwd.Packets": {
            "sdtype": "numerical"
        },
        "Total.Length.of.Bwd.Packets": {
            "sdtype": "numerical"
        }
    },
    "METADATA_SPEC_VERSION": "SINGLE_TABLE_V1"
}
metadata = SingleTableMetadata.load_from_dict(metadata_dict)
# print(metadata)  # DEBUG, validate metadata

Modelling

In [5]:
# Input arguments
load_existing = input("Do you want to load an existing GAN model? (yes/no): ").lower()
if load_existing == "yes":
    model_name = input("Enter the name of the model to load (without '.pkl' extension): ") + '.pkl'
    synthesizer = CTGANSynthesizer.load(filepath=model_name)
else:
    enforce_rounding = input("Enforce rounding for synthetic data? (yes [default] /no): ").lower() == "yes"  # default true, control whether the synthetic data should have same number of decimal digits as the real data
    enforce_min_max_values = input("Enforce min/max values for synthetic data? (yes [default] /no): ").lower() == "yes"  # default true, control whether the synthetic data should adhere to same min/max boundaries as the real data
    epochs = int(input("Enter the number of epochs for training (default=300): ") or 300)
    cuda = input("Enable CUDA computing? (yes/no): ").lower() == "yes"
    synthesizer = CTGANSynthesizer(
        metadata=metadata,
        enforce_rounding=enforce_rounding,
        enforce_min_max_values=enforce_min_max_values,
        epochs=epochs,
        verbose=True,
        cuda=cuda
    )

In [11]:
# Create model path
model_folder = 'model'
os.makedirs(model_folder, exist_ok=True)
model_base_name = f"ctgan_{input_file_name}_ep{epochs}_mdl"  # split argument removes file extension
model_number = 1
model_name = model_base_name + str(model_number) + '.pkl'
model_path = os.path.join(model_folder, model_name)

while os.path.exists(model_path):
    print('Model path already exists:', model_path)
    model_number += 1
    model_path = os.path.join(model_folder, model_base_name + str(model_number) + '.pkl')

print("Final model path:", model_path)

Final model path: model\ctgan_filtered_new_data_ep1_mdl1.pkl


# Training the model synthesizer

In [14]:
# Training
synthesizer.fit(real_data)

# Save trained model
synthesizer.save(filepath=model_path)

PerformanceAlert: Using the CTGANSynthesizer on this data is not recommended. To model this data, CTGAN will generate a large number of columns.

Original Column Name        Est # of Columns (CTGAN)
Timestamp                   11
Source.IP                   97
Source.Port                 2373
Destination.IP              524
Destination.Port            404
Protocol                    3
Flow.Duration               11
Total.Fwd.Packets           11
Total.Backward.Packets      11
Total.Length.of.Fwd.Packets 11
Total.Length.of.Bwd.Packets 11

We recommend preprocessing discrete columns that can have many values, using 'update_transformers'. Or you may drop columns that are not necessary to model. (Exit this script using ctrl-C)


Gen. (5.11) | Discrim. (-0.39): 100%|██████████| 1/1 [00:23<00:00, 23.29s/it]


Final logs for losses path: logs\losses_ctgan_filtered_new_data_ep1_mdl1.pkl.csv


In [15]:
# Save model losses in log folder as CSV
losses = synthesizer.get_loss_values()
model_log_file_name = f'losses_{model_name}_{current_timestamp}.csv'
model_log_file_path = os.path.join(log_folder, model_log_file_name)
losses.to_csv(model_log_file_path, index=False)

print("Final logs for losses path: ", model_log_file_path)

Final logs for losses path:  logs\losses_ctgan_filtered_new_data_ep1_mdl1.pkl.csv


#  Some outputs

In [13]:
losses.head()

Unnamed: 0,Epoch,Generator Loss,Discriminator Loss
0,0,5.293233,-0.391187
