In [0]:
#1 Importing Tools 
import openpyxl
import pandas as pd

from pyspark.sql import functions as F
from datetime import datetime
from openpyxl.styles import NamedStyle

In [0]:
#2 Reduce risk of a timeout by increasing limit to 30 minutes
spark.conf.set("spark.databricks.execution.timeout", "1800")

In [0]:
#3 Loading the master hierarchies table from the lake mart
df_master_hierarchies = spark.read.option("header", "true").csv("abfss://analytics-projects@udalstdataanalysisprod.dfs.core.windows.net/ElectiveRecovery/EROC_Collection_Queries/master_hierarchies_table.csv")
#display(df_master_hierarchies.limit(10))
#print(f"Number of rows in master hierarchies: {df_master_hierarchies.count()}")

In [0]:
#4 loading ICB to Region table
df_icb_region = spark.read.option("header", "true").csv("abfss://analytics-projects@udalstdataanalysisprod.dfs.core.windows.net/ElectiveRecovery/EROC_Collection/EROC/EROC_ICB_Region_DisplayNames.csv")  # Ensure proper Azure credentials are configured for ADLS access.
#display(df_icb_region.limit(10))
#print(f"Number of rows in icb_region: {df_icb_region.count()}")

In [0]:
#5 loading list of merged providers
df_merged_providers = spark.read.option("header", "true").csv("abfss://analytics-projects@udalstdataanalysisprod.dfs.core.windows.net/ElectiveRecovery/EROC_Collection/EROC/EROC_Merged_Providers.csv")
#display(df_merged_providers.limit(10))
#print(f"Number of rows in merged providers: {df_merged_providers.count()}")

In [0]:
#6 creating new provider code from the provider mapping table
provider_code_mapping = df_merged_providers = spark.read.option("header", "true").csv("abfss://analytics-projects@udalstdataanalysisprod.dfs.core.windows.net/ElectiveRecovery/EROC_Collection/EROC/EROC_Merged_Providers.csv")
#display(df_merged_providers.limit(10))
#print(f"Number of rows in merged providers: {df_merged_providers.count()}")

In [0]:
#7a importing MHS metric list and internal ID
mhs_metric_list = spark.read.option("header", "true").csv("abfss://analytics-projects@udalstdataanalysisprod.dfs.core.windows.net/ElectiveRecovery/EROC_Collection/MHS")
#display(mhs_metric_list.limit(10))
#print(f"Number of rows in mhs_metric_list: {mhs_metric_list.count()}")

#7b importing MHS allowable org codes
mhs_allowable_orgs = spark.read.option("header", "true").csv("abfss://analytics-projects@udalstdataanalysisprod.dfs.core.windows.net/ElectiveRecovery/EROC_Collection/MHS/Allowable_Org_Codes_Status.csv")
#display(mhs_allowable_orgs.limit(10))
#print(f"Number of rows in mhs_allowable_orgs: {mhs_allowable_orgs.count()}")

In [0]:
#8 Loading the core monthly snapshot data
from pyspark.sql import functions as F
df_op_activity_snapshot = spark.read.option("header", "true").option("recursiveFileLookup", "true").parquet(
    "abfss://reporting@udalstdatacuratedprod.dfs.core.windows.net/restricted/patientlevel/MESH/OPA/OPA_Core_Monthly_Snapshot/Published/1"
)
#display(df_op_activity_snapshot.limit(10))

# Show number of rows in the raw data
row_count = df_op_activity_snapshot.count()
print(f"Number of rows in raw data: {row_count}")

In [0]:
#9 Creating the wide table & inserting new column for merged providers with new merger codes and mapping to ICB and Region codes
from pyspark.sql.functions import when, col, lit, create_map, coalesce
import pyspark.sql.functions as F

# Define valid treatment function codes
VALID_TREATMENT_CODES = [
    '100', '101', '102', '104', '105', '106', '108', '110', '111', '115', '120', '130', '140',
    '144', '145', '301', '302', '303', '307', '320', '330', '340', '361', '400', '410', '420',
    '430', '501', '502', '560', '650'
]

# Adding in the Treatment_Function_Code_New column
opa_with_tfc = df_op_activity_snapshot.withColumn(
    "Treatment_Function_Code_New",
    when(col("Treatment_Function_Code").isin(VALID_TREATMENT_CODES), col("Treatment_Function_Code")).otherwise("Other")
)

# Add Treatment_Function_Group column using VALID_TREATMENT_CODES groupings
opa_with_groups = opa_with_tfc.withColumn(
    "Treatment_Function_Group",
    when(col("Treatment_Function_Code_New").isin("100", "102", "104", "105", "106"), "GS")
     .when(col("Treatment_Function_Code_New").isin("140", "144", "145"), "OMFS")
     .when(col("Treatment_Function_Code_New").isin("110", "111", "115"), "T&O")
     .otherwise(col("Treatment_Function_Code_New"))
)

# Filter dataset for relevant years, admin category, TFC, and attendance
opa_filtered = opa_with_groups.filter(
    (col("Der_Financial_Year").isin("2023/24", "2024/25", "2025/26")) &  # This can be updated manually
    (col("Administrative_Category") == "01") &
    (col("Treatment_Function_Code") != "812") &
    (col("First_Attendance").isin("1", "2", "3", "4"))
)

# Aggregates the metrics by month, provider, and Treatment_Function_Group
opa_agg = opa_filtered.groupBy(
    "Der_Activity_Month",
    "Der_Provider_Code",
    "Treatment_Function_Group"
).agg(
    # All contacts
    F.sum(when((col("Attendance_Status").isin("5", "6")) & (col("First_Attendance").isin("1", "2", "3", "4")), 1).otherwise(0)).alias("All_Total"),
    F.sum(when((col("Attendance_Status").isin("5", "6")) & (col("First_Attendance").isin("1", "3")), 1).otherwise(0)).alias("All_First"),
    F.sum(when((col("Attendance_Status").isin("5", "6")) & (col("First_Attendance").isin("2", "4")), 1).otherwise(0)).alias("All_FU"),
    F.sum(when((col("Der_Number_Procedure") > 0) & (col("Attendance_Status").isin("5", "6")) & (col("First_Attendance").isin("1", "2", "3", "4")), 1).otherwise(0)).alias("All_Proc"),
    F.sum(when((col("Der_Number_Procedure") == 0) & (col("Attendance_Status").isin("5", "6")) & (col("First_Attendance").isin("1", "2", "3", "4")), 1).otherwise(0)).alias("All_NoProc"),
    F.sum(when((col("Der_Number_Procedure") > 0) & (col("Attendance_Status").isin("5", "6")) & (col("First_Attendance").isin("2", "4")), 1).otherwise(0)).alias("All_FU_Proc"),
    F.sum(when((col("Der_Number_Procedure") == 0) & (col("Attendance_Status").isin("5", "6")) & (col("First_Attendance").isin("2", "4")), 1).otherwise(0)).alias("All_FU_NoProc"),
    # Face-to-face
    F.sum(when((col("Attendance_Status").isin("5", "6")) & (col("First_Attendance").isin("1", "2")), 1).otherwise(0)).alias("F2F_Total"),
    F.sum(when((col("Attendance_Status").isin("5", "6")) & (col("First_Attendance") == "1"), 1).otherwise(0)).alias("F2F_First"),
    F.sum(when((col("Attendance_Status").isin("5", "6")) & (col("First_Attendance") == "2"), 1).otherwise(0)).alias("F2F_FU"),
    # Remote
    F.sum(when((col("Attendance_Status").isin("5", "6")) & (col("First_Attendance").isin("3", "4")), 1).otherwise(0)).alias("Remote_Total"),
    F.sum(when((col("Attendance_Status").isin("5", "6")) & (col("First_Attendance") == "3"), 1).otherwise(0)).alias("Remote_First"),
    F.sum(when((col("Attendance_Status").isin("5", "6")) & (col("First_Attendance") == "4"), 1).otherwise(0)).alias("Remote_FU"),
    # Did not attends (DNAs)
    F.sum(when((col("Attendance_Status").isin("3", "7")) & (col("First_Attendance").isin("1", "2", "3", "4")), 1).otherwise(0)).alias("All_DNA"),
    F.sum(when((col("Attendance_Status").isin("3", "7")) & (col("First_Attendance").isin("1", "3")), 1).otherwise(0)).alias("All_First_DNA"),
    F.sum(when((col("Attendance_Status").isin("3", "7")) & (col("First_Attendance").isin("2", "4")), 1).otherwise(0)).alias("All_FU_DNA"),
    F.sum(when((col("Attendance_Status").isin("3", "7")) & (col("First_Attendance").isin("1", "2")), 1).otherwise(0)).alias("F2F_DNA"),
    F.sum(when((col("Attendance_Status").isin("3", "7")) & (col("First_Attendance").isin("3", "4")), 1).otherwise(0)).alias("Remote_DNA"),
    # 2WW DNA
    F.sum(when((col("Attendance_Status").isin("3", "7")) & (col("First_Attendance").isin("1", "2", "3", "4")) & (col("Priority_Type") == "3"), 1).otherwise(0)).alias("All_2WW_DNA"),
    F.sum(when((col("Attendance_Status").isin("3", "7")) & (col("First_Attendance").isin("1", "3")) & (col("Priority_Type") == "3"), 1).otherwise(0)).alias("All_First_2WW_DNA"),
    F.sum(when((col("Attendance_Status").isin("3", "7")) & (col("First_Attendance").isin("2", "4")) & (col("Priority_Type") == "3"), 1).otherwise(0)).alias("All_FU_2WW_DNA"),
    # All 2WW appointments
    F.sum(when((col("Attendance_Status").isin("5", "6", "3", "7")) & (col("First_Attendance").isin("1", "2", "3", "4")) & (col("Priority_Type") == "3"), 1).otherwise(0)).alias("All_2WW"),
    F.sum(when((col("Attendance_Status").isin("5", "6", "3", "7")) & (col("First_Attendance").isin("1", "3")) & (col("Priority_Type") == "3"), 1).otherwise(0)).alias("All_First_2WW"),
    F.sum(when((col("Attendance_Status").isin("5", "6", "3", "7")) & (col("First_Attendance").isin("2", "4")) & (col("Priority_Type") == "3"), 1).otherwise(0)).alias("All_FU_2WW")
)

