In [None]:
import random
from data_generator.csv_data_processor import CSVDataProcessor
from utils.util_funcs import get_row_count, display_df, remove_data, verify_ranking, verify_ranking_counts
from utils.read_write import read_postgres_table
from spark_instance import spark
from pyspark.sql import Window

from pyspark.sql.functions import col, lower, lit, rand, array, concat, when, row_number, concat_ws, date_format, split, size
  

In [None]:
# csv_reader = CSVDataProcessor(spark, "data/healthcare_dataset.csv")
# 
# # Read the CSV file
# raw_df = csv_reader.run()


In [None]:
# min_age_days = 1 * 365  # Minimum age in days (18 years)
# max_age_days = 90 * 365  # Maximum age in days (90 years)
# 
# raw_df = (raw_df.withColumn("DOB", expr(f"date_sub(current_date(), CAST(round(rand() * ({max_age_days} - {min_age_days}) + {min_age_days}) AS INT))"))
#         .withColumn("Age", floor(datediff(current_date(), col("DOB")) / 365.25)))




In [None]:
# above is commented out as it's saved in the database:
df = read_postgres_table("dob_age_raw_data")

In [None]:
from constants.admission_types_dataset import admission_mapping, AdmissionTypes
flattened = [
    (
        top_level.name,
        sub_level_key.name,
        list(sub_level_info.get("stay_types")),
        sub_level_info.get("tests"),  
        random.choice(sub_level_info.get("doctors"))
    )
    for top_level, sub_level_dict in admission_mapping.items()
    for sub_level_key, sub_level_info in sub_level_dict.items()
]



In [None]:
flattened

In [None]:
from constants.condition_probabilities import condition_age_probability_dict

flattened_condition_probabilities = [
    (
        sub_admission,
        condition,
        gender if gender in ['male', 'female'] else None,
        float(age_range[0]),
        float(age_range[1]) if len(age_range) > 1 else float('inf'),
        float(probability) 
    )
    for sub_admission, conditions in condition_age_probability_dict.items()
    for condition, genders_or_age_prob_list in conditions.items()
    for gender, age_prob_list in (genders_or_age_prob_list.items() if isinstance(genders_or_age_prob_list, dict) else [(None, genders_or_age_prob_list)])
    for age_range, probability in age_prob_list
]



In [None]:
from pyspark.sql.types import StructType, StructField, StringType, FloatType

columns = ["condition_admission_type", "condition", "gender", "age_min", "age_max", "probability"]
schema = StructType([
    StructField("condition_admission_type", StringType(), True),
    StructField("condition", StringType(), True),
    StructField("condition_gender", StringType(), True),  # Assuming gender can be specific probability for gender 'male', 'female', or null
    StructField("age_min", FloatType(), False),
    StructField("age_max", FloatType(), False),
    StructField("probability", FloatType(), False)
])



condition_prob_df = spark.createDataFrame(data=flattened_condition_probabilities, schema=schema)


In [None]:
condition_prob_df.show()

In [None]:
df.show()

In [None]:
"""
STEP 1 
    From the flattened data list we wish to create a DataFrame. This contains all the possible combinations for the given 
    top level admissions, sub level admissions, stay types and list of tests available from the admission_mapping, stay_type and admission_tests lists or dictionaries in admission_types_test_dataset.py
"""

mapping_df = spark.createDataFrame(flattened, ["top_level_admission", "sub_level_admission", "stay_types", "possible_tests", "doctor"])

display_df(mapping_df, 45)

In [None]:
joined_tbl = mapping_df.join(condition_prob_df, on=[mapping_df.sub_level_admission == condition_prob_df.condition_admission_type], how="left")

In [None]:
joined_tbl.show(n=2000)

In [None]:
"""
STEP 2 
    Create a list of admission_types randomly assign this to the original patient in the original data set, 
    whilst dropping the original admission_type column. Then join with mapping_df on top_level_admission col to give access to possible 
    conditions, mappings and so on.
"""
admission_type_names = [member.name for member in AdmissionTypes]

keys_array = array([lit(name) for name in admission_type_names])

# Define constants and conditions
female_only = ["MATERNITY", 'obstetrics']
is_female = lower(col('gender')) == 'female'
is_pediatric = col("Age") < 18
is_geriatric = (col("Age") >= 65)

df = (df.withColumn("is_female", is_female)
        .withColumn("is_pediatric", is_pediatric)
        .withColumn("is_geriatric", is_geriatric)
        .withColumn("unique_id", concat_ws("_", "name", date_format("DOB", "yyyyMMdd")))
        .drop("doctor", "medical_condition", "test_results", "medication", "admission_type")
      )
