In [None]:
from pyspark.sql import SparkSession
# import pyarrow

# Create a Spark session
spark = SparkSession.builder.appName("sdv_dem").getOrCreate()
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "True")


In [None]:

import pandas as pd
from sdv.datasets.local import load_csvs
from sdv.metadata import SingleTableMetadata
from sdv.single_table import GaussianCopulaSynthesizer

import hashlib
from pyspark.sql.functions import udf, col
from pyspark.sql.types import StringType

calculate_sha1_udf = udf(lambda value: hashlib.sha1(str(value).encode()).hexdigest(), StringType())


In [None]:
meds_file_path= '/home/harraz/my_tensorflow/venv/synthetic_phi/data/patient_test_data.csv'

real_data_raw = spark.read.option('delimiter', ",") \
    .option("header", "True") \
    .csv(meds_file_path)

real_data_raw.show(5)

In [None]:
real_data_raw = real_data_raw.withColumn("hashed_id", calculate_sha1_udf(col("patientid")))
real_data_raw.show(5)

In [None]:
real_data = real_data_raw.toPandas()

# Create metadata
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(real_data)

# Convert metadata to a dictionary
metadata_dict = metadata.to_dict()
# Access the 'columns' dictionary from the resulting dictionary
columns_dict = metadata_dict.get('columns', {})
# Extract column names from the 'columns' dictionary
found_column_names = columns_dict.keys()

# Update the PII flag
# metadata.update_column(
#     column_name='patientid',
#     sdtype='id',
#     regex_format='ID_[0-9]{5}')

print(metadata)

# Synthesize data
synthesizer = GaussianCopulaSynthesizer(metadata)

In [None]:
import warnings

# Suppress the FutureWarning from the specified module
warnings.filterwarnings("ignore", message="The behavior of Series.replace.*", module="rdt.transformers.categorical")
warnings.filterwarnings("ignore", message="Downcasting object dtype arrays.*", module="rdt.transformers.utils")


# Train the synthesizer
synthesizer.fit(real_data)

# Generate synthetic data
synthetic_data = synthesizer.sample(num_rows=10)

synthetic_data.drop(columns=['patientid'], inplace=True)
hashed_id_column = synthetic_data.pop('hashed_id')  # Remove the column and store it

# Insert the column back into the DataFrame at the first position
synthetic_data.insert(0, 'hashed_id', hashed_id_column)


In [None]:
synthetic_data.to_csv('./data/fake_patients.csv', sep='|', lineterminator='\n', mode='w', index=False)