# Add "All" TFC totals by month and provider
METRIC_COLS = [c for c in opa_agg.columns if c not in ["Der_Activity_Month", "Der_Provider_Code", "Treatment_Function_Group"]]

opa_all_tfc = opa_agg.groupBy("Der_Activity_Month", "Der_Provider_Code").agg(
    *[F.sum(col(c)).alias(c) for c in METRIC_COLS]
).withColumn("Treatment_Function_Group", lit("All"))

opa_final = opa_agg.unionByName(opa_all_tfc)

# Order by results
opa_final_ordered = opa_final.orderBy("Der_Activity_Month", "Der_Provider_Code", "Treatment_Function_Group")

# Inserted mapping code to build mapping_expr from df_merged_providers
provider_code_mapping_dict = {
    row['Old_Provider_Code']: row['New_Provider_Code']
    for row in df_merged_providers.select("Old_Provider_Code", "New_Provider_Code").distinct().collect()
}

mapping_list = []
for k, v in provider_code_mapping_dict.items():
    mapping_list.append(lit(k))
    mapping_list.append(lit(v))

mapping_expr = create_map(mapping_list)

# Add "Adj Org Code" column based on provider_code_mapping
opa_final_ordered_with_adj = opa_final_ordered.withColumn(
    "Adj Org Code",
    coalesce(mapping_expr.getItem(col("Der_Provider_Code")), col("Der_Provider_Code"))
)

# Add "ICB" column by joining to df_master_hierarchies on Organisation_Code and returning STP_Code
opa_final_ordered_with_icb = opa_final_ordered_with_adj.join(
    df_master_hierarchies.select(
        F.col("Organisation_Code").alias("join_org_code"),
        F.col("STP_Code").alias("ICB")
    ),
    opa_final_ordered_with_adj["Adj Org Code"] == F.col("join_org_code"),
    "left"
).drop("join_org_code")

# Add "Region" column by joining to df_icb_region on ICB column and returning Region_Code
opa_final_ordered_with_icb_region = opa_final_ordered_with_icb.join(
    df_icb_region.select(
        F.col("ICB_Code").alias("join_icb"),
        F.col("Region_Code")
    ),
    opa_final_ordered_with_icb["ICB"] == F.col("join_icb"),
    "left"
).drop("join_icb")

from pyspark.sql.functions import last_day, to_date, concat_ws

opa_final_ordered_with_icb_region = opa_final_ordered_with_icb_region.withColumn(
    "Der_Activity_Month_Date",
    last_day(
        to_date(
            concat_ws(
                '-',
                col("Der_Activity_Month").substr(1, 4),
                col("Der_Activity_Month").substr(5, 2),
                lit("01")
            )
        )
    )
)

# display(opa_final_ordered_with_icb_region.limit(10))

opa_final_ordered_with_icb_region_row_count = opa_final_ordered_with_icb_region.count()
# Drop unwanted columns, aggregate metrics, and sort the final table
id_cols = ["Der_Activity_Month_Date", "Treatment_Function_Group", "Region_Code", "ICB", "Adj Org Code"]

# Determine metric columns (exclude identifiers and the three columns to drop)
metric_cols = [
    c for c in opa_final_ordered_with_icb_region.columns
    if c not in id_cols + ["Der_Activity_Month", "Der_Provider_Code", "Treatment_Function_Code_New"]
]

opa_final_processed = (
    opa_final_ordered_with_icb_region
    .groupBy(*[F.col(c) for c in id_cols])
    .agg(*[F.sum(F.col(c)).alias(c) for c in metric_cols])
    .orderBy("Der_Activity_Month_Date", "Region_Code", "ICB", "Adj Org Code", "Treatment_Function_Group")
)

# display(opa_final_processed.limit(10))
# print(f"Number of rows in opa_final_processed: {opa_final_processed.count()}")


In [0]:
#10 — Safe metric calculation (robust to missing columns)
from pyspark.sql import functions as F

df = opa_final_ordered_with_icb_region

def safe_add(df, new_col, expr_fn, required_cols):
    if all(c in df.columns for c in required_cols):
        return df.withColumn(new_col, expr_fn(df))
    else:
        return df.withColumn(new_col, F.lit(None))

metrics = [
    ("All_DNA_Over_All_Total", lambda d: F.when(
        (F.col("All_Total") + F.col("All_DNA")) != 0,
        (F.col("All_DNA") / (F.col("All_Total") + F.col("All_DNA"))) * 100
    ), ["All_Total", "All_DNA"]),
    ("All_DNA_Over_All_Total_IG", lambda d: F.when(
        (F.col("All_Total") + F.col("All_DNA")) != 0,
        (F.col("All_DNA") / (F.col("All_Total") + F.col("All_DNA"))) * 100
    ), ["All_Total", "All_DNA"]),
    ("All_First_DNA_Over_All_First", lambda d: F.when(
        (F.col("All_First") + F.col("All_First_DNA")) != 0,
        (F.col("All_First_DNA") / (F.col("All_First") + F.col("All_First_DNA"))) * 100
    ), ["All_First", "All_First_DNA"]),
    ("All_First_DNA_Over_All_First_IG", lambda d: F.when(
        (F.col("All_First") + F.col("All_First_DNA")) != 0,
        (F.col("All_First_DNA") / (F.col("All_First") + F.col("All_First_DNA"))) * 100
    ), ["All_First", "All_First_DNA"]),
    ("All_FU_DNA_Over_All_FU", lambda d: F.when(
        (F.col("All_FU") + F.col("All_FU_DNA")) != 0,
        (F.col("All_FU_DNA") / (F.col("All_FU") + F.col("All_FU_DNA"))) * 100
    ), ["All_FU", "All_FU_DNA"]),
    ("All_FU_DNA_Over_All_FU_IG", lambda d: F.when(
        (F.col("All_FU") + F.col("All_FU_DNA")) != 0,
        (F.col("All_FU_DNA") / (F.col("All_FU") + F.col("All_FU_DNA"))) * 100
    ), ["All_FU", "All_FU_DNA"]),
    ("All_2WW_DNA_Over_All_2WW", lambda d: F.when(
        (F.col("All_2WW") != 0) & (F.col("All_2WW").isNotNull()),
        (F.col("All_2WW_DNA") / F.col("All_2WW")) * 100
    ), ["All_2WW_DNA", "All_2WW"]),
    ("All_FU_2WW_DNA_Over_All_FU_2WW", lambda d: F.when(
        (F.col("All_FU_2WW") != 0) & (F.col("All_FU_2WW").isNotNull()),
        (F.col("All_FU_2WW_DNA") / F.col("All_FU_2WW")) * 100
    ), ["All_FU_2WW_DNA", "All_FU_2WW"]),
    ("All_First_2WW_DNA_Over_All_First_2WW", lambda d: F.when(
        (F.col("All_First_2WW") != 0) & (F.col("All_First_2WW").isNotNull()),
        (F.col("All_First_2WW_DNA") / F.col("All_First_2WW")) * 100
    ), ["All_First_2WW_DNA", "All_First_2WW"]),
    ("All_FU_NoProc_Over_All_FU", lambda d: F.when(
        F.col("All_FU") != 0, (F.col("All_FU_NoProc") / F.col("All_FU")) * 100
    ), ["All_FU_NoProc", "All_FU"]),
    ("All_FU_Proc_Over_All_FU", lambda d: F.when(
        F.col("All_FU") != 0, (F.col("All_FU_Proc") / F.col("All_FU")) * 100
    ), ["All_FU_Proc", "All_FU"]),
    ("All_FU_To_All_First", lambda d: F.when(
        F.col("All_First") != 0, (F.col("All_FU") / F.col("All_First"))
    ), ["All_FU", "All_First"]),
    ("All_FU_Over_All_Total", lambda d: F.when(
        F.col("All_Total") != 0, (F.col("All_FU") / F.col("All_Total")) * 100
    ), ["All_FU", "All_Total"]),
    ("All_First_Over_All_Total", lambda d: F.when(
        F.col("All_Total") != 0, (F.col("All_First") / F.col("All_Total")) * 100
    ), ["All_First", "All_Total"]),
    ("All_NoProc_Over_All_Total", lambda d: F.when(
        F.col("All_Total") != 0, (F.col("All_NoProc") / F.col("All_Total")) * 100
    ), ["All_NoProc", "All_Total"]),
    ("All_Proc_Over_All_Total", lambda d: F.when(
        F.col("All_Total") != 0, (F.col("All_Proc") / F.col("All_Total")) * 100
    ), ["All_Proc", "All_Total"]),
    ("Remote_Total_Over_All_Total", lambda d: F.when(
        F.col("All_Total") != 0, (F.col("Remote_Total") / F.col("All_Total")) * 100
    ), ["Remote_Total", "All_Total"]),
    ("Remote_FU_Over_All_FU", lambda d: F.when(
        F.col("All_FU") != 0, (F.col("Remote_FU") / F.col("All_FU")) * 100
    ), ["Remote_FU", "All_FU"]),
    ("Remote_First_Over_All_First", lambda d: F.when(
        F.col("All_First") != 0, (F.col("Remote_First") / F.col("All_First")) * 100
    ), ["Remote_First", "All_First"]),
    ("F2F_DNA_Over_F2F_Total", lambda d: F.when(
        (F.col("F2F_Total") + F.col("F2F_DNA")) != 0,
        (F.col("F2F_DNA") / (F.col("F2F_Total") + F.col("F2F_DNA"))) * 100
    ), ["F2F_Total", "F2F_DNA"]),
    ("Remote_DNA_Over_Remote_Total", lambda d: F.when(
        (F.col("Remote_Total") + F.col("Remote_DNA")) != 0,
        (F.col("Remote_DNA") / (F.col("Remote_Total") + F.col("Remote_DNA"))) * 100
    ), ["Remote_Total", "Remote_DNA"]),
]

