In [1]:
import random

from src.data_generator.csv_data_processor import CSVDataProcessor
from src.utils.statistcs.statistical_functions import create_distributed_age_df
from spark_instance import spark
from src.utils.statistcs.data_visualisations import plot_age_distribution_with_sd, plot_kernel_density_age_distribution
from pyspark.sql.types import StructType, StructField, StringType, FloatType

from pyspark.sql.functions import col
  

24/04/24 21:08:58 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [None]:
average_age = 40.2

csv_age_file_file = CSVDataProcessor(spark, "data/uk_age_population.csv")

csv_age_uk_df = csv_age_file_file.runner()

In [None]:
csv_age_uk_df.show()

In [None]:
plot_age_distribution_with_sd(csv_age_uk_df, average_age)


In [None]:
plot_kernel_density_age_distribution(csv_age_uk_df)


In [None]:
from pyspark.sql import SparkSession
from typing import Dict
from pprint import pprint

In [39]:
from pyspark.sql import DataFrame
from typing import Dict, List

def get_config(country: str, file_path: str) -> Dict:
    """
    
    Args:
        country: 
        file_path: 

    Returns:

    """
    type_dict = {
    "ethnicity": {
        "uk": {
                "file_path": "data/uk/ethnicity.csv",
                "columns": ["ethnic_group", "population"]
        }
    },
    "population_distribution":
                            {
                                "uk": {
                                        "file_path": "data/uk/population_distrubition_2011_census.csv",
                                        "columns": ["postcode_area", "population"]
                            }
        }
    }
    return type_dict[file_path][country]

def normalize_data(df: DataFrame, required_columns: List) -> Dict:
    """
    
    Args:
        df: 
        required_columns: 

    Returns:

    """
    data_collected = df.rdd.map(lambda row: (row[required_columns[0]], row[required_columns[1]])).collect()
    dict_from_df = dict(data_collected)
    total_population = sum(dict_from_df.values())
    return {key: (value / total_population) * 100 for key, value in dict_from_df.items()}
         
        
def get_normalised_data_dict(spark: SparkSession, country: str, file_type: str) -> Dict:
    """
    
    Args:
        spark: 
        country: 
        file_type: 

    Returns:

    """
    config = get_config(country, file_type)
    data_processor_ethnicity = CSVDataProcessor(spark, config["file_path"])
    df = data_processor_ethnicity.runner()
    required_columns = config["columns"]
    if not all(col in df.columns for col in required_columns):
        raise ValueError("Required columns are missing in the data")
    return normalize_data(df, required_columns)


In [40]:
get_normalised_data_dict(spark, "uk", "population_distribution")

