In [1]:
import pandas as pd
from ctgan import CTGAN
import pickle
import os
import time

In [2]:
# --- Configuration ---
INPUT_DATA_FILE = './dataset/housing.csv'
MODEL_FILE = 'ctgan_model.pkl'
SYNTHETIC_OUTPUT_FILE = 'synthetic_data.csv'

# Leave empty for auto-detection, or specify categorical columns
CATEGORICAL_COLUMNS = []

# 2. Detect categorical columns

In [3]:
def detect_categorical_columns(data: pd.DataFrame) -> list:
    detected_cols = []

    # 1. Object / Category dtype
    object_cols = data.select_dtypes(include=['object', 'category']).columns.tolist()
    detected_cols.extend(object_cols)

    # 2. Low-cardinality numeric columns
    numeric_cols = data.select_dtypes(include=['int64', 'float64']).columns
    n_rows = len(data)

    for col in numeric_cols:
        n_unique = data[col].nunique()
        if (n_unique <= 10) and ((n_unique / n_rows) < 0.05):
            detected_cols.append(col)

    return list(set(detected_cols))


# 3. Load dataset

In [4]:
def load_data(filepath: str) -> pd.DataFrame | None:
    global CATEGORICAL_COLUMNS

    if not os.path.exists(filepath):
        print(f"‚ùå ERROR: Input file not found at {filepath}")
        return None

    print(f"üìÇ Loading data from {filepath}...")
    try:
        data = pd.read_csv(filepath)
        print(f"‚úÖ Data loaded. Shape: {data.shape}")

        if CATEGORICAL_COLUMNS:
            active_categorical_cols = [col for col in CATEGORICAL_COLUMNS if col in data.columns]
            source_description = "Manual list"
        else:
            active_categorical_cols = detect_categorical_columns(data)
            source_description = "Auto-detected"

        for col in active_categorical_cols:
            data[col] = data[col].astype('category')

        if not active_categorical_cols:
            print("‚ùå ERROR: No categorical columns detected.")
            return None

        print(f"üîé Source: {source_description}")
        print(f"üìä Categorical Columns: {active_categorical_cols}")

        CATEGORICAL_COLUMNS = active_categorical_cols
        return data
    except Exception as e:
        print(f"‚ùå ERROR loading data: {e}")
        return None


# 4. Train and Save Model

In [5]:
def train_and_save_model(data: pd.DataFrame, model_path: str):
    print("üöÄ Training CTGAN model...")
    start_time = time.time()

    model = CTGAN(
        epochs=300,
        batch_size=500,
        verbose=True
    )
    model.fit(data, CATEGORICAL_COLUMNS)

    with open(model_path, 'wb') as f:
        pickle.dump(model, f)

    print(f"üíæ Model saved to {model_path}")
    print(f"‚è± Training time: {time.time() - start_time:.2f} seconds")
    return model

# 5. Load Saved Model

In [6]:
def load_model(model_path: str):
    if not os.path.exists(model_path):
        print("‚ùå ERROR: No model file found.")
        return None
    with open(model_path, 'rb') as f:
        model = pickle.load(f)
    print("‚úÖ Model loaded successfully")
    return model

# 6. Generate Synthetic Data

In [7]:
def generate_synthetic_data(model: CTGAN, num_samples: int, output_path: str):
    print(f"‚ú® Generating {num_samples} synthetic samples...")
    start_time = time.time()

    synthetic_data = model.sample(num_samples)
    synthetic_data.to_csv(output_path, index=False)

    print(f"üíæ Synthetic data saved to {output_path}")
    print(f"‚è± Generation time: {time.time() - start_time:.2f} seconds")
    display(synthetic_data.head())
    return synthetic_data

# 7. Run Pipeline

In [None]:
# Step 1: Load data
data = load_data(INPUT_DATA_FILE)

if data is not None:
    # Step 2: Train or Load Model
    if not os.path.exists(MODEL_FILE):
        ctgan_model = train_and_save_model(data, MODEL_FILE)
    else:
        ctgan_model = load_model(MODEL_FILE)

    # Step 3: Generate Synthetic Data
    if ctgan_model:
        synthetic_data = generate_synthetic_data(ctgan_model, num_samples=data.shape[0], output_path=SYNTHETIC_OUTPUT_FILE)

üìÇ Loading data from ./dataset/housing.csv...
‚úÖ Data loaded. Shape: (10000, 3)
üîé Source: Auto-detected
üìä Categorical Columns: ['address', 'name', 'building no']
üöÄ Training CTGAN model...


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Gen. (6.11) | Discrim. (-0.00):   0%|‚ñè                                             | 1/300 [05:53<29:19:11, 353.02s/it]

# 8. Compare Real vs Synthetic Data

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def compare_distributions(real_data: pd.DataFrame, synthetic_data: pd.DataFrame, categorical_cols: list, max_cols: int = 5):
    """
    Compare distributions of real vs synthetic data.
    
    Parameters:
    - real_data: Original dataset
    - synthetic_data: Generated dataset
    - categorical_cols: List of categorical columns
    - max_cols: Number of columns to visualize per type
    """

    # Separate numeric and categorical
    numeric_cols = [col for col in real_data.select_dtypes(include=['int64','float64']).columns if col not in categorical_cols]

    print("üìä Visualizing distributions...")
    plt.figure(figsize=(16, 10))

    # --- Numeric Features ---
    for i, col in enumerate(numeric_cols[:max_cols], 1):
        plt.subplot(2, max_cols, i)
        sns.kdeplot(real_data[col], label="Real", fill=True, alpha=0.5)
        sns.kdeplot(synthetic_data[col], label="Synthetic", fill=True, alpha=0.5)
        plt.title(f"Numeric: {col}")
        plt.legend()

    # --- Categorical Features ---
    for j, col in enumerate(categorical_cols[:max_cols], 1):
        plt.subplot(2, max_cols, max_cols+j)
        real_counts = real_data[col].value_counts(normalize=True)
        synth_counts = synthetic_data[col].value_counts(normalize=True)

        df_plot = pd.DataFrame({
            "Real": real_counts,
            "Synthetic": synth_counts
        }).fillna(0)

        df_plot.plot(kind="bar", ax=plt.gca())
        plt.title(f"Categorical: {col}")

    plt.tight_layout()
    plt.show()