for name, expr, req in metrics:
    df = safe_add(df, name, expr, req)

simple_copies = [
    ("All_DNA1", "All_DNA"),
    ("All_DNA2", "All_DNA"),
    ("All_First1", "All_First"),
    ("All_First2", "All_First"),
    ("All_First3", "All_First"),
    ("All_FU1", "All_FU"),
    ("All_FU2", "All_FU"),
    ("All_FU3", "All_FU"),
    ("All_FU4", "All_FU"),
    ("All_FU5", "All_FU"),
    ("All_Total1", "All_Total"),
    ("All_Total2", "All_Total"),
    ("All_Total3", "All_Total"),
    ("All_Total4", "All_Total"),
    ("All_Total5", "All_Total"),
    ("All_Total6", "All_Total"),
    ("Remote_Total1", "Remote_Total"),
    ("Remote_Total2", "Remote_Total"),
]
for newc, base in simple_copies:
    if base in df.columns:
        df = df.withColumn(newc, F.col(base))
    else:
        df = df.withColumn(newc, F.lit(None))

combos = [
    ("All_First_plus_All_First_DNA", ["All_First", "All_First_DNA"]),
    ("All_FU_plus_All_FU_DNA", ["All_FU", "All_FU_DNA"]),
    ("All_Total_plus_All_DNA", ["All_Total", "All_DNA"]),
    ("F2F_Total_plus_F2F_DNA", ["F2F_Total", "F2F_DNA"]),
    ("Remote_Total_plus_Remote_DNA", ["Remote_Total", "Remote_DNA"]),
]
for newc, cols in combos:
    if all(c in df.columns for c in cols):
        df = df.withColumn(newc, F.col(cols[0]).cast("long") + F.col(cols[1]).cast("long"))
    else:
        df = df.withColumn(newc, F.lit(None))

cols_to_drop = [c for c in ["Der_Activity_Month", "Der_Provider_Code"] if c in df.columns]
opa_final_with_added_metrics = df.drop(*cols_to_drop)

display(opa_final_with_added_metrics.limit(10))


In [0]:
#11 reshapes the wide outpatient dataset into a long (tidy) format for easier analysis
from pyspark.sql.functions import col, explode, array, struct, lit, concat_ws

# ID columns to keep
id_cols = [
    "Der_Activity_Month_Date",
    # "Der_Provider_Code",
    "Treatment_Function_Group",
    "Adj Org Code",
    "ICB",
    "Region_Code"
]

# Identify all metric columns
metric_cols = [c for c in opa_final_with_added_metrics.columns if c not in id_cols]

# Unpivot numeric metrics
opa_long = (
    opa_final_with_added_metrics
    .select(
        *id_cols,
        explode(array(*[
            struct(lit(c).alias("Metric_Name"), col(c).alias("Metric_Value")) for c in metric_cols
        ])).alias("kv")
    )
    .select(
        *id_cols,
        col("kv.Metric_Name"),
        col("kv.Metric_Value")
    )
)

# Create the combined metric name
opa_long = opa_long.withColumn(
    "Metric_Name_Treatment_Function_Group",
    concat_ws("_", col("Metric_Name"), col("Treatment_Function_Group"))
)

# Order by date
opa_long_ordered = opa_long.orderBy("Der_Activity_Month_Date")

display(opa_long_ordered.limit(10))
print(f"Number of rows in opa_long_ordered: {opa_long_ordered.count()}")


In [0]:
#12 – Aggregation and final metric derivation (Org, ICB, Region)
from pyspark.sql import functions as F
from pyspark.sql.functions import when, col, lit

# Start from Org-level counts from Container 10
df_org = opa_final_with_added_metrics.withColumnRenamed("Adj Org Code", "Adj_Org_Code")

# Base metric columns to sum (added F2F_DNA which was missing)
count_cols = [
    "All_Total","All_First","All_FU","All_Proc","All_NoProc",
    "All_FU_Proc","All_FU_NoProc",
    "F2F_Total","F2F_First","F2F_FU","F2F_DNA",
    "Remote_Total","Remote_First","Remote_FU","Remote_DNA",
    "All_DNA","All_First_DNA","All_FU_DNA",
    "All_2WW","All_First_2WW","All_FU_2WW",
    "All_2WW_DNA","All_First_2WW_DNA","All_FU_2WW_DNA"
]

# Function to (re)calculate rates & derived metrics (now includes F2F_DNA_Over_F2F_Total)
def add_rate_metrics(df):
    return (
        df
        # DNA metrics
        .withColumn("All_DNA_Over_All_Total", F.when((F.col("All_Total")+F.col("All_DNA"))!=0,
            (F.col("All_DNA")/(F.col("All_Total")+F.col("All_DNA")))*100).otherwise(None))
        .withColumn("All_DNA_Over_All_Total_IG", F.when((F.col("All_Total")+F.col("All_DNA"))!=0,
            (F.col("All_DNA")/(F.col("All_Total")+F.col("All_DNA")))*100).otherwise(None))
        .withColumn("All_First_DNA_Over_All_First", F.when((F.col("All_First")+F.col("All_First_DNA"))!=0,
            (F.col("All_First_DNA")/(F.col("All_First")+F.col("All_First_DNA")))*100).otherwise(None))
        .withColumn("All_First_DNA_Over_All_First_IG", F.when((F.col("All_First")+F.col("All_First_DNA"))!=0,
            (F.col("All_First_DNA")/(F.col("All_First")+F.col("All_First_DNA")))*100).otherwise(None))
        .withColumn("All_FU_DNA_Over_All_FU", F.when((F.col("All_FU")+F.col("All_FU_DNA"))!=0,
            (F.col("All_FU_DNA")/(F.col("All_FU")+F.col("All_FU_DNA")))*100).otherwise(None))
        .withColumn("All_FU_DNA_Over_All_FU_IG", F.when((F.col("All_FU")+F.col("All_FU_DNA"))!=0,
            (F.col("All_FU_DNA")/(F.col("All_FU")+F.col("All_FU_DNA")))*100).otherwise(None))
        # FU metrics
        .withColumn("All_FU_NoProc_Over_All_FU", F.when(F.col("All_FU")!=0,
            (F.col("All_FU_NoProc")/F.col("All_FU"))*100).otherwise(None))
        .withColumn("All_FU_Proc_Over_All_FU", F.when(F.col("All_FU")!=0,
            (F.col("All_FU_Proc")/F.col("All_FU"))*100).otherwise(None))
        .withColumn("All_FU_To_All_First", F.when(F.col("All_First")!=0,
            (F.col("All_FU")/F.col("All_First"))).otherwise(None))
        # 2WW rates
        .withColumn("All_2WW_DNA_Over_All_2WW", F.when(F.col("All_2WW")!=0,
            (F.col("All_2WW_DNA")/F.col("All_2WW"))*100).otherwise(None))
        .withColumn("All_First_2WW_DNA_Over_All_First_2WW", F.when(F.col("All_First_2WW")!=0,
            (F.col("All_First_2WW_DNA")/F.col("All_First_2WW"))*100).otherwise(None))
        .withColumn("All_FU_2WW_DNA_Over_All_FU_2WW", F.when(F.col("All_FU_2WW")!=0,
            (F.col("All_FU_2WW_DNA")/F.col("All_FU_2WW"))*100).otherwise(None))
        # Mix shares
        .withColumn("All_FU_Over_All_Total", F.when(F.col("All_Total")!=0,
            (F.col("All_FU")/F.col("All_Total"))*100).otherwise(None))
        .withColumn("All_First_Over_All_Total", F.when(F.col("All_Total")!=0,
            (F.col("All_First")/F.col("All_Total"))*100).otherwise(None))
        .withColumn("All_NoProc_Over_All_Total", F.when(F.col("All_Total")!=0,
            (F.col("All_NoProc")/F.col("All_Total"))*100).otherwise(None))
        .withColumn("All_Proc_Over_All_Total", F.when(F.col("All_Total")!=0,
            (F.col("All_Proc")/F.col("All_Total"))*100).otherwise(None))
        .withColumn("Remote_Total_Over_All_Total", F.when(F.col("All_Total")!=0,
            (F.col("Remote_Total")/F.col("All_Total"))*100).otherwise(None))
        .withColumn("Remote_FU_Over_All_FU", F.when(F.col("All_FU")!=0,
            (F.col("Remote_FU")/F.col("All_FU"))*100).otherwise(None))
        .withColumn("Remote_First_Over_All_First", F.when(F.col("All_First")!=0,
            (F.col("Remote_First")/F.col("All_First"))*100).otherwise(None))
        .withColumn("Remote_DNA_Over_Remote_Total", F.when((F.col("Remote_Total")+F.col("Remote_DNA"))!=0,
            (F.col("Remote_DNA")/(F.col("Remote_Total")+F.col("Remote_DNA")))*100).otherwise(None))
        .withColumn("F2F_DNA_Over_F2F_Total", F.when((F.col("F2F_Total")+F.col("F2F_DNA"))!=0,
            (F.col("F2F_DNA")/(F.col("F2F_Total")+F.col("F2F_DNA")))*100).otherwise(None))
    )

