In [7]:
# stdlib
import os
import sys
import warnings

# synthcity absolute
import synthcity.logger as log
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import TimeSeriesSurvivalDataLoader
from synthcity.utils.datasets.time_series.pbc import PBCDataloader
from synthcity.utils.serialization import save, load

# third party
import pandas as pd
import numpy as np

log.add(sink=sys.stderr, level="INFO")
warnings.filterwarnings("ignore")

In [8]:
SYNTHETIC_DATA_COUNT = 500
MODEL_OUTPUT_PATH = "../models/timegan_pbc_model.pkl"
DATA_OUTPUT_PATH = "../data/pbc_synthetic/synthetic_pbc_data_3.csv"

In [9]:
def get_pbc_data_loader():
    (
        static_surv,
        temporal_surv,
        temporal_surv_horizons,
        outcome_surv,
    ) = PBCDataloader().load()
    T, E = outcome_surv

    static_surv_extended = static_surv.copy()

    drug_values = []
    age_values = []

    for patient_df in temporal_surv:
        drug_values.append(patient_df["drug"].iloc[0])  # Assuming 'drug' is constant per patient
        age_values.append(patient_df["age"].iloc[0])    # Assuming 'age' is constant per patient

    static_surv_extended["drug"] = drug_values
    static_surv_extended["age"] = age_values

    for i in range(len(temporal_surv)):
        temporal_surv[i] = temporal_surv[i].drop(columns=["drug", "age"])
    
    horizons = [0.25, 0.5, 0.75]
    time_horizons = np.quantile(T, horizons).tolist()

    loader = TimeSeriesSurvivalDataLoader(
        temporal_data=temporal_surv,
        observation_times=temporal_surv_horizons,
        static_data=static_surv_extended,
        T=T,
        E=E,
        time_horizons=time_horizons,
    )
    
    return loader, time_horizons

In [10]:
loader, time_horizons = get_pbc_data_loader()

if os.path.exists(MODEL_OUTPUT_PATH):
    print(f"Loading model from {MODEL_OUTPUT_PATH}")
    with open(MODEL_OUTPUT_PATH, "rb") as f:
        syn_model = load(f.read())
else:
    print("Training new timegan model...")
    syn_model = Plugins().get("timegan")
    syn_model.fit(loader)
    with open(MODEL_OUTPUT_PATH, "wb") as f:
        f.write(save(syn_model))
    print(f"Model saved to {MODEL_OUTPUT_PATH}")

Loading model from ../models/timegan_pbc_model.pkl


In [11]:
print("Generating synthetic data...")
synthetic_data = syn_model.generate(count=SYNTHETIC_DATA_COUNT, time_horizons=time_horizons)
df_synth = synthetic_data.dataframe()

df_synth.to_csv(DATA_OUTPUT_PATH, index=False)
print(f"Synthetic dataset saved to {DATA_OUTPUT_PATH}")

Generating synthetic data...


[2025-05-15T21:52:20.428468-0400][14020][INFO] [seq_time_id] quality loss for constraints ge = 0.0027379257474500207. Remaining 2385. prev length 3161. Original dtype float64.
[2025-05-15T21:52:20.428468-0400][14020][INFO] [seq_time_id] quality loss for constraints ge = 0.0027379257474500207. Remaining 2385. prev length 3161. Original dtype float64.


Synthetic dataset saved to ../data/pbc_synthetic/synthetic_pbc_data_3.csv
