In [2]:
import random
import py4j 
from data_generator.csv_data_processor import CSVDataProcessor
from utils.util_funcs import get_row_count, display_df, remove_data, verify_ranking, verify_ranking_counts
from utils.read_write import read_postgres_table
from spark_instance import spark
from pyspark.sql import Window

from pyspark.sql.functions import col, lower, lit, rand, array, concat, when, row_number, concat_ws, date_format, split, size
  



In [3]:
# csv_reader = CSVDataProcessor(spark, "data/healthcare_dataset.csv")
# 
# # Read the CSV file
# raw_df = csv_reader.run()


In [4]:
# min_age_days = 1 * 365  # Minimum age in days (18 years)
# max_age_days = 90 * 365  # Maximum age in days (90 years)
# 
# raw_df = (raw_df.withColumn("DOB", expr(f"date_sub(current_date(), CAST(round(rand() * ({max_age_days} - {min_age_days}) + {min_age_days}) AS INT))"))
#         .withColumn("Age", floor(datediff(current_date(), col("DOB")) / 365.25)))




In [5]:
# above is commented out as it's saved in the database:
df = read_postgres_table("dob_age_raw_data")

In [6]:
from constants.admission_types_dataset import admission_mapping, AdmissionTypes
flattened = [
    (
        top_level.name,
        sub_level_key.name,
        stay_type,
        sub_level_info.get("tests"),  
        random.choice(sub_level_info.get("doctors"))
    )
    for top_level, sub_level_dict in admission_mapping.items()
    for sub_level_key, sub_level_info in sub_level_dict.items()
    for stay_type in (sub_level_info["stay_types"] if isinstance(sub_level_info, dict) else sub_level_info)
]



In [7]:
from constants.condition_probabilities import condition_age_probability_dict

flattened_condition_probabilities = [
    (
        sub_admission,
        condition,
        gender if gender in ['male', 'female'] else None,
        float(age_range[0]),
        float(age_range[1]) if len(age_range) > 1 else float('inf'),
        float(probability) 
    )
    for sub_admission, conditions in condition_age_probability_dict.items()
    for condition, genders_or_age_prob_list in conditions.items()
    for gender, age_prob_list in (genders_or_age_prob_list.items() if isinstance(genders_or_age_prob_list, dict) else [(None, genders_or_age_prob_list)])
    for age_range, probability in age_prob_list
]



In [8]:
from pyspark.sql.types import StructType, StructField, StringType, FloatType

columns = ["top_level_admission", "condition", "gender", "age_min", "age_max", "probability"]
schema = StructType([
    StructField("top_level_admission", StringType(), True),
    StructField("condition", StringType(), True),
    StructField("condition_gender", StringType(), True),  # Assuming gender can be specific probability for gender 'male', 'female', or null
    StructField("age_min", FloatType(), False),
    StructField("age_max", FloatType(), False),
    StructField("probability", FloatType(), False)
])



condition_prob_df = spark.createDataFrame(data=flattened_condition_probabilities, schema=schema)


In [9]:
df.show()

                                                                                

+--------------------+---+------+----------+-----------------+-----------------+------------------+--------------------+------------------+------------------+-----------+--------------+--------------+-----------+------------+----------+
|                name|Age|gender|blood_type|medical_condition|date_of_admission|            doctor|            hospital|insurance_provider|    billing_amount|room_number|admission_type|discharge_date| medication|test_results|       DOB|
+--------------------+---+------+----------+-----------------+-----------------+------------------+--------------------+------------------+------------------+-----------+--------------+--------------+-----------+------------+----------+
|     Tiffany Ramirez| 39|Female|        O-|         Diabetes|       2022-11-17|    Patrick Parker|    Wallace-Hamilton|          Medicare| 37490.98336352819|        146|      Elective|    2022-12-01|    Aspirin|Inconclusive|1984-09-30|
|         Ruben Burns| 16|  Male|        O+|        

In [10]:
"""
STEP 1 
    From the flattened data list we wish to create a DataFrame. This contains all the possible combinations for the given 
    top level admissions, sub level admissions, stay types and list of tests available from the admission_mapping, stay_type and admission_tests lists or dictionaries in admission_types_test_dataset.py
"""
mapping_df = spark.createDataFrame(flattened, ["mapping_top_level_admission", "sub_level_admission", "stay_type", "possible_tests", "doctor"])