# Aggregate to ICB
df_icb = (
    df_org.groupBy("Der_Activity_Month_Date", "ICB", "Treatment_Function_Group")
    .agg(*[F.sum(F.col(c)).alias(c) for c in count_cols])
)
df_icb = add_rate_metrics(df_icb).withColumn("Level", F.lit("ICB"))

# Aggregate to Region
df_region = (
    df_org.groupBy("Der_Activity_Month_Date", "Region_Code", "Treatment_Function_Group")
    .agg(*[F.sum(F.col(c)).alias(c) for c in count_cols])
)
df_region = add_rate_metrics(df_region).withColumn("Level", F.lit("Region"))

# Label Org-level rows
df_org = df_org.withColumn("Level", F.lit("Org"))

# Combine all levels into one dataset
final_output = (
    df_org.unionByName(df_icb, allowMissingColumns=True)
          .unionByName(df_region, allowMissingColumns=True)
)

# Adjust codes based on Level
final_output = final_output.withColumn(
    "Adj_Org_Code_Final",
    when(col("Level") == "Org", col("Adj_Org_Code"))
    .when(col("Level") == "ICB", col("ICB"))
    .when(col("Level") == "Region", col("Region_Code"))
)

# ---- NEW: re-create "simple copy" and "combo" columns at ICB/Region so they’re populated, not NULL ----
copy_map = {
    "All_DNA1":"All_DNA","All_DNA2":"All_DNA",
    "All_FU1":"All_FU","All_FU2":"All_FU","All_FU3":"All_FU","All_FU4":"All_FU","All_FU5":"All_FU",
    "All_Total1":"All_Total","All_Total2":"All_Total","All_Total3":"All_Total",
    "All_Total4":"All_Total","All_Total5":"All_Total","All_Total6":"All_Total",
    "Remote_Total1":"Remote_Total","Remote_Total2":"Remote_Total"
}
for newc, base in copy_map.items():
    if base in final_output.columns:
        final_output = final_output.withColumn(newc, F.col(base))

combo_pairs = {
    "All_First_plus_All_First_DNA": ("All_First","All_First_DNA"),
    "All_FU_plus_All_FU_DNA": ("All_FU","All_FU_DNA"),
    "All_Total_plus_All_DNA": ("All_Total","All_DNA"),
    "F2F_Total_plus_F2F_DNA": ("F2F_Total","F2F_DNA"),
    "Remote_Total_plus_Remote_DNA": ("Remote_Total","Remote_DNA")
}
for newc, (a,b) in combo_pairs.items():
    if a in final_output.columns and b in final_output.columns:
        final_output = final_output.withColumn(newc, F.col(a).cast("long") + F.col(b).cast("long"))

# Save (same as your original)
final_output.write.mode("overwrite").parquet("/mnt/output/opa_final_all_levels")

# display(final_output.limit(10))
# print(f"Rows in final output: {final_output.count()}")



In [0]:
#13 – long/skinny OPRT format (Level preserved)
from pyspark.sql.functions import col, lit, explode, array, struct

# Use the final_output table from container 12 (has Level + Adj_Org_Code_Final)
df_wide = final_output

# --- Step 1: Drop unnecessary columns (optional; keep if you don't need the raw counts downstream) ---
cols_to_drop = [
    "Adj_Org_Code",
    "All_DNA", "All_First", "All_FU",
    "All_Total", "F2F_Total", "Remote_Total"
]
df_wide = df_wide.drop(*[c for c in cols_to_drop if c in df_wide.columns])

# --- Step 2: Define identifier columns (KEEP Level) ---
id_cols = [
    "Der_Activity_Month_Date",
    "Region_Code",
    "ICB",
    "Adj_Org_Code_Final",
    "Treatment_Function_Group",
    "Level",  # <-- critical fix
]

# --- Step 3: Identify metric columns (exclude identifiers) ---
metric_cols = [c for c in df_wide.columns if c not in id_cols]

# --- Step 4: Melt into long/skinny format ---
opa_oprt_long = (
    df_wide.select(
        *id_cols,
        explode(array(*[
            struct(lit(c).alias("OPRT_Metric_Name"), col(c).alias("Metric_Value"))
            for c in metric_cols
        ])).alias("kv")
    )
    .select(
        *id_cols,
        col("kv.OPRT_Metric_Name"),
        col("kv.Metric_Value")
    )
)

# --- Step 5: Optional renaming (kept for consistency) ---
opa_oprt_long = opa_oprt_long.withColumnRenamed("Treatment_Function_Group", "Treatment_Function_Group")

display(opa_oprt_long.limit(10))
print(f"Container 13 complete — {opa_oprt_long.count()} rows, {len(opa_oprt_long.columns)} columns")

# Ensure Container 14 uses the cleaned long-format table
final_output = opa_oprt_long





In [0]:
#14 — Join Internal IDs and Clean Long OPRT Dataset (no restacking; Level preserved)
from pyspark.sql.functions import (
    col, lit, concat_ws, lower, regexp_replace, trim
)
from pyspark.sql.types import StringType, DoubleType, DateType
import pyspark.sql.functions as F

# --- Step 1: Start from container 13 output (already has Level + Adj_Org_Code_Final) ---
df_long = final_output

# --- Step 2: DO NOT restack; just keep the dataset as-is ---
df_stacked = df_long

# --- Step 3: Filter out unwanted Treatment_Function_Group = 'Other' ---
df_stacked = df_stacked.filter(col("Treatment_Function_Group") != "Other")

# --- Step 4: Build combined metric name for joining to ID list ---
df_stacked = df_stacked.withColumn(
    "OPRT_Metric_Name_TFC",
    concat_ws("_", col("OPRT_Metric_Name"), col("Treatment_Function_Group"))
)

# --- Step 5: Normalize join keys on our long dataset ---
# spaces -> underscores, drop non [a-z0-9_], collapse underscores, trim underscores, lowercase
df_stacked_clean = df_stacked.withColumn(
    "join_metric",
    lower(regexp_replace(trim(col("OPRT_Metric_Name_TFC")), r"\s+", "_"))
)
df_stacked_clean = df_stacked_clean.withColumn(
    "join_metric",
    regexp_replace(col("join_metric"), r"[^a-z0-9_]", "")
)
df_stacked_clean = df_stacked_clean.withColumn(
    "join_metric",
    regexp_replace(col("join_metric"), r"_+", "_")
)
df_stacked_clean = df_stacked_clean.withColumn(
    "join_metric",
    regexp_replace(col("join_metric"), r"^_+|_+$", "")
)

# --- Step 5b: Prepare mhs_metric_list cleaned keys (bring in Description too) ---
metric_list_col = [c for c in mhs_metric_list.columns if "OPRT" in c and "TFC" in c]
if len(metric_list_col) == 0:
    raise ValueError("Could not find OPRT TFC metric column name in mhs_metric_list. Review mhs_metric_list columns.")
metric_list_col = metric_list_col[0]

mhs_metric_list_clean = mhs_metric_list.withColumn(
    "join_metric",
    lower(regexp_replace(trim(col(metric_list_col)), r"\s+", "_"))
)
mhs_metric_list_clean = mhs_metric_list_clean.withColumn(
    "join_metric",
    regexp_replace(col("join_metric"), r"[^a-z0-9_]", "")
)
mhs_metric_list_clean = mhs_metric_list_clean.withColumn(
    "join_metric",
    regexp_replace(col("join_metric"), r"_+", "_")
)
mhs_metric_list_clean = mhs_metric_list_clean.withColumn(
    "join_metric",
    regexp_replace(col("join_metric"), r"^_+|_+$", "")
)

