In [None]:
from data_generator.csv_data_processor import CSVDataProcessor
from utils.util_funcs import get_row_count, display_df
from pyspark.sql import SparkSession
from pyspark.sql.functions import to_date, col, floor, datediff, current_date, lower, lit, array, rand, collect_list, size, array, floor



spark = SparkSession.builder.appName("ETL").getOrCreate()

csv_reader = CSVDataProcessor(spark, "data/healthcare_dataset.csv")

# Read the CSV file
df = csv_reader.run()


In [None]:
from pyspark.sql.functions import lit, to_date, datediff, current_date, floor, date_add


start_date = to_date(lit("1935-01-01"))  # Start of the date range
range_days = 365 * 90  # Number of days in the range  365 * years

df = (df.withColumn("RandomDays", (rand() * range_days).cast("int"))
        .withColumn("DOB", date_add(start_date, "RandomDays")).drop("age") 
       .withColumn("Age", floor(datediff(current_date(), col("DOB")) / 365)))




In [None]:
df.show()

In [None]:
from constants.admission_types_tests_dataset import admission_mapping, admission_tests
# Flatten the mapping and create a DataFrame
flattened = [
    (top_level, sub_level, stay_type, admission_tests.get(sub_level, ["No tests"]))
    for top_level, sub_level_dict in admission_mapping.items()
    for sub_level, stay_types in sub_level_dict.items()
    for stay_type in stay_types
]



In [None]:
mapping_df = spark.createDataFrame(flattened, ["top_level_admission", "sub_level_admission", "stay_type", "possible_tests"])

display_df(mapping_df)

In [None]:
# create joining column to mapping_df
admission_types = list(admission_mapping.keys())

print(admission_types)

keys_array = array([lit(key) for key in admission_types])

df = df.withColumn("top_level_admission", keys_array[floor(rand() * len(admission_types))]).drop("admission_type")


In [None]:
df.show()

In [None]:

# Define constants and conditions
female_only = ['maternity', 'obstetrics']
is_female = lower(col('gender')) == 'female'
is_pediatric = col("Age") < 18
is_geriatric = (col("Age") >= 65) & (col("sub_level_admission") == "geriatrics")

In [None]:
from pyspark.sql.functions import concat, when
from data_generator.constants import ColConstants

df = (df.withColumn("is_female", is_female)
        .withColumn("is_pediatric", is_pediatric)
        .withColumn("top_level_admission", 
                    when(col("is_pediatric"), 
                            concat(lit(ColConstants.peds), 
                                   col("top_level_admission")
                                   )
                            ).otherwise(col("top_level_admission"))
                    )
      )

In [None]:
df.show(n=8000)

In [None]:
from pyspark.sql.types import StringType

df = df.select([col(c).cast(StringType()).alias(c) for c in df.columns])
df.write.csv('./temp_data/temp.csv', mode = 'overwrite', header=True)