display_df(mapping_df, 45)

                                                                                

Unnamed: 0,mapping_top_level_admission,sub_level_admission,stay_type,possible_tests,doctor
0,EMERGENCY,INJURY_RTC,Inpatient,"[X-rays, CT scans, MRI, Ultrasound, Blood tests]",Dr. Anthony Perry
1,EMERGENCY,INJURY_RTC,Day Patient,"[X-rays, CT scans, MRI, Ultrasound, Blood tests]",Dr. Aaron Long
2,EMERGENCY,SELF_INFLICTED,Inpatient,"[Psychological assessment, X-rays (for physica...",Dr. Lance Case
3,EMERGENCY,CARDIOLOGY,Inpatient,"[ECG, Echocardiogram, Stress tests, Cardiac ca...",Dr. Dawn Hawkins
4,EMERGENCY,CARDIOLOGY,Day Patient,"[ECG, Echocardiogram, Stress tests, Cardiac ca...",Dr. Amanda Young
5,EMERGENCY,NEUROLOGY,Inpatient,"[MRI or CT scans of the brain, Electroencephal...",Dr. Michael Johnson
6,EMERGENCY,NEUROLOGY,Day Patient,"[MRI or CT scans of the brain, Electroencephal...",Dr. Cameron Barrera
7,EMERGENCY,NEUROLOGY,Outpatient,"[MRI or CT scans of the brain, Electroencephal...",Dr. Maria Camacho
8,EMERGENCY,GASTROENTEROLOGY,Inpatient,"[Endoscopy, Colonoscopy, Blood tests, Stool te...",Dr. Michael Pearson
9,EMERGENCY,GASTROENTEROLOGY,Day Patient,"[Endoscopy, Colonoscopy, Blood tests, Stool te...",Dr. Zachary Hernandez


In [11]:
joined_tbl = mapping_df.join(condition_prob_df, on=[mapping_df.sub_level_admission == condition_prob_df.top_level_admission], how="left")

In [12]:
joined_tbl.show(n=2000)