select_cols = ["join_metric", "InternalID"]
if "Description" in mhs_metric_list_clean.columns:
    select_cols.append("Description")
else:
    print("WARNING: mhs_metric_list does not contain a column named 'Description'.")

mhs_metric_list_clean_sel = mhs_metric_list_clean.select(*select_cols).distinct()

# --- Step 6a: Left join first — to capture unmatched metrics for debugging ---
df_with_id_left = df_stacked_clean.join(
    mhs_metric_list_clean_sel,
    on="join_metric",
    how="left"
)

# show how many unmatched keys we have
unmatched = df_with_id_left.filter(col("InternalID").isNull()).select("OPRT_Metric_Name_TFC").distinct()
# print("Number of distinct unmatched OPRT metric keys:", unmatched.count())

# write unmatched to disk for inspection
unmatched.write.mode("overwrite").parquet("/mnt/output/unmatched_oprt_metrics.parquet")
# display(unmatched.limit(50))

# Optional: inspect where '2WW' metrics went
df_with_id_left.filter(lower(col("OPRT_Metric_Name")).like("%2ww%")) \
    .select("OPRT_Metric_Name_TFC","InternalID").distinct().show(200, truncate=False)

# --- Step 6b: Final join for production (inner; drop unmapped) ---
df_with_id = df_stacked_clean.join(
    mhs_metric_list_clean_sel,
    on="join_metric",
    how="inner"
).drop("join_metric")

# --- Step 7: Filter allowable org codes if provided ---
if "Org_Code" in mhs_allowable_orgs.columns:
    df_with_id = df_with_id.join(
        mhs_allowable_orgs.select(col("Org_Code").alias("allowable_code")),
        df_with_id["Adj_Org_Code_Final"] == col("allowable_code"),
        "inner"
    ).drop("allowable_code")

# --- Step 8: Enforce clean data types and ensure Description present ---
df_with_id = (
    df_with_id
    .withColumn("Der_Activity_Month_Date", col("Der_Activity_Month_Date").cast(DateType()))
    .withColumn("Adj_Org_Code_Final", col("Adj_Org_Code_Final").cast(StringType()))
    .withColumn("Level", col("Level").cast(StringType()))
    .withColumn("Treatment_Function_Group", col("Treatment_Function_Group").cast(StringType()))
    .withColumn("OPRT_Metric_Name", col("OPRT_Metric_Name").cast(StringType()))
    .withColumn("OPRT_Metric_Name_TFC", col("OPRT_Metric_Name_TFC").cast(StringType()))
    .withColumn("Metric_Value", col("Metric_Value").cast(DoubleType()))
)

if "Description" not in df_with_id.columns:
    df_with_id = df_with_id.withColumn("Description", lit(None).cast(StringType()))

# --- Step 9: Write final tidy dataset ---
df_with_id.write.mode("overwrite").parquet("/mnt/output/opa_oprt_final")

# display(df_with_id.limit(20))
# print(f" Container 14 complete — {df_with_id.count()} rows, {len(df_with_id.columns)} columns")
# print("Unmatched metric keys written to /mnt/output/unmatched_oprt_metrics.parquet for inspection.")






In [0]:
#15 – Add Remote Lower Benchmark
from pyspark.sql import functions as F

# Start from container 10 output
df_benchmark = opa_final_with_added_metrics

# Step 1: Calculate 25th percentile of Remote_Total by month and Adj Org Code
remote_lower = (
    df_benchmark
    .groupBy("Der_Activity_Month_Date", "Adj Org Code")
    .agg(
        F.expr("percentile_approx(Remote_Total, 0.25)").alias("Remote_Lower_Benchmark")
    )
)

# Step 2: Join benchmark back to main dataset
df_with_benchmark = df_benchmark.join(
    remote_lower,
    on=["Der_Activity_Month_Date", "Adj Org Code"],
    how="left"
)

# Step 3: For missing values, fill with 0
df_with_benchmark = df_with_benchmark.fillna({"Remote_Lower_Benchmark": 0})

# Step 4: Final outputs
opa_final_with_remote_benchmark = df_with_benchmark

# Create a standalone lower benchmark table for downstream use
df_lower_bm = (
    remote_lower
    .withColumnRenamed("Adj Org Code", "Adj_Org_Code_Final")
    .select("Der_Activity_Month_Date", "Adj_Org_Code_Final", "Remote_Lower_Benchmark")
)

# Optional: Save for reference or later join
df_lower_bm.write.mode("overwrite").parquet("/mnt/output/opa_lower_benchmark")

#display(opa_final_with_remote_benchmark.limit(10))
#print(f"Container 15 complete — {opa_final_with_remote_benchmark.count()} rows, {len(opa_final_with_remote_benchmark.columns)} columns")
#print(f"Lower benchmark table created — {df_lower_bm.count()} rows, {len(df_lower_bm.columns)} columns")



In [0]:
#16 DNA opportunities
from pyspark.sql import functions as F
from pyspark.sql.window import Window

# Start from container 10 output (with All_DNA_Over_All_Total metric)
df = opa_final_with_added_metrics

# Add Level column if missing (Org level)
if "Level" not in df.columns:
    df = df.withColumn("Level", F.lit("Org"))

# Ignore Treatment_Function_Group and focus on Org/Level
df_filtered = df.select(
    "Der_Activity_Month_Date",
    "Adj Org Code",
    "Level",
    "All_DNA",
    "All_Total",
    "All_DNA_Over_All_Total"
)

# Define a 6-month rolling window per Org/Level
window_spec = (
    Window.partitionBy("Adj Org Code", "Level")
          .orderBy(F.col("Der_Activity_Month_Date").cast("long"))
          .rowsBetween(-5, 0)  # last 6 months including current
)

# Calculate rolling averages
avg_df = df_filtered.withColumn(
    "Avg_All_DNA", F.avg("All_DNA").over(window_spec)
).withColumn(
    "Avg_DNA_rate", F.avg("All_DNA_Over_All_Total").over(window_spec)
)

# Calculate national median and 25th percentile for each month
national_stats = (
    df_filtered.groupBy("Der_Activity_Month_Date")
    .agg(
        F.expr("percentile_approx(All_DNA_Over_All_Total, 0.5)").alias("National_Median"),
        F.expr("percentile_approx(All_DNA_Over_All_Total, 0.25)").alias("Percentile_25")
    )
)

# Join stats back to avg_df
avg_df = avg_df.join(
    national_stats,
    on="Der_Activity_Month_Date",
    how="left"
)

# Apply DNA_Opportunity_reduction rules
avg_df = avg_df.withColumn(
    "DNA_Opportunity_reduction",
    F.when(F.col("Avg_DNA_rate").isNull(), "No reduction")
     .when(F.col("Avg_All_DNA").isNull() | (F.col("Avg_All_DNA") == 0), "No reduction")
     .when(F.col("Avg_DNA_rate") > F.col("National_Median"), "25% reduction")
     .when((F.col("Avg_DNA_rate") <= F.col("National_Median")) & (F.col("Avg_DNA_rate") > F.col("Percentile_25")), "15% reduction")
     .when(F.col("Avg_DNA_rate") <= F.col("Percentile_25"), "10% reduction")
     .otherwise(None)
)

# Calculate DNA_Opportunity as an integer
avg_df = avg_df.withColumn(
    "DNA_Opportunity",
    F.when(F.col("DNA_Opportunity_reduction") == "No reduction", F.lit(0))
     .when(F.col("DNA_Opportunity_reduction") == "25% reduction", F.round(0.25 * F.col("Avg_All_DNA")).cast("int"))
     .when(F.col("DNA_Opportunity_reduction") == "15% reduction", F.round(0.15 * F.col("Avg_All_DNA")).cast("int"))
     .when(F.col("DNA_Opportunity_reduction") == "10% reduction", F.round(0.10 * F.col("Avg_All_DNA")).cast("int"))
     .otherwise(F.lit(None))
)

# Optional: keep only one row per Org/Level for latest month
latest_month_window = Window.partitionBy("Adj Org Code", "Level").orderBy(F.col("Der_Activity_Month_Date").desc())
dna_opp_df = avg_df.withColumn("row_num", F.row_number().over(latest_month_window)).filter(F.col("row_num") == 1).drop("row_num")

display(dna_opp_df.limit(10))

In [0]:
#17 — Post-run QA (robust): ICB/Region completeness for copy/combos & F2F DNA metrics
from pyspark.sql import functions as F

print("=== Container 17: Post-run QA (read-only, robust) ===")

# ---------- 0) Try to read wide output from C12 (authoritative for column-based checks) ----------
df_wide = None
df_long = None

# Load wide (C12)
try:
    df_wide = spark.read.parquet("/mnt/output/opa_final_all_levels")
    print("Loaded wide table from /mnt/output/opa_final_all_levels")
except Exception as e:
    print("Could not load wide table (/mnt/output/opa_final_all_levels):", str(e))

