# Generative Model Training Tutorial

This notebook demonstrates how to train generative models for network trace data generation.

**Example command-line equivalent:**
```bash
python3 driver.py \
    --config_partition small-scale \
    --dataset_name caida \
    --model_name realtabformer-tabular \
    --order_csv_by_timestamp
```


## Step 1: Setup - Import Libraries and Set Parameters

Configure the experiment by setting the model, dataset, and other parameters.


In [6]:
# Import required libraries
import os
import time
import datetime
import pandas as pd
import numpy as np
# ignore all warnings
import warnings
warnings.filterwarnings('ignore')

# Set experiment parameters (modify these as needed)
config_partition = 'small-scale'  # Options: 'small-scale', 'medium-scale', etc.
dataset_name = 'caida'  # Options: 'caida', 'ugr16', 'cidds', 'ton', 'm57', etc.
model_name = 'tvae'  # Options: 'realtabformer-tabular', 'ctgan', 'tvae', etc.
order_csv_by_timestamp = True  # Whether to sort output by timestamp

# Generate timestamp for this run
now = datetime.datetime.now()
cur_time = now.strftime("%Y%m%d%H%M%S") + str(now.microsecond // 1000).zfill(3)

print(f"Experiment Configuration:")
print(f"  Model: {model_name}")
print(f"  Dataset: {dataset_name}")
print(f"  Config Partition: {config_partition}")
print(f"  Timestamp: {cur_time}")


Experiment Configuration:
  Model: tvae
  Dataset: caida
  Config Partition: small-scale
  Timestamp: 20251107110332600


## Step 2: Load Configuration and Setup Directories

Load the model and dataset configuration, and create necessary output directories.


In [7]:
# Load configuration based on partition
if config_partition == 'small-scale':
    from config_small_scale import configs, NETGPT_BASE_FOLDER
elif config_partition == 'medium-scale':
    from config_medium_scale import configs, NETGPT_BASE_FOLDER
else:
    raise ValueError(f"Unknown config partition: {config_partition}")

# Setup result directories
RESULT_PATH_BASE = os.path.join(NETGPT_BASE_FOLDER, "results", config_partition)
RESULT_PATH = {
    'runs': os.path.join(RESULT_PATH_BASE, "runs"),
    'csv': os.path.join(RESULT_PATH_BASE, 'csv'),
    'txt': os.path.join(RESULT_PATH_BASE, 'txt'),
    'npz': os.path.join(RESULT_PATH_BASE, 'npz'),
    'time': os.path.join(RESULT_PATH_BASE, 'time'),
}

# Create directories if they don't exist
for path in RESULT_PATH.values():
    os.makedirs(path, exist_ok=True)

# Create work folder for this specific run
work_folder = os.path.join(RESULT_PATH['runs'], f'{model_name}_{dataset_name}_{cur_time}')
os.makedirs(work_folder, exist_ok=True)

# Get configuration for this specific model and dataset
from config_io import Config
current_config = Config(configs[model_name][dataset_name])

print(f"Work folder: {work_folder}")
print(f"\nConfiguration loaded for {model_name} on {dataset_name}")


Work folder: /home/lesley/generative-trace-tutorials/results/small-scale/runs/tvae_caida_20251107110332600

Configuration loaded for tvae on caida


## Step 3: Load and Inspect the Dataset

Read the raw CSV data and perform initial preprocessing (e.g., drop unnecessary columns).


In [8]:
# Determine dataset type (pcap or netflow)
if dataset_name in ['ugr16', 'cidds', 'ton']:
    dataset_type = 'netflow'
elif dataset_name in ['caida', 'dc', 'ca', 'm57']:
    dataset_type = 'pcap'
else:
    raise ValueError(f"Unknown dataset name: {dataset_name}")

# Load the raw CSV file
df = pd.read_csv(current_config.raw_csv_file)

# For PCAP datasets, drop unnecessary columns (version, ihl, chksum)
if dataset_type == "pcap":
    dropped_columns = []
    for col in ['version', 'ihl', 'chksum']:
        if col in df.columns:
            dropped_columns.append(col)
    if dropped_columns:
        df.drop(columns=dropped_columns, inplace=True)
        print(f"Dropped columns: {dropped_columns}")

print(f"\nDataset shape: {df.shape}")
print(f"Columns: {list(df.columns)}")
print(f"\nFirst few rows:")
df.head()


Dropped columns: ['version', 'ihl', 'chksum']

Dataset shape: (100000, 104)
Columns: ['srcip_31', 'srcip_30', 'srcip_29', 'srcip_28', 'srcip_27', 'srcip_26', 'srcip_25', 'srcip_24', 'srcip_23', 'srcip_22', 'srcip_21', 'srcip_20', 'srcip_19', 'srcip_18', 'srcip_17', 'srcip_16', 'srcip_15', 'srcip_14', 'srcip_13', 'srcip_12', 'srcip_11', 'srcip_10', 'srcip_9', 'srcip_8', 'srcip_7', 'srcip_6', 'srcip_5', 'srcip_4', 'srcip_3', 'srcip_2', 'srcip_1', 'srcip_0', 'dstip_31', 'dstip_30', 'dstip_29', 'dstip_28', 'dstip_27', 'dstip_26', 'dstip_25', 'dstip_24', 'dstip_23', 'dstip_22', 'dstip_21', 'dstip_20', 'dstip_19', 'dstip_18', 'dstip_17', 'dstip_16', 'dstip_15', 'dstip_14', 'dstip_13', 'dstip_12', 'dstip_11', 'dstip_10', 'dstip_9', 'dstip_8', 'dstip_7', 'dstip_6', 'dstip_5', 'dstip_4', 'dstip_3', 'dstip_2', 'dstip_1', 'dstip_0', 'srcport_15', 'srcport_14', 'srcport_13', 'srcport_12', 'srcport_11', 'srcport_10', 'srcport_9', 'srcport_8', 'srcport_7', 'srcport_6', 'srcport_5', 'srcport_4', 'src

Unnamed: 0,srcip_31,srcip_30,srcip_29,srcip_28,srcip_27,srcip_26,srcip_25,srcip_24,srcip_23,srcip_22,...,dstport_1,dstport_0,proto,time,pkt_len,tos,id,flag,off,ttl
0,1,0,0,1,0,0,1,1,1,1,...,1,0,TCP,1521118773289502,40,40,0,2,0,52
1,1,0,0,0,1,1,0,0,1,1,...,1,1,TCP,1521118773289519,1500,0,15922,2,0,60
2,1,1,0,1,1,0,0,1,1,1,...,0,1,TCP,1521118773289529,1400,24,3559,2,0,55
3,1,1,0,0,0,1,0,1,0,0,...,0,0,TCP,1521118773289530,52,0,47404,2,0,59
4,1,0,0,0,1,1,0,0,1,1,...,1,1,TCP,1521118773289535,1500,0,15923,2,0,60


## Step 4: Initialize the Model

Create and configure the generative model based on the selected model type.


In [9]:
# Record start time
start_time = time.time()

# Initialize model based on model_name
if model_name == "realtabformer-tabular":
    from realtabformer import REaLTabFormer
    from transformers.models.gpt2 import GPT2Config
    
    # Configure the model architecture
    rtf_model = REaLTabFormer(
        model_type="tabular",
        tabular_config=GPT2Config(
            n_layer=getattr(current_config, 'n_layer', 12),
            n_head=getattr(current_config, 'n_head', 12),
            n_embd=getattr(current_config, 'n_embd', 768)
        ),
        checkpoints_dir=os.path.join(work_folder, "rtf_checkpoints"),
        samples_save_dir=os.path.join(work_folder, "rtf_samples"),
        gradient_accumulation_steps=4,
        epochs=current_config.epochs,
        batch_size=16,
        logging_steps=current_config.logging_steps,
        save_steps=current_config.save_steps,
        save_total_limit=current_config.save_total_limit,
        eval_steps=current_config.eval_steps
    )
    print("REaLTabFormer model initialized")

elif model_name == "ctgan":
    from ctgan import CTGAN
    
    discrete_columns = current_config['discrete_columns']
    ctgan = CTGAN(epochs=100, verbose=True)
    print(f"CTGAN model initialized with discrete columns: {discrete_columns}")

elif model_name == "tvae":
    from ctgan import TVAE
    
    discrete_columns = current_config['discrete_columns']
    tvae = TVAE(epochs=100)
    print(f"TVAE model initialized with discrete columns: {discrete_columns}")

else:
    raise ValueError(f"Model {model_name} not implemented in this tutorial")

print(f"Model configuration complete!")


TVAE model initialized with discrete columns: ['srcip_31', 'srcip_30', 'srcip_29', 'srcip_28', 'srcip_27', 'srcip_26', 'srcip_25', 'srcip_24', 'srcip_23', 'srcip_22', 'srcip_21', 'srcip_20', 'srcip_19', 'srcip_18', 'srcip_17', 'srcip_16', 'srcip_15', 'srcip_14', 'srcip_13', 'srcip_12', 'srcip_11', 'srcip_10', 'srcip_9', 'srcip_8', 'srcip_7', 'srcip_6', 'srcip_5', 'srcip_4', 'srcip_3', 'srcip_2', 'srcip_1', 'srcip_0', 'dstip_31', 'dstip_30', 'dstip_29', 'dstip_28', 'dstip_27', 'dstip_26', 'dstip_25', 'dstip_24', 'dstip_23', 'dstip_22', 'dstip_21', 'dstip_20', 'dstip_19', 'dstip_18', 'dstip_17', 'dstip_16', 'dstip_15', 'dstip_14', 'dstip_13', 'dstip_12', 'dstip_11', 'dstip_10', 'dstip_9', 'dstip_8', 'dstip_7', 'dstip_6', 'dstip_5', 'dstip_4', 'dstip_3', 'dstip_2', 'dstip_1', 'dstip_0', 'srcport_15', 'srcport_14', 'srcport_13', 'srcport_12', 'srcport_11', 'srcport_10', 'srcport_9', 'srcport_8', 'srcport_7', 'srcport_6', 'srcport_5', 'srcport_4', 'srcport_3', 'srcport_2', 'srcport_1', 'src

## Step 5: Train the Model and Generate Synthetic Data

Train the model on the dataset and generate synthetic samples.


In [10]:
# Train the model and generate synthetic data
if model_name == "realtabformer-tabular":
    print("Training REaLTabFormer model...")
    rtf_model.fit(df, num_bootstrap=current_config.num_bootstrap)
    
    # Save the trained model
    rtf_model.save(os.path.join(work_folder, "rtf_model"))
    print("Model saved!")
    
    # Generate synthetic samples
    print(f"\nGenerating {len(df)} synthetic samples...")
    syn_df = rtf_model.sample(n_samples=len(df), gen_batch=1024)

elif model_name == "ctgan":
    print("Training CTGAN model...")
    ctgan.fit(df, discrete_columns)
    
    # Save the trained model
    ctgan.save(os.path.join(work_folder, "model.pt"))
    print("Model saved!")
    
    # Generate synthetic samples
    print(f"\nGenerating {len(df)} synthetic samples...")
    syn_df = ctgan.sample(len(df))

elif model_name == "tvae":
    print("Training TVAE model...")
    tvae.fit(df, discrete_columns)
    
    # Save the trained model
    tvae.save(os.path.join(work_folder, "model.pt"))
    print("Model saved!")
    
    # Generate synthetic samples
    print(f"\nGenerating {len(df)} synthetic samples...")
    syn_df = tvae.sample(len(df))

print(f"\nSynthetic data generation complete!")
print(f"Synthetic data shape: {syn_df.shape}")
print(f"\nFirst few rows of synthetic data:")
syn_df.head()


Training TVAE model...
Start data transformation...
Data transformation finished!
Data dimension:  246
Model saved!

Generating 100000 synthetic samples...

Synthetic data generation complete!
Synthetic data shape: (100000, 104)

First few rows of synthetic data:


Unnamed: 0,srcip_31,srcip_30,srcip_29,srcip_28,srcip_27,srcip_26,srcip_25,srcip_24,srcip_23,srcip_22,...,dstport_1,dstport_0,proto,time,pkt_len,tos,id,flag,off,ttl
0,0,1,1,0,1,0,0,0,1,1,...,1,1,TCP,1521118773447938,1480,0,39179,2,0,85
1,1,0,0,0,0,1,1,1,0,0,...,1,0,TCP,1521118773516469,41,0,30465,2,0,52
2,0,0,1,0,1,1,0,0,0,1,...,0,0,TCP,1521118773405970,54,0,14386,2,0,240
3,0,0,1,0,1,1,0,0,0,1,...,1,0,TCP,1521118773417828,1460,0,5955,2,0,243
4,0,0,1,0,1,1,0,0,0,0,...,0,0,TCP,1521118773369949,1229,0,45401,2,0,244


## Step 6: Post-process and Save Results

Apply any necessary post-processing and save the synthetic data to a CSV file.


In [12]:
# Function to convert bit columns to decimal (for CTGAN/TVAE models)
def csv_bit2decimal(input_df):
    """Convert bit-encoded IP addresses and ports to decimal format"""
    df = input_df.copy(deep=True)
    
    # Convert srcip from bits to decimal
    srcip_cols = df.loc[:, [f"srcip_{31-i}" for i in range(32)]]
    srcip_decimal = srcip_cols.apply(lambda x: int(''.join(x.astype(str)), 2), axis=1)
    df["srcip"] = srcip_decimal

    # Convert dstip from bits to decimal
    dstip_cols = df.loc[:, [f"dstip_{31-i}" for i in range(32)]]
    dstip_decimal = dstip_cols.apply(lambda x: int(''.join(x.astype(str)), 2), axis=1)
    df["dstip"] = dstip_decimal

    # Convert srcport from bits to decimal
    srcport_cols = df.loc[:, [f"srcport_{15-i}" for i in range(16)]]
    srcport_decimal = srcport_cols.apply(lambda x: int(''.join(x.astype(str)), 2), axis=1)
    df["srcport"] = srcport_decimal

    # Convert dstport from bits to decimal
    dstport_cols = df.loc[:, [f"dstport_{15-i}" for i in range(16)]]
    dstport_decimal = dstport_cols.apply(lambda x: int(''.join(x.astype(str)), 2), axis=1)
    df["dstport"] = dstport_decimal

    # Drop the bit columns
    df = df.drop(columns=([f"srcip_{31-i}" for i in range(32)] +
                          [f"dstip_{31-i}" for i in range(32)] +
                          [f"srcport_{15-i}" for i in range(16)] +
                          [f"dstport_{15-i}" for i in range(16)]))

    return df

# Convert bit columns to decimal if needed (for CTGAN/TVAE)
if model_name in ['ctgan', 'tvae']:
    print("Converting bit columns to decimal format...")
    syn_df = csv_bit2decimal(syn_df)
    print(f"Converted data shape: {syn_df.shape}")

# Function to validate and fix synthetic dataframe
def check_and_fix_syndf(syn_df, dataset_type, order_by_timestamp=False):
    """Validate data types and reorder columns"""
    # Ensure correct data types
    if syn_df['srcip'].dtype != int or syn_df['dstip'].dtype != int:
        raise ValueError("srcip and dstip should be int")
    
    if dataset_type == 'pcap':
        # Add constant columns for PCAP
        syn_df['version'] = 4
        syn_df['ihl'] = 5
        
        # Ensure integer types for specific columns
        for col in ['time', 'pkt_len', 'tos', 'id', 'flag', 'off', 'ttl']:
            syn_df[col] = syn_df[col].astype(int)
        
        # Reorder columns (important for downstream processing)
        syn_df = syn_df[['srcip', 'dstip', 'srcport', 'dstport', 'proto', 
                         'time', 'pkt_len', 'version', 'ihl', 'tos', 'id', 'flag', 'off', 'ttl']]
        time_col = 'time'
    
    elif dataset_type == 'netflow':
        # Ensure integer types for specific columns
        for col in ['ts', 'pkt', 'byt']:
            syn_df[col] = syn_df[col].astype(int)
        
        # Reorder columns based on what's available
        if 'label' in syn_df.columns and 'type' in syn_df.columns:
            syn_df = syn_df[['srcip', 'dstip', 'srcport', 'dstport', 'proto', 
                             'ts', 'td', 'pkt', 'byt', 'label', 'type']]
        elif 'type' in syn_df.columns:
            syn_df = syn_df[['srcip', 'dstip', 'srcport', 'dstport', 'proto', 
                             'ts', 'td', 'pkt', 'byt', 'type']]
        time_col = 'ts'
    
    # Sort by timestamp if requested
    if order_by_timestamp:
        syn_df = syn_df.sort_values(by=[time_col])
    
    return syn_df

# Post-process the synthetic data
syn_df = check_and_fix_syndf(syn_df, dataset_type, order_csv_by_timestamp)

# Save synthetic data to CSV
output_csv_path = os.path.join(RESULT_PATH['csv'], f'{model_name}_{dataset_name}_{cur_time}.csv')
syn_df.to_csv(output_csv_path, index=False)
print(f"Synthetic data saved to: {output_csv_path}")

# Calculate and save execution time
end_time = time.time()
time_elapsed = end_time - start_time
time_file_path = os.path.join(RESULT_PATH['time'], f'{model_name}_{dataset_name}_{cur_time}.txt')
with open(time_file_path, 'w') as f:
    f.write(f"{time_elapsed:.2f} seconds\n")
    f.write(f"{time_elapsed / 3600:.2f} hours\n")
    f.write(f"start_time: {datetime.datetime.fromtimestamp(start_time).strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(f"end_time: {datetime.datetime.fromtimestamp(end_time).strftime('%Y-%m-%d %H:%M:%S')}\n")

print(f"\nExecution time: {time_elapsed:.2f} seconds ({time_elapsed / 3600:.2f} hours)")
print(f"Time log saved to: {time_file_path}")
print("\n=== Training and Generation Complete! ===")


Converting bit columns to decimal format...
Converted data shape: (100000, 12)
Synthetic data saved to: /home/lesley/generative-trace-tutorials/results/small-scale/csv/tvae_caida_20251107110332600.csv

Execution time: 425.86 seconds (0.12 hours)
Time log saved to: /home/lesley/generative-trace-tutorials/results/small-scale/time/tvae_caida_20251107110332600.txt

=== Training and Generation Complete! ===