get_row_count(df, True)



In [None]:
from constants.admission_types_dataset import SubAdmissionTypes
"""
Create Enum class df and join it on to main driver df. 
"""
enum_values = [e.name for e in SubAdmissionTypes]

enum_df = spark.createDataFrame(enum_values, StringType()).toDF("admission_type")


In [None]:
df_enum_cross = df.crossJoin(enum_df)

In [None]:
df_enum_cross.where((col("unique_id") == "Aaron Jones_20141222") & (col("admission_type") == "CARDIOLOGY")).show(n=2000)

In [None]:
df_joined_temp = df_enum_cross.join(joined_tbl, df_enum_cross.admission_type == joined_tbl.condition_admission_type, how="left")

In [None]:
df_joined_temp = df_joined_temp.where((col("unique_id") == "Aaron Jones_20141222") & (col("admission_type") == "CARDIOLOGY"))

In [None]:
from utils.conditions_creator import ConditionsCreator

In [None]:


con = ConditionsCreator(spark, df_enum_cross, joined_tbl)

new_df = con.runner()


In [None]:
new_df.show(n=200)

In [None]:

from data_generator.constants import properties, POSTGRES_URL

df_con.write.jdbc(url=POSTGRES_URL, table="chosen_conditions_2503", mode="overwrite", properties=properties)


In [None]:
"""
STEP 3 
    Create a row ranking using a unique cols, created from stay_name and unique_id. 
    Verification done below. 
    
"""

df =  df.withColumn('stay_name', 
                   when(col('stay_type') == 'out_patient', concat(col('name'), lit('_out_patient')))
                   .when(col('stay_type') == 'inpatient', concat(col('name'), lit('_inpatient')))
                   .when(col('stay_type') == 'day_patient', concat(col('name'), lit('_day_patient')))
                   .otherwise(col('name'))
                  )


# Define a window specification that partitions data by 'top_level_admission' (or another unique patient identifier if needed)
windowSpec = Window.partitionBy('stay_name', 'unique_id').orderBy(rand())

# Assign row numbers within each partition in a random order
ranked_df = df.withColumn("row_num", row_number().over(windowSpec))


In [None]:
ranked_df.createGlobalTempView("ranked_df")

unique_dobs_df = spark.sql("""
WITH NameCounts AS (
    SELECT name
    FROM global_temp.ranked_df
)

SELECT DISTINCT r.name, r.unique_id
FROM global_temp.ranked_df r
JOIN NameCounts n ON r.name = n.name
ORDER BY r.name
LIMIT 10
""")

In [None]:
# COSTLY WAY TO VERIFY THAT THE RANK WORKS BELOW....  due to .collect() and .count() in verify_ranking_counts() function

# Example usage:
unique_names = [row['name'] for row in unique_dobs_df.select("name").collect()]
unique_ids = [row['unique_id'] for row in unique_dobs_df.select("unique_id").collect()]
verify_ranking_counts(df, ranked_df, unique_names, unique_ids)

# this function is faster than above
verify_ranking(df, ranked_df)


In [None]:
"""
Step 4: 
    Sort out geriatrics data and verify that individuals are not geriatrics 
"""
ranked_df = ranked_df.withColumn("is_geriatric", is_geriatric)

not_geriatric_df = ranked_df.where((col('sub_level_admission') == "GERIATRICS") &( col("is_geriatric") == False))

filtered_df = remove_data(ranked_df, not_geriatric_df, (col('sub_level_admission') == "GERIATRICS"), ( col("is_geriatric") == False))


In [None]:
"""
Step 5: 
    Sort out geriatrics data and verify that individuals who are not female and should not have a female sub level admission 
"""

not_female_df = filtered_df.where((col("sub_level_admission").isin(female_only)) & (col("is_female") == False))
filtered_df_female = remove_data(filtered_df, not_female_df, (col("gender") == "Male"),  (col("sub_level_admission").isin(female_only)))


In [None]:
# filter pediatric patients who cannot be pregnant (based on legal age in the UK, 16) No assumption made an individual cannot choose to get pregnant before this age. 
# upper age bound defined  here: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4822427/#:~:text=Menopause%20typically%20occurs%20between%2045,reproducing%20many%20years%20before%20menopause.
# between 45 and 55 upper bounding will be 50. # Possiblilty to include outliers in a new func in the future.
under16_female_df = filtered_df_female.where((col("Age") < 16) &
                                             (col("Age") <= 50) &
                                             (col("sub_level_admission") == "MATERNITY")
                                             ).orderBy("Age")