# Load long (C14) for secondary checks
try:
    df_long = spark.read.parquet("/mnt/output/opa_oprt_final")
    print("Loaded long table from /mnt/output/opa_oprt_final")
except Exception as e:
    if 'df_with_id' in locals():
        df_long = df_with_id
        print("Using in-memory df_with_id for long checks")
    else:
        print("No long table available for OPRT-based checks.")

# ---------- 1) If we have the wide table, run column-based checks there ----------
copy_cols  = [
    "All_DNA1","All_DNA2",
    "All_FU1","All_FU2","All_FU3","All_FU4","All_FU5",
    "All_Total1","All_Total2","All_Total3","All_Total4","All_Total5","All_Total6",
    "Remote_Total1","Remote_Total2"
]
combo_cols = [
    "All_First_plus_All_First_DNA",
    "All_FU_plus_All_FU_DNA",
    "All_Total_plus_All_DNA",
    "F2F_Total_plus_F2F_DNA",
    "Remote_Total_plus_Remote_DNA"
]
extra_cols = ["F2F_DNA", "F2F_DNA_Over_F2F_Total"]

def null_summary(df, label, cols):
    total = df.count()
    print(f"\n--- {label}: {total:,} rows ---")
    if total == 0:
        print("No rows to check.")
        return
    cols_present = [c for c in cols if c in df.columns]
    cols_missing = sorted(set(cols) - set(cols_present))
    if cols_missing:
        print("WARNING: Missing expected columns:", cols_missing)
    if not cols_present:
        print("No expected columns present — skipping null summary.")
        return
    exprs = [F.sum(F.when(F.col(c).isNull(), 1).otherwise(0)).alias(c) for c in cols_present]
    res = df.agg(*exprs).collect()[0].asDict()
    any_nulls = [(k, int(v), round(100*float(v)/total,2)) for k,v in res.items() if v and v > 0]
    if not any_nulls:
        print("All checked columns are fully populated (no NULLs). ✅")
    else:
        print("Columns with NULLs (count, % of rows):")
        for k, v, pct in sorted(any_nulls, key=lambda x: (-x[1], x[0])):
            print(f"  {k:30s} {v:8d} ({pct:5.2f}%)")

def spotcheck_copies(df, label, n=100):
    print(f"\n--- Spot-check copy columns @ {label} (random {n} rows) ---")
    mapping = {
        "All_DNA1":"All_DNA","All_DNA2":"All_DNA",
        "All_FU1":"All_FU","All_FU2":"All_FU","All_FU3":"All_FU","All_FU4":"All_FU","All_FU5":"All_FU",
        "All_Total1":"All_Total","All_Total2":"All_Total","All_Total3":"All_Total",
        "All_Total4":"All_Total","All_Total5":"All_Total","All_Total6":"All_Total",
        "Remote_Total1":"Remote_Total","Remote_Total2":"Remote_Total"
    }
    pairs = [(k,v) for k,v in mapping.items() if k in df.columns and v in df.columns]
    if not pairs:
        print("No copy columns present — skipping.")
        return
    sample = df.orderBy(F.rand()).limit(n)
    exprs = [F.sum(F.when(F.col(k) != F.col(v), 1).otherwise(0)).alias(k+"_neq") for k,v in pairs]
    out = sample.agg(*exprs).collect()[0].asDict()
    bad = [(k,v) for k,v in out.items() if v and v > 0]
    if not bad:
        print("All copy columns match their base values in the sample. ✅")
    else:
        print("Mismatches found in sample:")
        for k,v in bad:
            print(f"  {k}: {v}")

def spotcheck_combos(df, label, n=100):
    print(f"\n--- Spot-check combo columns @ {label} (random {n} rows) ---")
    pairs = {
        "All_First_plus_All_First_DNA": ("All_First","All_First_DNA"),
        "All_FU_plus_All_FU_DNA": ("All_FU","All_FU_DNA"),
        "All_Total_plus_All_DNA": ("All_Total","All_DNA"),
        "F2F_Total_plus_F2F_DNA": ("F2F_Total","F2F_DNA"),
        "Remote_Total_plus_Remote_DNA": ("Remote_Total","Remote_DNA")
    }
    usable = [(newc,a,b) for newc,(a,b) in pairs.items() if all(c in df.columns for c in [newc,a,b])]
    if not usable:
        print("No combo columns present — skipping.")
        return
    sample = df.orderBy(F.rand()).limit(n)
    exprs = [F.sum(F.when(F.col(newc) != (F.col(a).cast("long")+F.col(b).cast("long")), 1).otherwise(0)).alias(newc+"_neq")
             for (newc,a,b) in usable]
    out = sample.agg(*exprs).collect()[0].asDict()
    bad = [(k,v) for k,v in out.items() if v and v > 0]
    if not bad:
        print("All combo columns equal the sum of their parts in the sample. ✅")
    else:
        print("Mismatches found in sample:")
        for k,v in bad:
            print(f"  {k}: {v}")

# Replace the old rate_cols(...) in Container 17 with this:
def rate_cols(df):
    # pick only columns that clearly represent rates (contain '_Over_') or are explicitly named rate fields
    explicit = {"F2F_DNA_Over_F2F_Total", "Remote_DNA_Over_Remote_Total"}
    return sorted([c for c in df.columns if ("_Over_" in c) or (c in explicit)])


def rate_bounds(df, label):
    rcols = rate_cols(df)
    if not rcols:
        print(f"\n--- {label}: No rate columns detected — skipping. ---")
        return
    print(f"\n--- {label}: rate bounds (0–100) ---")
    exprs = [F.sum(F.when((F.col(c) < 0) | (F.col(c) > 100), 1).otherwise(0)).alias(c) for c in rcols]
    out = df.agg(*exprs).collect()[0].asDict()
    bad = [(k,v) for k,v in out.items() if v and v > 0]
    if not bad:
        print("All rates within [0,100] (ignoring NULLs). ✅")
    else:
        print("Out-of-range rate values detected:")
        for k,v in bad:
            print(f"  {k}: {v}")

if df_wide is not None:
    df_icb    = df_wide.filter(F.col("Level") == "ICB")
    df_region = df_wide.filter(F.col("Level") == "Region")

    # 1A) Null coverage
    cols_to_check = [c for c in (copy_cols + combo_cols + extra_cols) if c in df_wide.columns]
    null_summary(df_icb, "ICB level (wide)", cols_to_check)
    null_summary(df_region, "Region level (wide)", cols_to_check)

    # 1B) Spotchecks
    spotcheck_copies(df_icb, "ICB (wide)")
    spotcheck_copies(df_region, "Region (wide)")
    spotcheck_combos(df_icb, "ICB (wide)")
    spotcheck_combos(df_region, "Region (wide)")

    # 1C) Rate bounds
    rate_bounds(df_icb, "ICB (wide)")
    rate_bounds(df_region, "Region (wide)")
else:
    print("\nWide table unavailable — skipping column-based checks.")

# ---------- 2) If we have the long OPRT table, check presence of these metrics by name ----------
if df_long is not None:
    print("\n--- OPRT long checks (by metric name) ---")
    target_metric_names = [  # these are expected as OPRT_Metric_Name entries if they exist upstream
        "All_DNA1","All_DNA2",
        "All_FU1","All_FU2","All_FU3","All_FU4","All_FU5",
        "All_Total1","All_Total2","All_Total3","All_Total4","All_Total5","All_Total6",
        "Remote_Total1","Remote_Total2",
        "All_First_plus_All_First_DNA",
        "All_FU_plus_All_FU_DNA",
        "All_Total_plus_All_DNA",
        "F2F_Total_plus_F2F_DNA",
        "Remote_Total_plus_Remote_DNA",
        "F2F_DNA","F2F_DNA_Over_F2F_Total"
    ]
    present = (df_long
        .filter(F.col("Level").isin("ICB","Region"))
        .groupBy("OPRT_Metric_Name")
        .count()
        .filter(F.col("OPRT_Metric_Name").isin(target_metric_names))
        .orderBy("OPRT_Metric_Name"))

    print("Presence of target metric names at ICB/Region in long table:")
    display(present)

    # Null share by Level for those targets
    mv_nulls = (df_long
        .filter(F.col("Level").isin("ICB","Region"))
        .filter(F.col("OPRT_Metric_Name").isin(target_metric_names))
        .groupBy("Level","OPRT_Metric_Name")
        .agg(F.count("*").alias("rows"),
             F.sum(F.when(F.col("Metric_Value").isNull(),1).otherwise(0)).alias("nulls"),
             F.round(F.sum(F.when(F.col("Metric_Value").isNull(),1).otherwise(0))/F.count("*")*100,2).alias("null_pct"))
        .orderBy("Level","OPRT_Metric_Name"))

    print("Null share for those metrics (long):")
    display(mv_nulls)
else:
    print("\nLong table unavailable — skipping OPRT checks.")

print("\n=== QA complete. ===")



In [0]:
#18 sanity check on nulls in aggregation
from pyspark.sql import functions as F
df = spark.read.parquet("/mnt/output/opa_final_all_levels")