+---------------------------+--------------------+-----------+--------------------+--------------------+-------------------+--------------------+----------------+-------+--------+-----------+
|mapping_top_level_admission| sub_level_admission|  stay_type|      possible_tests|              doctor|top_level_admission|           condition|condition_gender|age_min| age_max|probability|
+---------------------------+--------------------+-----------+--------------------+--------------------+-------------------+--------------------+----------------+-------+--------+-----------+
|          HOSPITAL_REFERRAL|             UROLOGY|  Inpatient|[Urinalysis, Bloo...|    Dr. Dwayne Bates|               NULL|                NULL|            NULL|   NULL|    NULL|       NULL|
|          HOSPITAL_REFERRAL|             UROLOGY|Day Patient|[Urinalysis, Bloo...|     Dr. John Thomas|               NULL|                NULL|            NULL|   NULL|    NULL|       NULL|
|          HOSPITAL_REFERRAL|           

In [13]:
"""
STEP 2 
    Create a list of admission_types randomly assign this to the original patient in the original data set, 
    whilst dropping the original admission_type column. Then join with mapping_df on top_level_admission col to give access to possible 
    conditions, mappings and so on.
"""
admission_type_names = [member.name for member in AdmissionTypes]

keys_array = array([lit(name) for name in admission_type_names])

# Define constants and conditions
female_only = ["MATERNITY", 'obstetrics']
is_female = lower(col('gender')) == 'female'
is_pediatric = col("Age") < 18
is_geriatric = (col("Age") >= 65)

df = (df.withColumn("is_female", is_female)
        .withColumn("is_pediatric", is_pediatric)
        .withColumn("is_geriatric", is_geriatric)
        .withColumn("unique_id", concat_ws("_", "name", date_format("DOB", "yyyyMMdd")))
        .drop("doctor", "medical_condition", "test_results", "medication")
      )
get_row_count(df, True)

"""
1. udf function that takes in, age, is_female, is_pediatric, is_geriatric, along with mapping_df 
2. in the function choose_condition_for_patient this should loop over the mapping df, filtering on the age of the individual, gender specific (is_female) and if the probability is == 0 then throw that condition out, also orderBy random(seed=1234567), to reduce repetition of condition selection.
3. then loop over this as it does now 

"""



10000


'\n1. udf function that takes in, age, is_female, is_pediatric, is_geriatric, along with mapping_df \n2. in the function choose_condition_for_patient this should loop over the mapping df, filtering on the age of the individual, gender specific (is_female) and if the probability is == 0 then throw that condition out, also orderBy random(seed=1234567), to reduce repetition of condition selection.\n3. then loop over this as it does now \n\n'

In [14]:
df.show()

+--------------------+---+------+----------+-----------------+--------------------+------------------+------------------+-----------+--------------+--------------+----------+---------+------------+------------+--------------------+
|                name|Age|gender|blood_type|date_of_admission|            hospital|insurance_provider|    billing_amount|room_number|admission_type|discharge_date|       DOB|is_female|is_pediatric|is_geriatric|           unique_id|
+--------------------+---+------+----------+-----------------+--------------------+------------------+------------------+-----------+--------------+--------------+----------+---------+------------+------------+--------------------+
|     Tiffany Ramirez| 39|Female|        O-|       2022-11-17|    Wallace-Hamilton|          Medicare| 37490.98336352819|        146|      Elective|    2022-12-01|1984-09-30|     true|       false|       false|Tiffany Ramirez_1...|
|         Ruben Burns| 16|  Male|        O+|       2023-06-01|Burke, Gri

In [15]:
joined_tbl.show()

+---------------------------+-------------------+-----------+--------------------+--------------------+-------------------+--------------------+----------------+-------+-------+-----------+
|mapping_top_level_admission|sub_level_admission|  stay_type|      possible_tests|              doctor|top_level_admission|           condition|condition_gender|age_min|age_max|probability|
+---------------------------+-------------------+-----------+--------------------+--------------------+-------------------+--------------------+----------------+-------+-------+-----------+
|                  EMERGENCY|          MATERNITY|  Inpatient|[Ultrasound, Bloo...|    Dr. Miguel Logan|               NULL|                NULL|            NULL|   NULL|   NULL|       NULL|
|                  EMERGENCY|          MATERNITY|Day Patient|[Ultrasound, Bloo...|Dr. Christopher J...|               NULL|                NULL|            NULL|   NULL|   NULL|       NULL|
|                  EMERGENCY|          MATERNITY| 

In [56]:
import json
import random
from pprint import pprint
from typing import List, Union, Tuple

import numpy as np
from IPython.core.display_functions import display
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import col, count, lit, when, rand, struct, collect_list
from faker import Faker
from spark_instance import spark

from constants.type_constants import SubAdmissionTypes


def find_probability_for_age_gender(age: int,
                                    condition_probabilities: Tuple[Tuple[int, int], float]) -> float | int:
    """
    This function returns probability based on age and a given condition probability list.
    Args:
        age: The age we will be comparing to
        condition_probabilities: The list of conditional probabilities

    Returns:
        int: the probability, if 0 returns, edge case, should be investigated.

    """
    (age_min, age_max), prob = condition_probabilities
    if age_min <= age <= age_max:
        return prob
    return 0


def filter_female_conditions(df):
    """
    Filters out entries from a DataFrame based on gender and age criteria:
    - Excludes female-only conditions (MATERNITY, OBSTETRICS) for non-female subjects.
    - Applies age restrictions specifically for MATERNITY-related entries.
    - Filter pediatric patients who cannot be pregnant (based on legal age in the UK, 16)
      No assumption made an individual cannot choose to get pregnant before this age.
      upper age bound defined  here: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4822427/#:~:text=Menopause%20typically%20occurs%20between%2045,reproducing%20many%20years%20before%20menopause.
      between 45 and 55 upper bounding will be 50.
      # TODO Possibility to include outliers in a new func in the future.
    Args:
        df (DataFrame): The input DataFrame with condition and demographic data.

    Returns:
        DataFrame: The filtered DataFrame.
    """

    female_only = [SubAdmissionTypes.MATERNITY.name, SubAdmissionTypes.OBSTETRICS.name]  #  TODO Obstetrics needs sorting still.

    filtered_df = df.filter((~(col("top_level_admission").isin(female_only))) & (~col("is_female")))

    filtered_df = filtered_df.withColumn("Age_Filter",
                       when((col("top_level_admission") == SubAdmissionTypes.MATERNITY.name) &
                            ((col("Age") < 16) | (col("Age") > 50)),
                            False).otherwise(True))
    
    # Filter rows based on the Age_Filter col
    filtered_df =  filtered_df.filter(col("Age_Filter")).drop("Age_Filter")
    
    return filtered_df.filter(
                     (~col("is_female") & (col("condition_gender") != "female")) | 
                     (col("condition_gender").isNull()) | col("is_female")
                     )
    
    


def filter_geriatric_conditions(df: DataFrame) -> DataFrame:
    """
    This function filters all GERIATRIC submission types
    Args:
        df (DataFrame): The unfiltered dataframe

    Returns:
        df (DataFrame): The filtered dataframe
    """
    return df.filter(~(col("sub_level_admission") == SubAdmissionTypes.GERIATRICS.name) & ~col("is_geriatric"))


def choose_condition_for_patient(probability_df: DataFrame,
                                 driver_df: DataFrame
                                 ) -> str | None:
    """
    This function chooses a condition for a patient based on the age and looking at the condition_age_probability_dict
    Args:
        age (int): The age of the individual we will be comparing to
        spark (SparkSession): The current SparkSession.
        is_pediatric (col): The col of the row we're looking at for the patient if True the pediatric patient.
        is_female (col): The col of the row we're looking at for the patient if True then Female else gender Male.
        is_geriatric (col): The col of the row we're looking at for the patient if True the patient is geriatric.
    Returns:
        A condition that will be assigned to the patient or None if issues found.
    """

    df_joined = driver_df.crossJoin(probability_df)

    df = df_joined.filter((col("age") >= col("age_min")) & (col("age") <= col("age_max")))

    df = filter_female_conditions(df)
    df = filter_geriatric_conditions(df)

    df = df.orderBy(col("unique_id"))
    print(df.count())
    conditions_probabilities = []
    df_transformed = df.withColumn("probability_entry", 
                               struct(col("age_min"), col("age_max"), col("probability")))
    df_transformed = df_transformed.withColumn("row_info", 
                                           struct(col("Age"), col("condition"), col("is_pediatric"), 
                                                  col("unique_id"), col("top_level_admission")))
    df_aggregated = df_transformed.groupBy("unique_id").agg(
                                                            collect_list("probability_entry").alias("probability_entries"),
                                                            collect_list("row_info").alias("row_infos")
                                                        )
    for row in df_aggregated.toLocalIterator():
        unique_id = row["unique_id"]
        probability_entries = [((entry.age_min, entry.age_max), entry.probability) for entry in row["probability_entries"]]
        row_infos = [(entry.Age, entry.condition, entry.is_pediatric, entry.unique_id, entry.top_level_admission) for entry in row["row_infos"]]
        assert len(probability_entries) == len(row_infos), f"Mismatch in lengths for {unique_id}"

        for entry in probability_entries:
            assert isinstance(entry, tuple) and len(entry) == 2, f"Invalid entry structure in probability_entries for {unique_id}"
            assert isinstance(entry[0], tuple) and len(entry[0]) == 2, f"Invalid age range structure in probability_entries for {unique_id}"
            assert isinstance(entry[1], float), f"Probability is not a float for {unique_id}"
    
        for info in row_infos:
            assert isinstance(info, tuple) and len(info) == 5, f"Invalid entry structure in row_infos for {unique_id}"

        for probability_info, patient_info in list(zip(probability_entries, row_infos)):
            age_prob = probability_info
            age, condition, is_pediatric, unique_id, top_level_admission = patient_info

            prob = find_probability_for_age_gender(age, age_prob)
            if prob > 0:
                condition_label = f"pediatric_{condition}" if is_pediatric else condition
            else:
                condition_label = "pediatric no condition for patient edge case" if is_pediatric else "no condition for patient edge case"
                prob = 0 
            conditions_probabilities.append((f"{unique_id}_{top_level_admission}", condition_label, prob))   
    print("before write")
    with open('output_list.txt', 'w') as f:
        f.write(json.dumps(conditions_probabilities))
    # print(conditions_probabilities)
    raise Exception()
    # If no condition is applicable based on age, return None or handle as appropriate
    if not conditions_probabilities:
        return None

    # Sort the conditions by probability for easier handling (optional)
    conditions_probabilities.sort(key=lambda x: x[1], reverse=True)

    # Use cumulative probabilities to select a condition
    total_prob = sum(prob for _, prob in conditions_probabilities)
    random_prob = np.random.uniform(0, total_prob)
    cumulative_prob = 0
    for condition, prob in conditions_probabilities:
        cumulative_prob += prob
        if random_prob < cumulative_prob:
            return condition

    return None

In [57]:
# from utils.util_funcs import choose_condition_for_patient

df = choose_condition_for_patient(joined_tbl, df)

                                                                                

695448


[Stage 530:>                                                        (0 + 1) / 1]

before write


Exception: 

In [None]:
df.show()

In [None]:
"""
STEP 3 
    Create a row ranking using a unique cols, created from stay_name and unique_id. 
    Verification done below. 
    
"""

df =  df.withColumn('stay_name', 
                   when(col('stay_type') == 'out_patient', concat(col('name'), lit('_out_patient')))
                   .when(col('stay_type') == 'inpatient', concat(col('name'), lit('_inpatient')))
                   .when(col('stay_type') == 'day_patient', concat(col('name'), lit('_day_patient')))
                   .otherwise(col('name'))
                  )


# Define a window specification that partitions data by 'top_level_admission' (or another unique patient identifier if needed)
windowSpec = Window.partitionBy('stay_name', 'unique_id').orderBy(rand())

# Assign row numbers within each partition in a random order
ranked_df = df.withColumn("row_num", row_number().over(windowSpec))


In [None]:
ranked_df.createGlobalTempView("ranked_df")

unique_dobs_df = spark.sql("""
WITH NameCounts AS (
    SELECT name
    FROM global_temp.ranked_df
)

SELECT DISTINCT r.name, r.unique_id
FROM global_temp.ranked_df r
JOIN NameCounts n ON r.name = n.name
ORDER BY r.name
LIMIT 10
""")

In [None]:
# COSTLY WAY TO VERIFY THAT THE RANK WORKS BELOW....  due to .collect() and .count() in verify_ranking_counts() function

# Example usage:
unique_names = [row['name'] for row in unique_dobs_df.select("name").collect()]
unique_ids = [row['unique_id'] for row in unique_dobs_df.select("unique_id").collect()]
verify_ranking_counts(df, ranked_df, unique_names, unique_ids)

# this function is faster than above
verify_ranking(df, ranked_df)


In [None]:
"""
Step 4: 
    Sort out geriatrics data and verify that individuals are not geriatrics 
"""
ranked_df = ranked_df.withColumn("is_geriatric", is_geriatric)

not_geriatric_df = ranked_df.where((col('sub_level_admission') == "GERIATRICS") &( col("is_geriatric") == False))

filtered_df = remove_data(ranked_df, not_geriatric_df, (col('sub_level_admission') == "GERIATRICS"), ( col("is_geriatric") == False))


In [None]:
"""
Step 5: 
    Sort out geriatrics data and verify that individuals who are not female and should not have a female sub level admission 
"""

not_female_df = filtered_df.where((col("sub_level_admission").isin(female_only)) & (col("is_female") == False))
filtered_df_female = remove_data(filtered_df, not_female_df, (col("gender") == "Male"),  (col("sub_level_admission").isin(female_only)))


In [None]:
# filter pediatric patients who cannot be pregnant (based on legal age in the UK, 16) No assumption made an individual cannot choose to get pregnant before this age. 
# upper age bound defined  here: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4822427/#:~:text=Menopause%20typically%20occurs%20between%2045,reproducing%20many%20years%20before%20menopause.
# between 45 and 55 upper bounding will be 50. # Possiblilty to include outliers in a new func in the future.
under16_female_df = filtered_df_female.where((col("Age") < 16) &
                                             (col("Age") <= 50) &
                                             (col("sub_level_admission") == "MATERNITY")
                                             ).orderBy("Age")

In [None]:
df = remove_data(filtered_df_female, under16_female_df, (col("sub_level_admission") == "MATERNITY"),
                 (col("Age") < 16) & (col("Age") > 50))



In [None]:

split_col = split(df['name'], ' ')

df_renamed = (df.withColumn('name', 
                    when(size(split_col) == 4, concat(split_col.getItem(1).substr(1, 1), lit('. '), split_col.getItem(2).substr(1, 1), lit('. '), split_col.getItem(3)))
                    .when(size(split_col) == 3, concat(split_col.getItem(1).substr(1, 1), lit('. '), split_col.getItem(2)))
                    .otherwise(
                        concat(split_col.getItem(0).substr(1, 1), lit('. '), split_col.getItem(1))
                               )
                    )
      )


In [None]:
df_renamed.show()

In [None]:
(df_renamed.select("name", "DOB", "Age", "gender", "blood_type", "date_of_admission", "discharge_date", "top_level_admission", "sub_level_admission", "possible_tests", "conditions", "doctor", "hospital", "room_number", "insurance_provider", "billing_amount", "stay_type", "is_female", "is_geriatric", "is_pediatric", "stay_name", "row_num", "unique_id")
 # .where(col("stay_name") =="Abigail Lamb")
 .show(n=200))

In [None]:

from pyspark.sql.types import StringType
from pyspark.sql.functions import udf

sub_selection_cols = ["Age", "sub_level_admission", "conditions", "row_num", "unique_id"]

# df_subset = df_renamed.select(*sub_selection_cols)
# 
# choose_condition_udf = udf(choose_condition, StringType())
# df_test = df_subset.withColumn("assigned_condition", choose_condition_udf(df_subset["age"], df_subset["sub_level_admission"], df_subset["conditions"]))


In [None]:
# df_subset.persist()
# df_subset.show()  # The action that triggers the computation


In [None]:
# df_test.persist()
# df_test.show(n=20)

In [None]:
df_new = df_renamed.drop("row_num")

windowSpec = Window.partitionBy('stay_name', 'unique_id').orderBy(rand())

# Assign row numbers within each partition in a random order
df_new_part = df_new.withColumn("row_num", row_number().over(windowSpec))

df_new_part.select(["name", "DOB", "Age", "gender", "blood_type", "date_of_admission", "discharge_date", "top_level_admission", "sub_level_admission", "possible_tests", "conditions", "doctor", "hospital", "room_number", "insurance_provider", "billing_amount", "stay_type", "is_female", "is_geriatric", "is_pediatric", "stay_name", "row_num", "unique_id"]).where(col("stay_name") =="Tiffany Ramirez").show() 

In [None]:
get_row_count(df_new_part)

In [None]:
["name", "DOB", "Age", "gender", "blood_type", "date_of_admission", "discharge_date", "top_level_admission", "sub_level_admission", "possible_tests", "conditions", "doctor", "hospital", "room_number", "insurance_provider", "billing_amount", "stay_type", "is_female", "is_geriatric" "is_pediatric", "stay_name", "row_num", "unique_id"]

In [None]:
# join_with_condition_prob_df = df_new_part.join(condition_prob_df, 
#                                                on=["sub_level_admission"], 
#                                                how="inner")
# join_with_condition_prob_df = join_with_condition_prob_df.filter(
#     (col("Age") >= col("age_min")) &
#     (col("Age") <= col("age_max"))
# )

In [None]:
df_new_part.show(n=2000)

In [None]:
# TODO:  
# medical condition to be  chosen 
# tests to be chosen
# admission date to be checked again dob,  
# TODO filter on is pediatric, geriatric and is_female to be done here and same people with dob? needs  to be considered 
# drop stay_name and unique_id 


In [None]:
# Filter to keep only the top-ranked row within each partition
ranked_df = filtered_df_female

In [None]:
# from pyspark.sql.types import StringType
# get_row_count(df, True)
df = join_with_condition_prob_df.select([col(c).cast(StringType()).alias(c) for c in join_with_condition_prob_df.columns])
df.show()
df.repartition(10).write.csv('./temp_data/join_with_condition_prob_df/renamed.csv', mode = 'overwrite', header=True)

In [None]:
# TODO initialise first name to make data more realistic for name columns i.e. Daniel Mccoy is seen a Male name but here its Female

In [None]:
spark.stop()

In [None]:
dict_ = admission_mapping.get("emergency")