In [None]:
df = remove_data(filtered_df_female, under16_female_df, (col("sub_level_admission") == "MATERNITY"),
                 (col("Age") < 16) & (col("Age") > 50))



In [None]:

split_col = split(df['name'], ' ')

df_renamed = (df.withColumn('name', 
                    when(size(split_col) == 4, concat(split_col.getItem(1).substr(1, 1), lit('. '), split_col.getItem(2).substr(1, 1), lit('. '), split_col.getItem(3)))
                    .when(size(split_col) == 3, concat(split_col.getItem(1).substr(1, 1), lit('. '), split_col.getItem(2)))
                    .otherwise(
                        concat(split_col.getItem(0).substr(1, 1), lit('. '), split_col.getItem(1))
                               )
                    )
      )


In [None]:
df_renamed.show()

In [None]:
(df_renamed.select("name", "DOB", "Age", "gender", "blood_type", "date_of_admission", "discharge_date", "top_level_admission", "sub_level_admission", "possible_tests", "conditions", "doctor", "hospital", "room_number", "insurance_provider", "billing_amount", "stay_type", "is_female", "is_geriatric", "is_pediatric", "stay_name", "row_num", "unique_id")
 # .where(col("stay_name") =="Abigail Lamb")
 .show(n=200))

In [None]:

from pyspark.sql.types import StringType
from pyspark.sql.functions import udf

sub_selection_cols = ["Age", "sub_level_admission", "conditions", "row_num", "unique_id"]

# df_subset = df_renamed.select(*sub_selection_cols)
# 
# choose_condition_udf = udf(choose_condition, StringType())
# df_test = df_subset.withColumn("assigned_condition", choose_condition_udf(df_subset["age"], df_subset["sub_level_admission"], df_subset["conditions"]))


In [None]:
# df_subset.persist()
# df_subset.show()  # The action that triggers the computation


In [None]:
# df_test.persist()
# df_test.show(n=20)

In [None]:
df_new = df_renamed.drop("row_num")

windowSpec = Window.partitionBy('stay_name', 'unique_id').orderBy(rand())

# Assign row numbers within each partition in a random order
df_new_part = df_new.withColumn("row_num", row_number().over(windowSpec))

df_new_part.select(["name", "DOB", "Age", "gender", "blood_type", "date_of_admission", "discharge_date", "top_level_admission", "sub_level_admission", "possible_tests", "conditions", "doctor", "hospital", "room_number", "insurance_provider", "billing_amount", "stay_type", "is_female", "is_geriatric", "is_pediatric", "stay_name", "row_num", "unique_id"]).where(col("stay_name") =="Tiffany Ramirez").show() 

In [None]:
get_row_count(df_new_part)

In [None]:
["name", "DOB", "Age", "gender", "blood_type", "date_of_admission", "discharge_date", "top_level_admission", "sub_level_admission", "possible_tests", "conditions", "doctor", "hospital", "room_number", "insurance_provider", "billing_amount", "stay_type", "is_female", "is_geriatric" "is_pediatric", "stay_name", "row_num", "unique_id"]

In [None]:
# join_with_condition_prob_df = df_new_part.join(condition_prob_df, 
#                                                on=["sub_level_admission"], 
#                                                how="inner")
# join_with_condition_prob_df = join_with_condition_prob_df.filter(
#     (col("Age") >= col("age_min")) &
#     (col("Age") <= col("age_max"))
# )

In [None]:
df_new_part.show(n=2000)

In [None]:
# TODO:  
# medical condition to be  chosen 
# tests to be chosen
# admission date to be checked again dob,  
# TODO filter on is pediatric, geriatric and is_female to be done here and same people with dob? needs  to be considered 
# drop stay_name and unique_id 


In [None]:
# Filter to keep only the top-ranked row within each partition
ranked_df = filtered_df_female

In [None]:
# from pyspark.sql.types import StringType
# get_row_count(df, True)
df = join_with_condition_prob_df.select([col(c).cast(StringType()).alias(c) for c in join_with_condition_prob_df.columns])
df.show()
df.repartition(10).write.csv('./temp_data/join_with_condition_prob_df/renamed.csv', mode = 'overwrite', header=True)

In [None]:
# TODO initialise first name to make data more realistic for name columns i.e. Daniel Mccoy is seen a Male name but here its Female

In [None]:
spark.stop()

In [None]:
dict_ = admission_mapping.get("emergency")