for lvl in ["ICB","Region"]:
    d = df.filter(F.col("Level")==lvl)
    n_null = d.filter(F.col("F2F_DNA_Over_F2F_Total").isNull()).count()
    n_zero = d.filter(((F.col("F2F_Total").isNull()) | (F.col("F2F_Total")==0)) &
                      ((F.col("F2F_DNA").isNull())  | (F.col("F2F_DNA")==0))).count()
    print(lvl, "NULL rates:", n_null, "| zero denominators:", n_zero)


In [0]:
#19 — QA audit writer (Delta + ADLS CSV export)

from pyspark.sql import functions as F
from pyspark.sql import types as T
from datetime import datetime

print("=== Container 18: QA audit writer (Delta + ADLS) ===")

# ---------- 0) Load inputs ----------
df_wide = spark.read.parquet("/mnt/output/opa_final_all_levels")
print("Loaded wide table from /mnt/output/opa_final_all_levels")

df_long = None
try:
    df_long = spark.read.parquet("/mnt/output/opa_oprt_final")
    print("Loaded long table from /mnt/output/opa_oprt_final")
except Exception as e:
    print("No long table:", str(e))

unmatched = None
try:
    unmatched = spark.read.parquet("/mnt/output/unmatched_oprt_metrics.parquet").distinct()
    print("Loaded unmatched metrics.")
except Exception as e:
    print("No unmatched metrics file:", str(e))

# ---------- 1) Helpers ----------
run_ts = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")

def rate_cols(df):
    explicit = {"F2F_DNA_Over_F2F_Total", "Remote_DNA_Over_Remote_Total"}
    return sorted([c for c in df.columns if ("_Over_" in c) or (c in explicit)])

rate_columns = rate_cols(df_wide)
levels = ["Org", "ICB", "Region"]
df_levels = {lvl: df_wide.filter(F.col("Level") == lvl) for lvl in levels}

# ---------- 2) Build summaries ----------
summary_rows = []
for lvl, d in df_levels.items():
    total = d.count()
    has_f2f = all(c in d.columns for c in ["F2F_Total","F2F_DNA","F2F_DNA_Over_F2F_Total"])
    null_rate = d.filter(F.col("F2F_DNA_Over_F2F_Total").isNull()).count() if has_f2f else None
    denom_zero = d.filter(
        (F.col("F2F_Total").isNull() | (F.col("F2F_Total")==0)) &
        (F.col("F2F_DNA").isNull()  | (F.col("F2F_DNA")==0))
    ).count() if has_f2f else None

    oor_total = sum(
        d.filter((F.col(rc) < 0) | (F.col(rc) > 100)).count()
        for rc in rate_columns
    )
    summary_rows.append((run_ts, lvl, total, null_rate, denom_zero, oor_total))

schema = T.StructType([
    T.StructField("run_ts", T.StringType(), False),
    T.StructField("Level", T.StringType(), False),
    T.StructField("rows", T.LongType(), True),
    T.StructField("F2F_DNA_rate_nulls", T.LongType(), True),
    T.StructField("F2F_DNA_zero_denoms", T.LongType(), True),
    T.StructField("rates_out_of_range_total", T.LongType(), True),
])
qa_summary = spark.createDataFrame(summary_rows, schema)

detail_rows = []
for lvl, d in df_levels.items():
    for rc in rate_columns:
        cnt = d.filter((F.col(rc) < 0) | (F.col(rc) > 100)).count()
        detail_rows.append((run_ts, lvl, rc, cnt))

detail_schema = T.StructType([
    T.StructField("run_ts", T.StringType(), False),
    T.StructField("Level", T.StringType(), False),
    T.StructField("rate_column", T.StringType(), False),
    T.StructField("out_of_range_count", T.LongType(), False),
])
qa_rate_detail = spark.createDataFrame(detail_rows, detail_schema)

target_metric_names = [
    "All_DNA1","All_DNA2","All_FU1","All_FU2","All_FU3","All_FU4","All_FU5",
    "All_Total1","All_Total2","All_Total3","All_Total4","All_Total5","All_Total6",
    "Remote_Total1","Remote_Total2",
    "All_First_plus_All_First_DNA","All_FU_plus_All_FU_DNA","All_Total_plus_All_DNA",
    "F2F_Total_plus_F2F_DNA","Remote_Total_plus_Remote_DNA",
    "F2F_DNA","F2F_DNA_Over_F2F_Total"
]

qa_long_presence = None
if df_long is not None:
    lp = (df_long.filter(F.col("Level").isin("ICB","Region"))
          .groupBy("Level","OPRT_Metric_Name")
          .count()
          .withColumn("run_ts", F.lit(run_ts)))
    qa_long_presence = lp.select("run_ts","Level","OPRT_Metric_Name","count")

qa_unmatched = None
if unmatched is not None:
    qa_unmatched = (unmatched.groupBy()
                    .agg(F.countDistinct("OPRT_Metric_Name_TFC").alias("unmatched_metrics"))
                    .withColumn("run_ts", F.lit(run_ts))
                    .select("run_ts","unmatched_metrics"))

# ---------- 3) Write as Delta (for history) ----------
delta_base = "/mnt/output/_qa_delta"
(qa_summary.write.format("delta").mode("append").save(f"{delta_base}/summary_level"))
(qa_rate_detail.write.format("delta").mode("append").save(f"{delta_base}/rate_detail"))
if qa_long_presence is not None:
    (qa_long_presence.write.format("delta").mode("append").save(f"{delta_base}/long_presence"))
if qa_unmatched is not None:
    (qa_unmatched.write.format("delta").mode("append").save(f"{delta_base}/unmatched_summary"))

print("Appended all QA datasets as Delta.")

# ---------- 4) Export CSV snapshot to ADLS ----------
adls_csv_path = (
    "abfss://analytics-projects@udalstdataanalysisprod.dfs.core.windows.net/"
    "ElectiveRecovery/Projects/OP_QA_STP_Region_all_metrics_Audit.csv"
)

csv_summary = (
    qa_summary
    .withColumn("rates_out_of_range_any", (F.col("rates_out_of_range_total") > 0).cast("boolean"))
)

(csv_summary.coalesce(1)
 .write.mode("overwrite")
 .option("header", True)
 .csv(adls_csv_path))

print(f"Wrote CSV snapshot to {adls_csv_path}")
print("=== QA audit writer complete ===")

display(qa_summary.orderBy("Level"))



In [0]:
#20 Saving the file to the lake mart for QA (filtered for a small sample)
df_sample = df_with_id.filter(
    (F.col("Der_Activity_Month_Date") == "2025-07-31") & 
    (F.col("Adj_Org_Code_Final") == "RH8")
    #(F.col("ICB") == "QE1")
    #(F.col("Region_Code") == "Y59")
)
df_sample_count = df_sample.count()
#print(f"Number of rows in filtered sample: {df_sample_count}")

df_sample.coalesce(1).write \
    .format("csv") \
    .mode("overwrite") \
    .option("header", "true") \
    .save("abfss://analytics-projects@udalstdataanalysisprod.dfs.core.windows.net/ElectiveRecovery/Projects/OP_QA_STP_Region_new_metrics_short_Y59.csv")

In [0]:
#21 Saving the file to the lake mart for QA
df_with_id.coalesce(1).write \
    .format("csv") \
    .mode("overwrite") \
    .option("header", "true") \
    .save("abfss://analytics-projects@udalstdataanalysisprod.dfs.core.windows.net/ElectiveRecovery/Projects/OP_QA_STP_Region_new_metrics_new.csv")

#display(df_with_id.limit(10))
#print(f"Number of rows in df_with_id: {df_with_id.count()}")

In [0]:
#22 Saving the file to the lake mart for QA (filtered for a small sample)
opa_final_with_remote_benchmark.coalesce(1).write \
    .format("csv") \
    .mode("overwrite") \
    .option("header", "true") \
    .save("abfss://analytics-projects@udalstdataanalysisprod.dfs.core.windows.net/ElectiveRecovery/Projects/lower_bm_new.csv")

#display(opa_final_with_remote_benchmark.limit(10))
#print(f"Number of rows in opa_final_with_remote_benchmark: {opa_final_with_remote_benchmark.count()}")

In [0]:
#23 Saving the file to the lake mart for QA (filtered for a small sample)
dna_opp_df.coalesce(1).write \
    .format("csv") \
    .mode("overwrite") \
    .option("header", "true") \
    .save("abfss://analytics-projects@udalstdataanalysisprod.dfs.core.windows.net/ElectiveRecovery/Projects/DNA_Opportunities_New.csv")

#display(dna_opp_df.limit(10))
#print(f"Number of rows in dna_opp_df: {dna_opp_df.count()}")

In [0]:
#QA step Container 9a

from pyspark.sql import functions as F

id_cols = ["Der_Activity_Month_Date", "ICB", "Region_Code", "Adj Org Code"]
metric_cols = [c for c in opa_final_processed.columns if c not in id_cols + ["Treatment_Function_Group"]]

sum_by_groups = (
  opa_final_processed
  .filter(F.col("Treatment_Function_Group") != "All")
  .groupBy(*id_cols)
  .agg(*[F.sum(F.col(c)).alias(c + "_sum_groups") for c in metric_cols])
)