{'BS': 1.676434615989839,
 'B': 3.3959198024278234,
 'BN': 1.431698516111517,
 'BD': 1.031412560887106,
 'BL': 0.6775262076878928,
 'CA': 0.5672917098521733,
 'BR': 0.5334857505304595,
 'BH': 0.9845760511215582,
 'AL': 0.4456655114231579,
 'CH': 1.1767619579686908,
 'BA': 0.7749120513635159,
 'CB': 0.7510372724744985,
 'CF': 1.7927572894400718,
 'BB': 0.8712546663529966,
 'GL': 1.0797113741101527,
 'EX': 0.9761785060223362,
 'E': 1.7639071122017596,
 'FY': 0.4932474392926503,
 'CV': 1.4641456031958964,
 'DN': 1.348275887158108,
 'DL': 0.6432958237041245,
 'CT': 0.8606672326613253,
 'EN': 0.6148451049712754,
 'DE': 1.3020296486662581,
 'EC': 0.06055362951564657,
 'CR': 0.7226561023207255,
 'DY': 0.7326621812231962,
 'CM': 1.165609219159913,
 'DA': 0.7670655450062052,
 'DG': 0.00011591429846027292,
 'CW': 0.551907207501146,
 'DT': 0.3794784470023421,
 'DH': 0.5510922408181252,
 'CO': 0.7336679606744514,
 'LA': 0.5865727159283651,
 'HR': 0.314409509737443,
 'LN': 0.5227574363837364,
 'LL'

In [52]:
import numpy as np

def adjust_postcode_weights() -> Dict|None:
    """
    Adjusts postcode weights by setting zero weights to the lower quartile of the non-zero weights
    and scales the non-zero weights so that the total sums to 100.

    Returns:
    dict: Adjusted postcode weights where all weights sum to 100.
    """
    postcode_weights = get_normalised_data_dict(spark, "uk", "population_distribution")

    non_zero_weights = [weight for weight in postcode_weights.values() if weight > 0]
    if not non_zero_weights:
        print("No non-zero weights found in the data.")
        return 

    lower_quartile = np.percentile(non_zero_weights, 10)

    # Count zero weights and calculate the total weight to allocate to them
    zero_count = sum(1 for weight in postcode_weights.values() if weight == 0)
    total_allocated_to_zeros = lower_quartile * zero_count

    # Scale the non-zero weights to ensure the total sums to 100
    scale_factor = (100 - total_allocated_to_zeros) / sum(non_zero_weights)

    # Set zero weights to the lower quartile and adjust non-zero weights
    adjusted_weights = {}
    for postcode, weight in postcode_weights.items():
        if weight == 0:
            adjusted_weights[postcode] = lower_quartile
        else:
            adjusted_weights[postcode] = weight * scale_factor

    return adjusted_weights

In [53]:
_dict = adjust_postcode_weights()

In [54]:
_dict

{'BS': 1.5555037766101003,
 'B': 3.1509526392251033,
 'BN': 1.3284218946192647,
 'BD': 0.9570108600720476,
 'BL': 0.6286523582600757,
 'CA': 0.5263697066966346,
 'BR': 0.49500236502085776,
 'BH': 0.9135529362564331,
 'AL': 0.41351710320910945,
 'CH': 1.0918753718949032,
 'BA': 0.7190132027457103,
 'CB': 0.6968606485253537,
 'CF': 1.6634354287791806,
 'BB': 0.8084060726108709,
 'GL': 1.001825603013797,
 'EX': 0.9057611542259834,
 'E': 1.6366663802150239,
 'FY': 0.4576666738475586,
 'CV': 1.3585283873022187,
 'DN': 1.251017018130746,
 'DL': 0.5968912081062254,
 'CT': 0.7985823712061609,
 'EN': 0.5704928028155495,
 'DE': 1.2081067859379984,
 'EC': 0.05618554908174719,
 'CR': 0.6705267748760426,
 'DY': 0.6798110579452405,
 'CM': 1.0815271440719192,
 'DA': 0.7117327098737677,
 'DG': 0.000107552735608245,
 'CW': 0.5120949766952142,
 'DT': 0.3521044911768016,
 'DH': 0.5113387982310147,
 'CO': 0.6807442847588259,
 'LA': 0.5442598632719631,
 'HR': 0.2917293493787455,
 'LN': 0.4850479456759469,


In [41]:
get_normalised_data_dict(spark, "uk", "ethnicity")

{'Black African': 2.2933022069769655,
 'Black Caribbean': 0.9512288672018454,
 'Black other': 0.449862040806843,
 'Mixed': 2.5953725700404577,
 'Mixed White/Asian': 0.7375730636240838,
 'Mixed White/Black African': 0.3770637359371563,
 'Mixed White/Black Caribbean': 0.7750569089886282,
 'Mixed other': 0.7056788614905896,
 'Gypsy Or Irish Traveller': 0.10872460370147081,
 'Roma': 0.1525277673458753,
 'White English': 67.00780309662912,
 'White Scottish': 6.716149713548157,
 'White Irish': 0.8483613031778828,
 'White other': 5.788016647891257,
 'Arab': 0.515488984482531,
 'Any Other': 1.403053325997103,
 'Mixed or multiple ethnic groups': 0.029934805573853245,
 'Bangladeshi': 0.9799822941252442,
 'Chinese': 0.7241639627271108,
 'Indian': 2.865840298847103,
 'Pakistani': 2.4733461261634746,
 'Asian other': 1.501468814723253}

In [None]:
from pyspark import RDD
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import col, lit, rand, sum
from pyspark.sql import Row

from src.data_generator.csv_data_processor import CSVDataProcessor
from utils.column_creator_functions import choose_blood_type, random_gender_chooser, generate_name, dob_creator, \
    check_patient_is_pediatric, check_gender_is_female, check_patient_is_geriatric, create_unique_id


def create_distributed_age_df(spark: SparkSession, file_path: str, dataset_size: int = 10000) -> DataFrame:

    data_processor_age = CSVDataProcessor(spark, file_path)

    csv_age_sq_df = data_processor_age.runner()

    normalised_df = normalise_population_density(csv_age_sq_df)
    sampled_rdd = oversample_ages(normalised_df, dataset_size)

    row_rdd = create_rows_rdd(sampled_rdd)

    sampled_df = spark.createDataFrame(row_rdd)

    return sampled_df.orderBy(rand()).limit(dataset_size)

def get_ethnicity_df(spark: SparkSession, file_path: str) -> DataFrame:
    data_processor_ethnicity = CSVDataProcessor(spark, file_path)

    ethnicity_df = data_processor_ethnicity.runner()
    data_collected = ethnicity_df.rdd.map(lambda row: (row['ethnic_group'], row['population'])).collect()
    dict_from_df = dict(data_collected)
    print(dict_from_df)

def create_rows_rdd(sample_rdd: RDD) -> RDD:

    return sample_rdd.map(
        lambda age: (
            lambda gender, name, dob: Row(
                Age=int(age),
                DOB=dob,
                Blood_type=choose_blood_type(),
                Gender=gender,
                Name=name,
                is_female=check_gender_is_female(gender),
                is_pediatric=check_patient_is_pediatric(age),
                is_geriatric=check_patient_is_geriatric(age),
                unique_id=create_unique_id(name, dob)
            )
        )(random_gender_chooser("uk"), generate_name(), dob_creator(age))
    )


def normalise_population_density(csv_age_sq_df: DataFrame) -> DataFrame:

    # Calculate the total population and add density column
    total_population = csv_age_sq_df.select(sum("population_total")).collect()[0][0]
    density_df = csv_age_sq_df.withColumn("density", col("population_total") / lit(total_population))

    # normalise the density to ensure it sums to 1
    total_density = density_df.select(sum("density")).collect()[0][0]
    return density_df.withColumn("normalised_density", col("density") / lit(total_density))


def oversample_ages(dataframe: DataFrame, dataset_size: int, oversample_factor: float = 1.1) -> RDD:

    oversample_num = int(dataset_size * oversample_factor)
    return dataframe.rdd.flatMap(
        lambda row: [row['age']] * int(row['normalised_density'] * oversample_num)
    )

In [None]:
sampled_df = create_distributed_age_df(spark,  "data/uk_age_population.csv", 5000)

In [None]:
sampled_df_aggregated = sampled_df.groupby(col("Age")).count()

In [None]:
sampled_df_aggregated.orderBy("Age").show(n=200)

In [None]:
plot_age_distribution_with_sd(sampled_df_aggregated)

In [None]:
sampled_df.createOrReplaceGlobalTempView("sampled_df_with_unique_id_gt")

In [None]:
from src.constants.admission_types_dataset import admission_mapping
flattened = [
    (
        top_level.name,
        sub_level_key.name,
        list(sub_level_info.get("stay_types")),
        sub_level_info.get("tests"),  
        random.choice(sub_level_info.get("doctors"))
    )
    for top_level, sub_level_dict in admission_mapping.items()
    for sub_level_key, sub_level_info in sub_level_dict.items()
]



In [None]:
flattened


In [None]:
from src.constants.condition_probabilities import condition_age_probability_dict

flattened_condition_probabilities = [
    (
        sub_admission,
        condition,
        gender if gender in ['male', 'female'] else None,
        float(age_range[0]),
        float(age_range[1]) if len(age_range) > 1 else float('inf'),
        float(probability) 
    )
    for sub_admission, conditions in condition_age_probability_dict.items()
    for condition, genders_or_age_prob_list in conditions.items()
    for gender, age_prob_list in (genders_or_age_prob_list.items() if isinstance(genders_or_age_prob_list, dict) else [(None, genders_or_age_prob_list)])
    for age_range, probability in age_prob_list
]



In [None]:
flattened_condition_probabilities

In [None]:
columns = ["condition_admission_type", "condition", "gender", "age_min", "age_max", "probability"]
schema = StructType([
    StructField("condition_admission_type", StringType(), True),
    StructField("condition", StringType(), True),
    StructField("condition_gender", StringType(), True),  # Assuming gender can be specific probability for gender 'male', 'female', or null
    StructField("age_min", FloatType(), False),
    StructField("age_max", FloatType(), False),
    StructField("probability", FloatType(), False)
])

condition_prob_df = spark.createDataFrame(data=flattened_condition_probabilities, schema=schema)


In [None]:
condition_prob_df.show()

In [None]:
"""
STEP 1 
    From the flattened data list we wish to create a DataFrame. This contains all the possible combinations for the given 
    top level admissions, sub level admissions, stay types and list of tests available from the admission_mapping, stay_type and admission_tests lists or dictionaries in admission_types_test_dataset.py
"""

mapping_df = spark.createDataFrame(flattened, ["top_level_admission", "sub_level_admission", "stay_types", "possible_tests", "doctor"])

joined_tbl = mapping_df.join(condition_prob_df, on=[mapping_df.sub_level_admission == condition_prob_df.condition_admission_type], how="left")

In [None]:
gt_df = spark.sql("SELECT * FROM global_temp.sampled_df_with_unique_id_gt")


In [None]:
"""
STEP 2
Create Enum class df and join it on to main driver df. 
Then create conditions 
"""

from src.data_generator.conditions_creator import ConditionsCreator
from src.utils.thread_operations import runner
from src.constants.type_constants import DepartmentTypes

enum_values = [e.name for e in DepartmentTypes]

enum_df = spark.createDataFrame(enum_values, StringType()).toDF("admission_type")

selected_conditions_df = runner(spark, ConditionsCreator, gt_df, joined_tbl, enum_df)


In [None]:
from utils.util_funcs import get_row_count

get_row_count(selected_conditions_df)

In [None]:
selected_conditions_df.show()

In [None]:
# TODO:  
# tests to be chosen
# admission date to be chosen 
# hospital name 
# patient postcode 
# ethnicity distribution  
# _________________
# Long term 
# investigation which disease affects specific ethnicities
# blood type udf would need to consider ethnicity also eventually



In [None]:
spark.stop()

In [8]:
all_postcodes = [
"AB","AL","B","BA","BB","BD","BH","BL","BN","BR","BS","BT","CA","CB","CF","CH","CM","CO","CR","CT","CV","CW","DA","DD","DE","DG","DH","DL","DN","DT","DY","E","EC","EH","EN","EX","FK","FY","G","GL","GU","GY","HA","HD","HG","HP","HR","HS","HU","HX","IG","IM","IP","IV","JE","KA","KT","KW","KY","L","LA","LD","LE","LL","LN","LS","LU","M","ME","MK","ML","N","NE","NG","NN","NP","NR","NW","OL","OX","PA","PE","PH","PL","PO","PR","RG","RH","RM","S","SA","SE","SG","SK","SL","SM","SN","SO","SP","SR","SS","ST","SW","SY","TA","TD","TF","TN","TQ","TR","TS","TW","UB","W","WA","WC","WD","WF","WN","WR","WS","WV","YO","ZE"
]


In [9]:
len(all_postcodes)

124

In [10]:
present_postcodes = [
"BS",
"B",
"BN",
"BD",
"BL",
"CA",
"BR",
"BH",
"AL",
"CH",
"BA",
"CB",
"CF",
"BB",
"GL",
"EX",
"E",
"FY",
"CV",
"DN",
"DL",
"CT",
"EN",
"DE",
"EC",
"CR",
"DY",
"CM",
"DA",
"DG",
"CW",
"DT",
"DH",
"CO",
"LA",
"HR",
"LN",
"LL",
"HP",
"L",
"GU",
"KT",
"HA",
"HD",
"LD",
"LS",
"LE",
"HU",
"IG",
"IP",
"HX",
"HG",
"MK",
"NW",
"OL",
"NP",
"NE",
"M",
"NR",
"ME",
"NN",
"N",
"NG",
"LU",
"RG",
"PL",
"RH",
"SA",
"PO",
"OX",
"SG",
"S",
"SE",
"PE",
"PR",
"RM",
"SL",
"SM",
"SW",
"SO",
"TW",
"SP",
"TR",
"TS",
"TN",
"TD",
"TQ",
"SK",
"SN",
"SY",
"SS",
"TF",
"TA",
"SR",
"ST",
"WC",
"YO",
"WA",
"WF",
"W",
"WS",
"WR",
"WD",
"WV",
"UB",
"WN",
]
len(present_postcodes)

106

In [12]:
missing_postcodes = list(set(all_postcodes) - set(present_postcodes))

In [13]:
missing_postcodes

['AB',
 'JE',
 'EH',
 'KA',
 'HS',
 'ML',
 'ZE',
 'IV',
 'GY',
 'KY',
 'PA',
 'G',
 'KW',
 'PH',
 'FK',
 'DD',
 'IM',
 'BT']