all_rows = (
  opa_final_processed
  .filter(F.col("Treatment_Function_Group") == "All")
  .select(*id_cols, *metric_cols)
  .toDF(*id_cols, *[c + "_all" for c in metric_cols])
)

cmp = sum_by_groups.join(all_rows, on=id_cols, how="inner")

mismatches = []
for c in metric_cols:
    mismatches.append(F.sum(F.when(F.col(c + "_sum_groups") != F.col(c + "_all"), 1).otherwise(0)).alias(c+"_neq"))

res = cmp.agg(*mismatches).collect()[0].asDict()
[ (k,v) for k,v in res.items() if v != 0 ]


In [0]:
#QA step Container 9b
bad = cmp.filter((F.col("All_Total_sum_groups") != F.col("All_Total_all"))).limit(50)
display(bad)


In [0]:
#QA step Container 9c
distinct_groups = opa_final_processed.select("Treatment_Function_Group").distinct().collect()
distinct_groups


In [0]:
#QA step Container 10a
pct_cols = [c for c in opa_final_with_added_metrics.columns if c.endswith("_Over_All_Total") or c.endswith("_Over_All_FU") or c.endswith("_Over_All_First") or c.endswith("_2WW") or c in ["F2F_DNA_Over_F2F_Total","Remote_DNA_Over_Remote_Total"]]

violations = (
  opa_final_with_added_metrics
  .select(*[F.sum(F.when((F.col(c) < 0) | (F.col(c) > 100), 1).otherwise(0)).alias(c) for c in pct_cols])
  .collect()[0].asDict()
)
[ (k,v) for k,v in violations.items() if v != 0 ]


In [0]:
#QA step Container 10b
checks = opa_final_with_added_metrics.select(
  F.sum(F.when(F.col("All_Total_plus_All_DNA") != (F.col("All_Total").cast("long")+F.col("All_DNA").cast("long")), 1).otherwise(0)).alias("all_att_plus_dna_mismatch"),
  F.sum(F.when(F.col("F2F_Total_plus_F2F_DNA") != (F.col("F2F_Total").cast("long")+F.col("F2F_DNA").cast("long")), 1).otherwise(0)).alias("f2f_att_plus_dna_mismatch"),
  F.sum(F.when(F.col("Remote_Total_plus_Remote_DNA") != (F.col("Remote_Total").cast("long")+F.col("Remote_DNA").cast("long")), 1).otherwise(0)).alias("remote_att_plus_dna_mismatch")
).collect()[0]
checks


In [0]:
#QA step container 11
id_cols = ["Der_Activity_Month_Date","Treatment_Function_Group","Adj Org Code","ICB","Region_Code"]
m_cols = [c for c in opa_final_with_added_metrics.columns if c not in id_cols]

wide_cells = opa_final_with_added_metrics.count() * len(m_cols)
long_rows  = opa_long_ordered.count()
(wide_cells, long_rows)


In [0]:
#QA step container 12
from pyspark.sql import functions as F

# Sample ICB parity check
sample_icb = (
  final_output
  .filter(F.col("Level")=="ICB")
  .select("Der_Activity_Month_Date","ICB","Treatment_Function_Group","All_DNA","All_Total","All_DNA_Over_All_Total")
  .withColumn("recalc", F.when((F.col("All_Total")+F.col("All_DNA"))!=0, (F.col("All_DNA")/(F.col("All_Total")+F.col("All_DNA")))*100))
  .withColumn("neq", F.when((F.abs(F.col("recalc")-F.col("All_DNA_Over_All_Total")) > 1e-9) | (F.col("recalc").isNull() ^ F.col("All_DNA_Over_All_Total").isNull()), 1).otherwise(0))
)

sample_icb.agg(F.sum("neq")).show()


In [0]:
#QA step container 13
id_cols = ["Der_Activity_Month_Date","Region_Code","ICB","Adj_Org_Code_Final","Treatment_Function_Group","Level"]
metric_cols = [c for c in final_output.columns if c not in id_cols]  # final_output is wide before melt in C13

# If final_output has already been overwritten by long, reload the C12 parquet:
# df_wide_c12 = spark.read.parquet("/mnt/output/opa_final_all_levels")  # assuming you switch to Delta/Parquet path used in C12
# metric_cols = [c for c in df_wide_c12.columns if c not in id_cols]

# Compare cell counts (if you have both wide and long at hand)
# wide_cells = df_wide_c12.count() * len(metric_cols)
# long_rows  = opa_oprt_long.count()
# (wide_cells, long_rows)


In [0]:
#QA step container 14a
# Left-joined version produced earlier:
df_left = spark.read.parquet("/mnt/output/unmatched_oprt_metrics.parquet")  # written by C14
unmatched_cnt = df_left.distinct().count()
unmatched_cnt


In [0]:
#QA step container 14b
from pyspark.sql.types import DoubleType, DateType, StringType
df_with_id.printSchema()  # confirm Date/Double/String as intended


In [0]:
#QA step container 15
from pyspark.sql import functions as F

probe = (
  opa_final_with_remote_benchmark
  .select("Der_Activity_Month_Date","Adj Org Code","Remote_Total","Remote_Lower_Benchmark")
  .withColumn("same", F.when(F.col("Remote_Total") == F.col("Remote_Lower_Benchmark"), 1).otherwise(0))
)

probe.agg(F.sum("same").alias("same_cnt"), F.count("*").alias("rows")).show()


In [0]:
#QA step container 16a
from pyspark.sql import functions as F

dna_checks = dna_opp_df.select(
  F.sum(F.when(F.col("DNA_Opportunity") < 0, 1).otherwise(0)).alias("negatives"),
  F.sum(F.when(F.col("DNA_Opportunity").isNotNull() & (F.col("DNA_Opportunity") != F.col("DNA_Opportunity").cast("int")), 1).otherwise(0)).alias("non_ints")
).collect()[0].asDict()
dna_checks


In [0]:
#QA step container 16b
dna_opp_df.groupBy("DNA_Opportunity_reduction").count().orderBy("DNA_Opportunity_reduction").show(truncate=False)


In [0]:
#QA step final wide table
id_cols = ["Der_Activity_Month_Date","Treatment_Function_Group","Region_Code","ICB","Adj Org Code"]
dupes = (
  opa_final_processed.groupBy(*id_cols)
  .count()
  .filter("count > 1")
  .count()
)
dupes  # expect 0


In [0]:
#QA step final long table
id_cols = ["Der_Activity_Month_Date","Adj_Org_Code_Final","Level","Treatment_Function_Group","InternalID"]
dupes = (
  df_with_id.groupBy(*id_cols)
  .count()
  .filter("count > 1")
  .count()
)
dupes  # expect 0


In [0]:
#QA step voulme reconciliation v source
from pyspark.sql import functions as F

# After all filters in C9
src = (
  opa_with_groups
  .filter(
    (F.col("Der_Financial_Year").isin("2023/24","2024/25","2025/26")) &
    (F.col("Administrative_Category")=="01") &
    (F.col("Treatment_Function_Code")!="812") &
    (F.col("First_Attendance").isin("1","2","3","4"))
  )
)

att = src.filter(F.col("Attendance_Status").isin("5","6")).count()
dna = src.filter(F.col("Attendance_Status").isin("3","7")).count()

rollup = opa_final_with_added_metrics.select(
  F.sum("All_Total").alias("att_total"),
  F.sum("All_DNA").alias("dna_total")
).first()

(att, dna, rollup.att_total, rollup.dna_total)


What the QA steps aims are

Confidence that “All” equals the sum across TFC groups.

All percentages are bounded, denominators match the intended definitions.

Long/wide reshapes conserve information.

Aggregated levels recompute rates correctly.

Mapping to Internal IDs is complete (or you get a tidy list to fill).

A demonstration (not a change) that the current remote benchmark equals the org’s own value, so you can sign off current behavior before changing it later.

Opportunity numbers are non-negative integers with sensible band coverage.

Uniqueness at expected grains and volume reconciliation with the filtered source.

In [0]:
# import sys
# sys.path.append('/Workspace/Repos/MHS-analytics/MHS-analytics/Python_Packages')
# from mhs_import import MHS_IngestionHub
# from mhs_db_config import INSERT, DELETE

# def main():
#     max_date = df_sample.agg(F.max("reportingDate")).collect()[0][0]
#     desc = f"df_sample_{max_date}"
#     InjectionHub(df_sample, desc, True)
#     print("df_sample sent to InjectionHub")

# def InjectionHub(dfih, desc, tf):
#     sdf = dfih.withColumn('reportingDate', F.to_date('reportingDate'))
#     display(sdf.sample(False, 0.01))
#     display(sdf.dtypes)
#     MHS_IngestionHub.upload(
#         mhs_df=sdf,
#         description=desc,
#         loaded_by="steven.evans4@nhs.net",
#         mhs_mode=INSERT,
#         skip_existing_data_check=tf
#     )

# if __name__ == '__main__':
#     main()

In [0]:
#from pyspark.sql import Row

#col_names = opa_final_with_added_metrics.columns
#df_col_names = spark.createDataFrame([Row(Column_Name=c) for c in col_names])
#display(df_col_names)