# Here’s a complete PySpark pipeline to analyze and identify store closure due to severe weather using the available tables (salestransskufact, salestransorderfact, weatherdata, employee, customer, etc.) for All Entity

In [0]:
import numpy as np
import pandas as pd

In [0]:
query = """
SELECT *
FROM gold.salestransskufact
"""

df_sales = spark.sql(query)
display(df_sales)

In [0]:
columns_df = spark.createDataFrame([(col,) for col in df_sales.columns], ["Column Name"])
display(columns_df)

In [0]:
query = """
SELECT *
FROM gold.salestransorderfact
"""

df_sales_trans = spark.sql(query)
display(df_sales_trans)

print(df_sales_trans.columns)

In [0]:
from pyspark.sql.functions import to_date, col

# Convert OrderDate to date type and rename for join
df_sales_trans = df_sales_trans.withColumn("date", to_date(col("OrderDate")))


In [0]:
query = """
SELECT *
FROM gold.weatherdata_daily_historical_openmeteo_nws
"""

df_wethr = spark.sql(query)
display(df_wethr)

print(df_wethr.columns)

In [0]:
query = """
SELECT *
FROM gold.weatherdata_daily_forecast_nws
"""

df_forecasted = spark.sql(query)
display(df_forecasted)

print(df_forecasted.columns)

In [0]:
query = """
SELECT *
FROM gold.calendar
"""

df_calendar = spark.sql(query)
display(df_calendar)

print(df_calendar.columns)

In [0]:
query = """
SELECT *
FROM gold.product
"""

df_pr = spark.sql(query)
display(df_pr)

print(df_pr.columns)

# We can try to join gold.weatherdata and gold.weatherdata_forecasted right at the beginning — this will give us a unified view where each row contains both the actual and forecasted weather attributes per date, per businessEntity

In [0]:
from pyspark.sql.functions import col, to_date

df_actual = spark.sql("""
SELECT
    weather_date,
    weather_type,
    temperature,
    wind_speed,
    precipitation,
    businessentityid,
    businessentityname
FROM gold.weatherdata_daily_historical_openmeteo_nws
""").withColumn("weather_date", to_date("weather_date"))

df_forecasted = spark.sql("""
SELECT
    weather_date,
    weather_type_forecast,
    temperature_forecast,
    wind_speed_forecast,
    precipitation_forecast,
    businessentityid,
    businessentityname
FROM gold.weatherdata_daily_forecast_nws
""").withColumn("weather_date", to_date("weather_date"))

df_weather_combined = df_actual.join(
    df_forecasted,
    on=["weather_date", "businessentityid", "businessentityname"],
    how="outer"
)

display(df_weather_combined)

# Identify Past "null" ORDER DAYS

In [0]:
from pyspark.sql.functions import to_date, countDistinct, current_date
from pyspark.sql.types import DateType
from datetime import timedelta

# Extract distinct dates from calendar table for the last 24 months and filter out future dates
calendar_range = df_calendar.select("Date").filter("Date >= date_sub(current_date(), 730) AND Date <= current_date()")

# Convert OrderDate to date
df_sales_date = df_sales.withColumn("SalesDate", to_date("OrderDate"))

# Get actual sales dates per Site
sales_activity = df_sales_date.groupBy("SiteID", "SalesDate").agg(
    countDistinct("OrderID").alias("TotalOrders")
)

# Join calendar with sales to find dates with no orders
closure_candidates = calendar_range.crossJoin(df_sales.select("SiteID").distinct()).withColumnRenamed("Date", "SalesDate") \
    .join(sales_activity, on=["SiteID", "SalesDate"], how="left") \
    .filter("TotalOrders IS NULL")

display(closure_candidates)

# Now Let's CHECK The Weather Conditions On The Same Dates

In [0]:
from pyspark.sql.functions import to_date, countDistinct, current_date
from pyspark.sql.types import DateType
from datetime import timedelta

# Extract distinct dates from calendar table for the last 24 months and filter out future dates
calendar_range = df_calendar.select("Date").filter("Date >= date_sub(current_date(), 730) AND Date <= current_date()")

# Convert OrderDate to date
df_sales_date = df_sales.withColumn("SalesDate", to_date("OrderDate"))

# Get actual sales dates per Site
sales_activity = df_sales_date.groupBy("SiteID", "SalesDate").agg(
    countDistinct("OrderID").alias("TotalOrders")
)

# Join calendar with sales to find dates with no orders
closure_candidates = calendar_range.crossJoin(df_sales.select("SiteID").distinct()).withColumnRenamed("Date", "SalesDate") \
    .join(sales_activity, on=["SiteID", "SalesDate"], how="left") \
    .filter("TotalOrders IS NULL")

# Add weather_type to the resulting dates
closure_with_weather = closure_candidates.join(
    df_wethr.selectExpr("weather_date as SalesDate", "weather_type"),
    on="SalesDate",
    how="left"
)

display(closure_with_weather)

# The Prev Table Shows All Hourls weather_type Which Is Difficult To Interpret For The Daily Weather. Hence, Let's RANK The weather_type Based On It's Severity And Then Get The Overall Daily Weather For The Same Dates As In The Above Table

In [0]:
from pyspark.sql.functions import to_date, countDistinct, current_date, col, when, max as F_max

# Extract last 24 months from calendar and filter out future dates
calendar_range = df_calendar.select("Date").filter("Date >= date_sub(current_date(), 730) AND Date <= current_date()")

# Convert OrderDate to date
df_sales_date = df_sales.withColumn("SalesDate", to_date("OrderDate"))

# Get distinct SiteID + businessEntityId + businessEntityName
distinct_sites = df_sales.select("SiteID", "businessEntityId", "businessEntityName").distinct()

# Get actual sales dates per Site
sales_activity = df_sales_date.groupBy("SiteID", "SalesDate").agg(
    countDistinct("OrderID").alias("TotalOrders")
)

# Join calendar with site list (with businessEntity info), then join to sales_activity
closure_candidates = (
    calendar_range.crossJoin(distinct_sites)
    .withColumnRenamed("Date", "SalesDate")
    .join(sales_activity, on=["SiteID", "SalesDate"], how="left")
    .filter("TotalOrders IS NULL")
)

# Severity mapping for weather types
weather_severity = {
    "Clear": 1,
    "Partly cloudy": 2,
    "Overcast": 3,
    "Patchy moderate snow": 4,
    "Moderate snow": 5,
    "Ice pellets": 6,
    "Moderate or heavy snow showers": 7,
    "Blowing snow": 8,
    "Heavy snow": 9,
    "Moderate or heavy rain with thunder": 10
}

# Add severity ranking to weather data
df_wethr_with_severity = df_wethr.withColumn(
    "severity",
    when(col("weather_type") == "Clear", 1)
    .when(col("weather_type") == "Partly cloudy", 2)
    .when(col("weather_type") == "Overcast", 3)
    .when(col("weather_type") == "Patchy moderate snow", 4)
    .when(col("weather_type") == "Moderate snow", 5)
    .when(col("weather_type") == "Ice pellets", 6)
    .when(col("weather_type") == "Moderate or heavy snow showers", 7)
    .when(col("weather_type") == "Blowing snow", 8)
    .when(col("weather_type") == "Heavy snow", 9)
    .when(col("weather_type") == "Moderate or heavy rain with thunder", 10)
)

# Most severe weather for each date
df_weather_daily = df_wethr_with_severity.groupBy("weather_date").agg(
    F_max("severity").alias("max_severity")
).alias("a").join(
    df_wethr_with_severity.alias("b"),
    (col("a.weather_date") == col("b.weather_date")) & (col("a.max_severity") == col("b.severity")),
    how="left"
).select("a.weather_date", "b.weather_type").distinct()

# Add weather_type to closure candidates
closure_with_weather = closure_candidates.join(
    df_weather_daily.selectExpr("weather_date as SalesDate", "weather_type"),
    on="SalesDate",
    how="left"
)

display(closure_with_weather)


# # STEP 1: Prepare the Historical Weather + Store Closure Dataset

We’ll build a labeled dataset where:
	•	1 = store was closed
	•	0 = store was open

In [0]:
from pyspark.sql.functions import to_date, countDistinct, col, when

# Get all possible SiteID/date combinations (calendar × sites)
calendar_range = df_calendar.select(to_date("Date").alias("SalesDate"))
sites = df_sales.select("SiteID").distinct()
all_site_dates = calendar_range.crossJoin(sites)

# Aggregate sales activity: how many orders per site/day
df_sales_with_date = df_sales.withColumn("SalesDate", to_date("OrderDate"))
sales_activity = df_sales_with_date.groupBy("SiteID", "SalesDate").agg(
    countDistinct("OrderID").alias("TotalOrders")
)

# Label data: 1 = closed (no orders that day), 0 = open (≥1 order)
labeled_data = all_site_dates.join(
    sales_activity, on=["SiteID", "SalesDate"], how="left"
).withColumn(
    "StoreClosedLabel", when(col("TotalOrders").isNull(), 1).otherwise(0)
)


In [0]:
site_business_map = (
    df_sales
    .select("SiteID", "businessEntityId", "businessEntityName")
    .distinct()
)


In [0]:
labeled_with_business = labeled_data.join(site_business_map, on="SiteID", how="left")

In [0]:
display(labeled_with_business.orderBy("SalesDate", "SiteID"))

In [0]:
display(df_wethr.select("weather_type").distinct())

In [0]:
from pyspark.sql.functions import col, when, to_date, row_number
from pyspark.sql.window import Window

# Start from your weather DataFrame: df_wethr
df_wethr = df_wethr.withColumn("SalesDate", to_date(col("weather_date")))

# Custom mapping for your actual weather types
df_wethr = df_wethr.withColumn(
    "severity",
    when(col("weather_type").isin("Clear", "Mostly Clear", "Mostly sunny", "Clear and Windy"), 1)
    .when(col("weather_type").isin("Partly cloudy", "Partly Cloudy", "Mostly Cloudy", "Cloudy"), 2)
    .when(col("weather_type").isin("Haze", "Fog/Mist", "Unknown"), 3)
    .when(col("weather_type").isin("Light Rain", "Light rain", "Light Rain and Fog/Mist"), 4)
    .when(col("weather_type") == "Rain", 5)
    .when(col("weather_type").isin("Thunderstorms and Rain", "Thunderstorms"), 6)
    .when(col("weather_type") == "Heavy Thunderstorms and Heavy Rain", 7)
    .when(col("weather_type").isin("Snow", "Light snow"), 8)
    .otherwise(0)
)


In [0]:
window_spec = Window.partitionBy("SalesDate").orderBy(col("severity").desc())
weather_most_severe = (
    df_wethr.withColumn("row_num", row_number().over(window_spec))
            .filter("row_num = 1")
            .select("SalesDate", "weather_type", "severity")
)


In [0]:
df_final_labeled = labeled_with_business.join(
    weather_most_severe,
    on="SalesDate",
    how="left"
).select(
    "SalesDate",
    "SiteID",
    "businessEntityId",
    "businessEntityName",      # <-- NEW!
    "StoreClosedLabel",
    "TotalOrders",
    "weather_type",
    "severity"
)


In [0]:
from pyspark.sql.functions import to_date, countDistinct, col, when, current_date, max as F_max, row_number
from pyspark.sql.window import Window

# 1. Calendar range
calendar_range = df_calendar.select(to_date("Date").alias("SalesDate")) \
    .filter("SalesDate >= date_sub(current_date(), 730) AND SalesDate <= current_date()")

sites = df_sales.select("SiteID").distinct()
all_site_dates = calendar_range.crossJoin(sites)

# 2. Sales with dates
df_sales_with_date = df_sales.withColumn("SalesDate", to_date("OrderDate"))

sales_activity = df_sales_with_date.groupBy("SiteID", "SalesDate").agg(
    countDistinct("OrderID").alias("TotalOrders")
)

# 3. Label closure
labeled_data = all_site_dates.join(
    sales_activity, on=["SiteID", "SalesDate"], how="left"
).withColumn(
    "StoreClosedLabel", when(col("TotalOrders").isNull(), 1).otherwise(0)
)

# --- Add business entity info ---
site_business_map = (
    df_sales
    .select("SiteID", "businessEntityId", "businessEntityName")
    .distinct()
)
labeled_with_business = labeled_data.join(site_business_map, on="SiteID", how="left")

# 4. Weather severity
weather_severity = {
    "Clear": 1, "Partly cloudy": 2, "Overcast": 3, "Mist": 4,
    "Patchy moderate snow": 5, "Moderate snow": 6, "Ice pellets": 7,
    "Moderate or heavy snow showers": 8, "Blowing snow": 9, "Heavy snow": 10,
    "Moderate rain": 11, "Moderate or heavy rain with thunder": 12,
    "Heavy rain": 13, "Thunderstorm": 14
}

df_wethr_ranked = df_wethr.withColumn("SalesDate", to_date("weather_date"))
severity_col = when(col("weather_type").isNull(), None)
for wt, sev in weather_severity.items():
    severity_col = severity_col.when(col("weather_type") == wt, sev)
severity_col = severity_col.otherwise(0)
df_wethr_ranked = df_wethr_ranked.withColumn("severity", severity_col)

# 5. Most severe weather per day
window_spec = Window.partitionBy("SalesDate").orderBy(col("severity").desc())
weather_most_severe = df_wethr_ranked.withColumn(
    "row_num", row_number().over(window_spec)
).filter("row_num = 1").select("SalesDate", "weather_type", "severity")

# 6. Final join
df_final_labeled = labeled_with_business.join(
    weather_most_severe,
    on="SalesDate",
    how="left"
).select(
    "SalesDate",
    "SiteID",
    "businessEntityId",
    "businessEntityName",      # <-- Now included!
    "StoreClosedLabel",
    "TotalOrders",
    "weather_type",
    "severity"
)

# Show result
display(df_final_labeled.orderBy("SalesDate", "SiteID"))


In [0]:
from pyspark.sql import functions as F

display(df_final_labeled.orderBy("SalesDate", "SiteID"))

# Check for NULLs
null_counts = df_final_labeled.select(
    [F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in df_final_labeled.columns]
)
display(null_counts)

In [0]:
from pyspark.sql import functions as F

# Remove rows where SiteID is NULL
df_final_labeled = df_final_labeled.filter(F.col("SiteID").isNotNull())

# Fill NA for TotalOrders
df_final_labeled = df_final_labeled.fillna({"TotalOrders": 0})

# Standardize any remaining NAs
df_final_labeled = df_final_labeled.na.fill("Unknown")

display(df_final_labeled.orderBy("SalesDate", "SiteID"))

# Check for NULLs
null_counts = df_final_labeled.select(
    [F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in df_final_labeled.columns]
)
display(null_counts)

# END

In [0]:
##Analyze Closure Patterns(Count Entity By Closure)

display(df_final_labeled.groupBy("businessEntityName", "StoreClosedLabel")
         .count().orderBy("businessEntityName", "StoreClosedLabel"))
display(df_final_labeled.groupBy("weather_type", "StoreClosedLabel")
         .count().orderBy("weather_type", "StoreClosedLabel"))


# My **aggregated closure summary** is **FINE for reporting and analysis**—you now see, by business and by weather type, where closures (label = 1) are more frequent. Here’s how to **interpret and use this information**—plus, what to do next:

---

## **How to Interpret This Table**

### **1. By Business Entity**

* **Most entities (e.g., "Namak Mirch Grocers", "Lulu's Diner II", etc.) are open 99.8%+ of the time** (730 open, 1 closed = one closure per 2 years, likely during severe weather).
* **Entities like ParadiseGifts4/5/7 and SubParadiseGifts1 have much higher closure rates** (e.g., 409–422 closed days vs. 309–1053 open), suggesting either:

  * **Less consistent sales reporting**
  * **A store that’s open seasonally or with many holidays**
  * **Data issues or special business rules**
* **"OVVI Automation"** has many sites/entries (8030 open, 11 closed)—likely a test/demo or multi-location business.

---

### **2. By Weather Type**

* **Most closures happen on clear or partly cloudy days!**

  * **Interpretation:**

    * **Closures aren’t only weather-related.**
    * May reflect holidays, low season, or missing/erroneous weather data.
    * Need to check if “closed” labels are sometimes due to other factors (e.g., holidays, off-season).
* **Higher closure rates in “Rain”, “Light rain”, “Cloudy”, “Mostly sunny”**: Could indicate weather impact—but not exclusively.
* **Very few closures on “Thunderstorms”, “Light snow”, etc.:**

  * **Possible reasons:**

    * Extreme weather is rare in your data’s geography.
    * The store closes for other reasons more often than for truly severe weather.
    * Or: Weather reporting might not always match the true local conditions.


In [0]:
## Cross Tab Entity By weather_type And Closure
# Count closures (StoreClosedLabel = 1) by business and weather
df_bad_weather = df_final_labeled.filter(col("StoreClosedLabel") == 1)

display(
    df_bad_weather.groupBy("businessEntityName", "weather_type")
        .count()
        .orderBy("businessEntityName", "count", ascending=False)
)


In [0]:
from pyspark.sql.functions import dayofweek

# Add day_of_week (1=Sunday, 7=Saturday)
df_final_labeled = df_final_labeled.withColumn(
    "day_of_week", dayofweek("SalesDate")
)


In [0]:
from pyspark.sql.functions import when

# Map day_of_week number to day name
df_final_labeled = (
    df_final_labeled.withColumn(
        "day_name",
        when(col("day_of_week") == 1, "Sunday")
        .when(col("day_of_week") == 2, "Monday")
        .when(col("day_of_week") == 3, "Tuesday")
        .when(col("day_of_week") == 4, "Wednesday")
        .when(col("day_of_week") == 5, "Thursday")
        .when(col("day_of_week") == 6, "Friday")
        .when(col("day_of_week") == 7, "Saturday")
        .otherwise("Unknown")
    )
)


In [0]:
## Closed Vs Open Rate By weekday
display(
    df_final_labeled.groupBy("businessEntityName", "day_of_week", "StoreClosedLabel")
        .count()
        .orderBy("businessEntityName", "day_of_week", "StoreClosedLabel")
)


In [0]:
## With Day Nam
display(
    df_final_labeled.groupBy("businessEntityName", "day_name", "StoreClosedLabel")
        .count()
        .orderBy("businessEntityName", "day_name", "StoreClosedLabel")
)


In [0]:
##Closed Vs Open By Mont
from pyspark.sql.functions import month

df_final_labeled = df_final_labeled.withColumn(
    "month", month("SalesDate")
)

display(df_final_labeled)

In [0]:
##Join Calendar
from pyspark.sql.functions import col

df_final_labeled = (
    df_final_labeled
    .join(
        df_calendar.select(
            col("Date").alias("SalesDate"),
            "WeekDayName",
            "MonthName",
            "Month",
            "Year",
            "IsWeekend",
            "IsHoliday",
            "HolidayName"
        ),
        on="SalesDate",
        how="left"
    )
)


In [0]:
## Use Calendar Fat For More Group Analysis

display(
    df_final_labeled.groupBy("businessEntityName", "WeekDayName", "StoreClosedLabel")
        .count()
        .orderBy("businessEntityName", "WeekDayName", "StoreClosedLabel")
)


In [0]:
## Holiday Weekend Closure
display(
    df_final_labeled.groupBy("businessEntityName", "IsHoliday", "StoreClosedLabel")
        .count()
        .orderBy("businessEntityName", "IsHoliday", "StoreClosedLabel")
)
display(
    df_final_labeled.groupBy("businessEntityName", "IsWeekend", "StoreClosedLabel")
        .count()
        .orderBy("businessEntityName", "IsWeekend", "StoreClosedLabel")
)


# Interpretatio
## Great! These two tables tell me a **lot** about which business entities are prone to closing on holidays and weekends vs. regular days.

---

## **What Can We Interpret?**

### **A. Holiday Closure Patterns**

* **Most “classic” restaurants** (Altoona, Carniceria Sonora, Demo, Lulu's, Namak Mirch, Green Bawarchi Katy) **rarely close on holidays**—holiday rows show almost exclusively 0’s (open), with maybe one or two closures across all years.
* **“Paradise” and “OVVI Automation” chains** have a much higher closure rate on holidays:

  * ParadiseGift5, ParadiseGifts3/4/7, SubParadiseGifts1 all show **nonzero closure counts on holidays**. For example, ParadiseGift5 has 26 closures on holidays, and even on non-holidays, closure counts are high (e.g., 396 closures vs 299 open).
  * These stores are generally **much more likely to close**—potentially due to operational policy, seasonality, or being “gift”/specialty shops.

### **B. Weekend Closure Patterns**

* **Again, classic restaurants and grocers are open on nearly all weekends** (very few closures with IsWeekend = 1).
* **Paradise and SubParadise** shops, plus some OVVI units, show a significant number of closures on weekends (compare open vs. closed counts for IsWeekend = 1).
* **Exceptionally high closure rates** for “Paradise” entities on both weekends and holidays compared to the food businesses.

---
## **Summary Table Example**

| Entity              | Holiday Closure Rate | Weekend Closure Rate | Non-Holiday/Weekend Rate |
| ------------------- | -------------------- | -------------------- | ------------------------ |
| Namak Mirch Grocers | Very Low             | Very Low             | Very Low                 |
| ParadiseGift5       | High                 | High                 | High                     |


# NEW

# Rule BASED Store Closure Basic

# 1. Ingest & Materialize a “Master” DataFrame

Join all six tables on the date dimension so that every row is one calendar‐date for “Namak Mirch Grocers,” with sales, orders, weather (hist + forecast), and calendar attribute

In [0]:
from pyspark.sql import SparkSession, functions as F, types as T

spark = SparkSession.builder.appName("StoreClosureSevereWeather").getOrCreate()

# 1. Load tables filtered to Namak Mirch Grocers
sales_line = (
    spark.table("gold.salestransskufact")
         .filter(F.col("businessEntityName") == "Namak Mirch Grocers")
         .select(
             F.to_date("OrderDate").alias("date"),
             F.col("SalesQuantity"),
             F.col("NetSalesPrice")
         )
)

sales_order = (
    spark.table("gold.salestransorderfact")
         .filter(F.col("businessEntityName") == "Namak Mirch Grocers")
         .select(
             F.to_date("OrderDate").alias("date"),
             F.col("TotalSalesQuantity"),
             F.col("TotalSalesAmount")
         )
)

weather_hist = (
    spark.table("gold.weatherdata_daily_historical_openmeteo_nws")
         .filter(F.col("businessEntityName") == "Namak Mirch Grocers")
         .select(
             F.to_date("weather_date").alias("date"),
             "weather_type", "precipitation", "wind_speed"
         )
)

weather_forecast = (
    spark.table("gold.weatherdata_daily_forecast_nws")
         .filter(F.col("businessEntityName") == "Namak Mirch Grocers")
         .select(
             F.to_date("weather_date").alias("date"),
             F.col("weather_type_forecast").alias("weather_type"),
             F.col("precipitation_forecast").alias("precipitation"),
             F.col("wind_speed_forecast").alias("wind_speed")
         )
)

calendar = spark.table("gold.calendar").select(
    F.to_date("Date").alias("date"),
    "IsWeekend", "IsHoliday", "WeekDayName", "MonthName", "Year"
)

products = spark.table("gold.product") \
                .filter(F.col("businessEntityName") == "Namak Mirch Grocers") \
                .select("ProductID", "ProductName", "activeFlag")


# 2. Aggregate daily activity
daily_sales = (
    sales_line.groupBy("date")
        .agg(
            F.sum("SalesQuantity").alias("total_units_sold"),
            F.sum("NetSalesPrice").alias("total_revenue")
        )
)

daily_orders = (
    sales_order.groupBy("date")
        .agg(
            F.sum("TotalSalesQuantity").alias("order_units"),
            F.sum("TotalSalesAmount").alias("order_revenue")
        )
)

daily_activity = (
    calendar
      .join(daily_sales, "date", "left")
      .join(daily_orders, "date", "left")
      .na.fill(0, subset=["total_units_sold", "total_revenue", "order_units", "order_revenue"])
)


# 3. Define UDF to flag severe weather
severe_types = F.array(
    *[F.lit(w) for w in ["Thunderstorm", "Heavy Rain", "Storm", "Hurricane", "Snow", "Extreme"]]
)
flag_severe = (
    (F.col("precipitation") >= F.lit(5.0))        # ≥5 mm rain
    | (F.col("wind_speed") >= F.lit(20.0))         # ≥20 m/s wind
    | F.array_contains(severe_types, F.col("weather_type"))
)


# 4. Build historical and forecast views with severe‐weather flag
weather_h_flag = (
    weather_hist
      .withColumn("is_severe", flag_severe.cast("boolean"))
      .select("date", "is_severe")
)

weather_f_flag = (
    weather_forecast
      .withColumn("is_severe", flag_severe.cast("boolean"))
      .select("date", "is_severe")
)


# 5. Identify actual closures: no sales AND severe historical weather
closures_historical = (
    daily_activity
      .join(weather_h_flag, "date", "left")
      .filter(
          (F.col("total_revenue") == 0)  # no revenue → no open transactions
          & (F.col("is_severe") == True) # flagged severe weather
      )
      .select(
          "date", "WeekDayName", "MonthName", "Year",
          "IsWeekend", "IsHoliday",
          F.col("is_severe").alias("severe_weather"),
          F.col("total_revenue"),
          F.col("order_revenue")
      )
      .orderBy("date")
)


# 6. Identify forecasted closure risks: forecast severe AND historical low baseline
#    e.g. if past same weekday or month typically low sales
#    Here we simply mark days with forecast severe weather and zero forecasted baseline
closures_forecast = (
    daily_activity
      .join(weather_f_flag, "date", "left")
      .filter(F.col("is_severe") == True)
      .select(
          "date", "WeekDayName", "MonthName", "Year",
          "IsWeekend", "IsHoliday",
          F.col("is_severe").alias("forecast_severe"),
          "total_revenue", "order_revenue"
      )
      .orderBy("date")
)


# 7. Show results
print("=== Historical Store Closures Attributable to Severe Weather ===")
display(closures_historical)

print("=== Forecasted Store Closure Risk Due to Severe Weather ===")
display(closures_forecast)

In [0]:
sales_order = (
    spark.table("gold.salestransorderfact")
         .select(
             F.to_date("OrderDate").alias("date"),
             F.col("TotalSalesQuantity"),
             F.col("TotalSalesAmount"),
             "businessEntityId",
             "businessEntityName"
         )
)


In [0]:
# Verify the column names in the sales_order DataFrame
sales_order_columns = sales_order.columns
print(sales_order_columns)

# Mapping of businessEntityId/name to each date (and possibly SiteID if needed)
entity_map = (
    sales_order
    .select("date", "businessEntityId", "businessEntityName")
    .distinct()
)

display(entity_map)

In [0]:
print(spark.table("gold.weatherdata_daily_historical_openmeteo_nws").columns)
print(spark.table("gold.salestransskufact").columns)
print(spark.table("gold.salestransorderfact").columns)

In [0]:
from pyspark.sql.functions import to_date, countDistinct, col, when, current_date, max as F_max
from pyspark.sql.window import Window

# 1. Prepare calendar dates (for last 24 months)
calendar_range = (
    spark.table("gold.calendar")
    .select(to_date("Date").alias("date"))
    .filter("date >= date_sub(current_date(), 730) AND date <= current_date()")
)

# 2. Get all (businessEntityId, businessEntityName, SiteID) pairs
sites = (
    spark.table("gold.salestransskufact")
    .select("businessEntityId", "businessEntityName", "SiteID")
    .distinct()
)

# 3. Get all combinations (cross join)
all_site_dates = (
    calendar_range.crossJoin(sites)
)

# 4. Get sales activity per site/date
sales = (
    spark.table("gold.salestransskufact")
    .withColumn("date", to_date("OrderDate"))
)

sales_activity = sales.groupBy(
    "businessEntityId", "businessEntityName", "SiteID", "date"
).agg(
    countDistinct("OrderID").alias("TotalOrders")
)

# 5. Label store closures (1 = closed, 0 = open)
labeled_data = (
    all_site_dates.join(
        sales_activity,
        on=["businessEntityId", "businessEntityName", "SiteID", "date"],
        how="left"
    )
    .withColumn("StoreClosedLabel", when(col("TotalOrders").isNull(), 1).otherwise(0))
)

# 6. Prepare weather: best daily weather per business entity
weather = (
    spark.table("gold.weatherdata_daily_historical_openmeteo_nws")
    .select(
        "businessEntityId", "businessEntityName",
        to_date("weather_date").alias("date"),
        "weather_type", "precipitation", "wind_speed"
    )
)

# Add severity ranking (optional)
weather_severity = {
    "Clear": 1, "Partly cloudy": 2, "Overcast": 3, "Mist": 4, "Patchy moderate snow": 5,
    "Moderate snow": 6, "Ice pellets": 7, "Moderate or heavy snow showers": 8,
    "Blowing snow": 9, "Heavy snow": 10, "Moderate rain": 11,
    "Moderate or heavy rain with thunder": 12, "Heavy rain": 13, "Thunderstorm": 14
}
# UDF to map type to severity (or use a when-chain as before)
from pyspark.sql.functions import create_map, lit
from itertools import chain

# ... inside your code after defining weather_severity:
mapping_expr = create_map([lit(x) for x in chain.from_iterable(weather_severity.items())])
weather = weather.withColumn("severity", mapping_expr[col("weather_type")])


# For each entity/date, pick the worst (max) severity and its weather_type
window = Window.partitionBy("businessEntityId", "businessEntityName", "date").orderBy(col("severity").desc())
weather_most_severe = (
    weather.withColumn("row_num", F.row_number().over(window))
           .filter("row_num = 1")
           .select("businessEntityId", "businessEntityName", "date", "weather_type", "severity")
)

# 7. Join weather to labeled data
df_final = (
    labeled_data.join(
        weather_most_severe,
        on=["businessEntityId", "businessEntityName", "date"],
        how="left"
    )
    .select("date", "businessEntityId", "businessEntityName", "SiteID", "StoreClosedLabel", "TotalOrders", "weather_type", "severity")
    .orderBy("date", "businessEntityId", "SiteID")
)

display(df_final)


In [0]:
calendar = (
    spark.table("gold.calendar")
         .select(
             F.to_date("Date").alias("date"),
             "IsWeekend", "IsHoliday", "WeekDayName", "MonthName", "Year"
         )
)

df_final_labeled = (
    df_final.join(calendar, on="date", how="left")
)


In [0]:
from pyspark.sql.functions import count, when, col

null_counts = df_final_labeled.select(
    [count(when(col(c).isNull(), c)).alias(c) for c in df_final_labeled.columns]
)

display(null_counts)

In [0]:
df_final_labeled = df_final_labeled.fillna({'severity': 0})

In [0]:
df_final_labeled = df_final_labeled.fillna({'weather_type': 'Unknown'})

In [0]:
df_final_labeled = df_final_labeled.filter(col("SiteID").isNotNull())

# Compute Closure Rates
We’ll calculate closure rate per entity across:

Weekdays

IsHoliday

IsWeekend

In [0]:
##Closure Rate by WeekDayName

from pyspark.sql.functions import count, when, col

weekday_closure_rate = (
    df_final_labeled.groupBy("businessEntityName", "WeekDayName")
    .agg(
        count("*").alias("total_days"),
        count(when(col("StoreClosedLabel") == 1, True)).alias("closed_days")
    )
    .withColumn("closure_rate", col("closed_days") / col("total_days"))
)

display(weekday_closure_rate)


In [0]:
#Closure Rate by IsHoliday
holiday_closure_rate = (
    df_final_labeled.groupBy("businessEntityName", "IsHoliday")
    .agg(
        count("*").alias("total_days"),
        count(when(col("StoreClosedLabel") == 1, True)).alias("closed_days")
    )
    .withColumn("closure_rate", col("closed_days") / col("total_days"))
)

display(holiday_closure_rate)


In [0]:
#Closure Rate by IsWeekend
weekend_closure_rate = (
    df_final_labeled.groupBy("businessEntityName", "IsWeekend")
    .agg(
        count("*").alias("total_days"),
        count(when(col("StoreClosedLabel") == 1, True)).alias("closed_days")
    )
    .withColumn("closure_rate", col("closed_days") / col("total_days"))
)

display(weekend_closure_rate)


Excellent! You’ve now successfully computed **closure rates** by:

* `WeekDayName`
* `IsHoliday`
* `IsWeekend`

---

## ✅ **Key Findings (Interpretation)**

### 🏪 1. **Classic Retailers & Restaurants**

Entities like **Namak Mirch Grocers**, **Demo Restaurant**, **Carniceria Sonora #4**, **Green Bawarchi Katy**:

* Closure rates on weekends and holidays are **very low** (mostly < 5%)
* These stores are clearly **operational almost daily**, with minimal exceptions.

---

### 🎁 2. **Gift Stores (Paradise*, SubParadise*)\*\*

Entities like **ParadiseGift5**, **ParadiseGifts3/4/7**, **SubParadiseGifts1**:

* Show **extremely high closure rates**, sometimes up to **99.5%** on weekends/holidays
* Also notable weekday closures, especially on **Mondays, Tuesdays, Fridays**
* Behavior suggests **deliberate, possibly seasonal/weekend-based closure policies**

---

### 🏢 3. **OVVI Automation**

* Has **100% closure rate** across all days — possibly:

  * A non-retail automation/testing entity
  * Or a data quality error (check if these represent actual store sites)

---

## 📊 Summary Table Snapshot

| Entity              | Holiday Closure | Weekend Closure | Weekday (avg) Closure |
| ------------------- | --------------- | --------------- | --------------------- |
| Namak Mirch Grocers | 0%              | \~0.4%          | \~1.9% max            |
| Green Bawarchi Katy | \~7%            | \~2%            | Mostly <2.5%          |
| ParadiseGift5       | 72%             | 99%             | \~38–41%              |
| SubParadiseGifts1   | 83%             | 99.2%           | \~68%+ on avg         |
| ParadiseGifts4      | 84.7%           | 99.7%           | \~69%+ on avg         |
| OVVI Automation     | 100%            | 100%            | 100% (all days)       |

---


Compute Closure Rate by IsHoliday x Weather_Type

In [0]:
holiday_weather_closure = (
    df_final_labeled.groupBy("businessEntityName", "IsHoliday", "weather_type")
    .agg(
        count("*").alias("total_days"),
        count(when(col("StoreClosedLabel") == 1, True)).alias("closed_days")
    )
    .withColumn("closure_rate", col("closed_days") / col("total_days"))
    .orderBy("businessEntityName", "IsHoliday", "closure_rate", ascending=False)
)

display(holiday_weather_closure)


In [0]:
from pyspark.sql.functions import lower, regexp_replace, when

# Standardize weather_type
df = df_final_labeled.withColumn(
    "weather_type_cleaned",
    lower(regexp_replace("weather_type", "[-_]", " "))
)

# Example: group weather
df = df.withColumn(
    "weather_grouped",
    when(col("weather_type_cleaned").rlike("snow"), "Snow")
    .when(col("weather_type_cleaned").rlike("rain|thunderstorm"), "Rain")
    .when(col("weather_type_cleaned").rlike("clear|sunny"), "Clear")
    .when(col("weather_type_cleaned").rlike("cloudy"), "Cloudy")
    .otherwise("Other")
)


In [0]:
display(df)

In [0]:
## Add Binary Flags for Weather
from pyspark.sql.functions import when, col

df = df.withColumn("is_rain", when(col("weather_grouped") == "Rain", 1).otherwise(0)) \
       .withColumn("is_snow", when(col("weather_grouped") == "Snow", 1).otherwise(0)) \
       .withColumn("is_clear", when(col("weather_grouped") == "Clear", 1).otherwise(0)) \
       .withColumn("is_cloudy", when(col("weather_grouped") == "Cloudy", 1).otherwise(0)) \
       .withColumn("is_other_weather", when(col("weather_grouped") == "Other", 1).otherwise(0))

In [0]:
##Extract & Add Numerical Month (Optional)
from pyspark.sql.functions import month

df = df.withColumn("month_num", month(col("date")))

In [0]:
##Add Rolling Lag Features

from pyspark.sql.window import Window
from pyspark.sql.functions import lag, avg

window_spec = Window.partitionBy("businessEntityId").orderBy("date")

# Previous day's closure
df = df.withColumn("prev_day_closed", lag("StoreClosedLabel").over(window_spec))

# 3-day rolling average of closures
rolling_window = window_spec.rowsBetween(-3, -1)
df = df.withColumn("rolling_3day_closure_rate", avg("StoreClosedLabel").over(rolling_window))

In [0]:
## Optional - One-Hot Encode WeekDayName / MonthName
from pyspark.ml.feature import StringIndexer, OneHotEncoder

indexer = StringIndexer(inputCol="WeekDayName", outputCol="WeekDayIndex")
encoder = OneHotEncoder(inputCols=["WeekDayIndex"], outputCols=["WeekDayVec"])

# Fit + transform in pipeline if needed

In [0]:
display(df)

In [0]:
#Add both StringIndexers to your DataFrame
# Step 1: StringIndexers
from pyspark.ml import Pipeline

indexers = [
    StringIndexer(inputCol="WeekDayName", outputCol="WeekDayName_index", handleInvalid="keep"),
    StringIndexer(inputCol="MonthName", outputCol="MonthName_index", handleInvalid="keep")
]

pipeline = Pipeline(stages=indexers)
df = pipeline.fit(df).transform(df)

In [0]:
#Now re-check for nulls

# Redefine feature_cols
feature_cols = [
    "TotalOrders",
    "IsWeekend",
    "IsHoliday",
    "month_num",
    "prev_day_closed",
    "rolling_3day_closure_rate",
    "is_rain",
    "is_snow",
    "is_clear",
    "is_cloudy",
    "is_other_weather",
    "WeekDayName_index",
    "MonthName_index"
]

# Check for nulls in selected features
df.select([
    col(c).isNull().alias(c + "_isNull") for c in feature_cols
]).summary("count").show()

In [0]:
print(df.columns)

In [0]:
# Skip re-creation if columns exists
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

indexers = []

if "WeekDayName_index" not in df.columns:
    indexers.append(StringIndexer(inputCol="WeekDayName", outputCol="WeekDayName_index", handleInvalid="keep"))

if "MonthName_index" not in df.columns:
    indexers.append(StringIndexer(inputCol="MonthName", outputCol="MonthName_index", handleInvalid="keep"))

if indexers:
    pipeline = Pipeline(stages=indexers)
    df = pipeline.fit(df).transform(df)

In [0]:
df.select([
    col(c).isNull().alias(c + "_isNull") for c in feature_cols
]).summary("count").show()

In [0]:
print(feature_cols)

In [0]:
missing_cols = [c for c in feature_cols if c not in df.columns]
print("Missing columns:", missing_cols)

In [0]:
from pyspark.sql.functions import col

df.select([
    col(c).isNull().alias(f"{c}_isNull") for c in feature_cols
]).summary("count").show(truncate=False)

In [0]:
from pyspark.sql.functions import col, when, isnan, count
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml import Pipeline

# --- 1. Cast all flag columns to integer ---
for flag_col in ["IsHoliday", "IsWeekend"]:
    df = df.withColumn(
        flag_col,
        when(
            (col(flag_col) == True) | 
            (col(flag_col) == "true") | 
            (col(flag_col) == "True") | 
            (col(flag_col) == 1), 1
        ).otherwise(0).cast("int")
    )

# --- 2. Fill missing values for ALL potential feature columns ---
feature_cols = [
    "TotalOrders",
    "IsWeekend",
    "IsHoliday",
    "month_num",
    "prev_day_closed",
    "rolling_3day_closure_rate",
    "is_rain",
    "is_snow",
    "is_clear",
    "is_cloudy",
    "is_other_weather",
    "WeekDayName_index",
    "MonthName_index"
]

# Create comprehensive fill dictionary
fill_dict = {col: 0 for col in feature_cols}
df = df.fillna(fill_dict)

# --- 3. StringIndexer for WeekDayName and MonthName ---
indexers = []
if "WeekDayName_index" not in df.columns:
    indexers.append(StringIndexer(inputCol="WeekDayName", outputCol="WeekDayName_index", handleInvalid="keep"))
if "MonthName_index" not in df.columns:
    indexers.append(StringIndexer(inputCol="MonthName", outputCol="MonthName_index", handleInvalid="keep"))

if indexers:
    pipeline = Pipeline(stages=indexers)
    df = pipeline.fit(df).transform(df)

# --- 4. Ensure all feature columns are double type and handle nulls/NaNs ---
for colname in feature_cols:
    # Handle potential nulls from casting
    df = df.withColumn(
        colname, 
        when(col(colname).isNull() | isnan(col(colname)), 0.0)
        .otherwise(col(colname).cast("double"))
)

# --- 5. Final null check ---
print("Final null counts:")
null_counts = df.select([
    count(when(col(c).isNull() | isnan(col(c)), c)).alias(c) 
    for c in feature_cols
])
null_counts.show()

# --- 6. Assemble feature vector with safe handling ---
assembler = VectorAssembler(
    inputCols=feature_cols,
    outputCol="features",
    handleInvalid="keep"  # Critical: Handle any unexpected values
)

df_model_ready = assembler.transform(df)

# --- 7. Verify features before training ---
print("Sample features:")
df_model_ready.select("features", "StoreClosedLabel").show(5, truncate=False)

# Now safe to train your model

In [0]:
# Check for nulls in label column
label_null_count = df_model_ready.filter(col("StoreClosedLabel").isNull()).count()
print(f"Null labels: {label_null_count}")

In [0]:
df_model_ready.select(feature_cols).printSchema()

In [0]:
# Check if indexers produced expected values
if "WeekDayName_index" in feature_cols:
    df_model_ready.groupBy("WeekDayName", "WeekDayName_index").count().show()
if "MonthName_index" in feature_cols:
    df_model_ready.groupBy("MonthName", "MonthName_index").count().show()

# Train a Classifier

In [0]:
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator

# --- 1. Prepare the label column ---
# Ensure label is integer type (0/1)
df_model_ready = df_model_ready.withColumn(
    "StoreClosedLabel", 
    col("StoreClosedLabel").cast("integer")
)

# --- 2. Split data into training and test sets ---
train_data, test_data = df_model_ready.randomSplit([0.8, 0.2], seed=42)

print(f"Training count: {train_data.count()}")
print(f"Test count: {test_data.count()}")

# --- 3. Configure and train Random Forest ---
rf = RandomForestClassifier(
    featuresCol="features",
    labelCol="StoreClosedLabel",
    numTrees=100,                # Increase for better performance (tradeoff: longer training)
    maxDepth=10,                 # Tune this parameter
    minInstancesPerNode=5,        # Helps prevent overfitting
    seed=42,
    subsamplingRate=0.8,          # Use 80% of data for each tree
    featureSubsetStrategy="sqrt"  # Use sqrt(features) for each split
)

# Train model
rf_model = rf.fit(train_data)

# --- 4. Make predictions ---
predictions = rf_model.transform(test_data)

# --- 5. Evaluate model performance ---
evaluator = BinaryClassificationEvaluator(
    labelCol="StoreClosedLabel",
    rawPredictionCol="rawPrediction",
    metricName="areaUnderROC"
)

auc = evaluator.evaluate(predictions)
print(f"\n\n=== MODEL PERFORMANCE ===")
print(f"Test AUC: {auc:.4f}")

# Additional metrics
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

accuracy_evaluator = MulticlassClassificationEvaluator(
    labelCol="StoreClosedLabel",
    predictionCol="prediction",
    metricName="accuracy"
)

precision_evaluator = MulticlassClassificationEvaluator(
    labelCol="StoreClosedLabel",
    predictionCol="prediction",
    metricName="weightedPrecision"
)

accuracy = accuracy_evaluator.evaluate(predictions)
precision = precision_evaluator.evaluate(predictions)

print(f"Accuracy: {accuracy:.4f}")
print(f"Weighted Precision: {precision:.4f}")

# --- 6. Inspect feature importance ---
import pandas as pd

feature_importance = pd.DataFrame({
    "Feature": feature_cols,
    "Importance": rf_model.featureImportances.toArray()
}).sort_values("Importance", ascending=False)

print("\n=== FEATURE IMPORTANCE ===")
print(feature_importance)

# --- 7. Save model (optional) ---
# rf_model.save("dbfs:/path/to/rf_store_closures_model")

# 1. Investigate Data Leakage (Critical First Step)

In [0]:
# Check correlation between features and target
from pyspark.sql.functions import corr

for feature in feature_cols:
    correlation = df_model_ready.select(corr(feature, "StoreClosedLabel")).first()[0]
    print(f"{feature}: {correlation:.4f}")

**1. Fix Temporal Feature Calculation**

In [0]:
from pyspark.sql.window import Window
from pyspark.sql.functions import lag, sum as spark_sum, col, datediff, current_date

# Create proper window specification (assuming 'date' column exists)
window_spec = Window.partitionBy().orderBy("date")

# Calculate correct lag features
df_fixed = (df_model_ready
    .withColumn("correct_prev_day_closed", 
                lag("StoreClosedLabel", 1).over(window_spec))
    
    .withColumn("rolling_3day_closures", 
                spark_sum(col("StoreClosedLabel")).over(window_spec.rowsBetween(-3, -1)))
    
    .withColumn("correct_3day_closure_rate",
                col("rolling_3day_closures") / 3)
)

# Validate calculation
display(df_fixed.select("date", "StoreClosedLabel", "correct_prev_day_closed", "correct_3day_closure_rate"))

**Remove Leaky Features**

In [0]:
safe_feature_cols = [
    "IsWeekend",
    "IsHoliday",
    "month_num",
    "is_rain",
    "is_snow",
    "is_clear",
    "is_cloudy",
    "WeekDayName_index",
    "MonthName_index"
]

# Update assembler
safe_assembler = VectorAssembler(
    inputCols=safe_feature_cols,
    outputCol="safe_features",
    handleInvalid="keep"
)

df_safe = safe_assembler.transform(df_fixed)

Rebuild Model with Safe Feature

In [0]:
# Split data
train_safe, test_safe = df_safe.randomSplit([0.8, 0.2], seed=42)

# Train new model
rf_safe = RandomForestClassifier(
    featuresCol="safe_features",
    labelCol="StoreClosedLabel",
    numTrees=50,
    maxDepth=5
)

safe_model = rf_safe.fit(train_safe)

# Evaluate
safe_predictions = safe_model.transform(test_safe)
evaluator.evaluate(safe_predictions)  # Expect <1.0 AUC now

4. Validate with Business Rules

In [0]:
# 1. Generate predictions for the entire dataset
full_predictions = safe_model.transform(df_safe)

# 2. Analyze closure reasons
closure_reasons = full_predictions.filter(col("StoreClosedLabel") == 1).select(
    "date", 
    "IsWeekend", 
    "IsHoliday", 
    "is_snow",
    "is_rain",
    "prediction",  # Now available after transformation
    "probability"  # Useful for confidence analysis
).distinct()

display(closure_reasons.orderBy("date"))

In [0]:
# Feature Importance Analysis
# Get feature importance
safe_importance = pd.DataFrame({
    "Feature": safe_feature_cols,
    "Importance": safe_model.featureImportances.toArray()
}).sort_values("Importance", ascending=False)

print("SAFE FEATURE IMPORTANCE:")
print(safe_importance)

# Plot importance
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6))
plt.barh(safe_importance['Feature'], safe_importance['Importance'])
plt.gca().invert_yaxis()  # Most important on top
plt.xlabel('Importance')
plt.title('Feature Importance (Safe Features)')
plt.show()

In [0]:
# Weather Impact Analysis
# Weather impact on closures
weather_effect = full_predictions.groupBy(
    "is_rain", "is_snow", "is_clear", "is_cloudy", "is_other_weather"
).agg(
    spark_sum("StoreClosedLabel").alias("actual_closures"),
    spark_sum("prediction").alias("predicted_closures"),
    (spark_sum("prediction") / count("*")).alias("closure_rate")
)

display(weather_effect.orderBy("closure_rate", ascending=False))

In [0]:
# Temporal Patterns Analysis
# Day of week patterns
weekday_effect = full_predictions.groupBy("WeekDayName_index").agg(
    (spark_sum("StoreClosedLabel") / count("*")).alias("actual_closure_rate"),
    (spark_sum("prediction") / count("*")).alias("predicted_closure_rate")
)

# Month patterns
month_effect = full_predictions.groupBy("MonthName_index").agg(
    (spark_sum("StoreClosedLabel") / count("*")).alias("actual_closure_rate"),
    (spark_sum("prediction") / count("*")).alias("predicted_closure_rate")
)

print("Weekday Closure Rates:")
display(weekday_effect.orderBy("WeekDayName_index"))

print("\nMonthly Closure Rates:")
display(month_effect.orderBy("MonthName_index"))

# Feature Engineering

In [0]:
# Create combined weather severity feature
df_fixed = df_fixed.withColumn(
    "weather_severity",
    when(col("is_snow") == 1, 3)
    .when(col("is_rain") == 1, 2)
    .when(col("is_other_weather") == 1, 2)
    .otherwise(1)
)

# Add weekend+holiday combination
df_fixed = df_fixed.withColumn(
    "special_day",
    when((col("IsWeekend") == 1) | (col("IsHoliday") == 1), 1)
    .otherwise(0)
)

# Update feature list
enhanced_features = safe_feature_cols + ["weather_severity", "special_day"]

In [0]:
#Address Class Imbalance (if exist)
from pyspark.sql.functions import when

class_balance = full_predictions.groupBy("StoreClosedLabel").count()
class_balance.show()

# If imbalance > 10:1, use class weights
rf_balanced = RandomForestClassifier(
    featuresCol="safe_features",
    labelCol="StoreClosedLabel",
    weightCol="class_weight"
)

# Create weight column
df_weighted = df_safe.withColumn(
    "class_weight",
    when(col("StoreClosedLabel") == 1, 10).otherwise(1)
)

In [0]:
#Hyperparameter Tuning

from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

param_grid = (ParamGridBuilder()
    .addGrid(rf_safe.numTrees, [30, 50, 100])
    .addGrid(rf_safe.maxDepth, [3, 5, 7])
    .addGrid(rf_safe.minInstancesPerNode, [1, 5, 10])
    .build()
)

cv = CrossValidator(
    estimator=rf_safe,
    estimatorParamMaps=param_grid,
    evaluator=evaluator,
    numFolds=5
)

cv_model = cv.fit(train_safe)
best_model = cv_model.bestModel

In [0]:
#Threshold Tuning (for precision/recall tradeoff):
# Generate probability scores
probabilities = best_model.transform(test_safe).select(
    "probability", "StoreClosedLabel"
).rdd.map(lambda row: (
    float(row.probability[1]), 
    float(row.StoreClosedLabel)
)).toDF(["probability", "label"])

# Find optimal threshold
thresholds = [i/100 for i in range(10, 60)]
results = []
for t in thresholds:
    predicted = probabilities.withColumn("prediction", when(col("probability") > t, 1).otherwise(0))
    precision = predicted.filter(col("prediction") == 1).filter(col("label") == 1).count() / max(1, predicted.filter(col("prediction") == 1).count())
    recall = predicted.filter(col("prediction") == 1).filter(col("label") == 1).count() / max(1, predicted.filter(col("label") == 1).count())
    results.append((t, precision, recall))

# Plot precision-recall curve
pd_results = pd.DataFrame(results, columns=["threshold", "precision", "recall"])
pd_results.plot(x="recall", y="precision", title="Precision-Recall Tradeoff")

Based on your precision-recall tradeoff curve, here's a detailed analysis and recommendations for next steps:

Precision-Recall Analysis
Your curve shows a strong tradeoff between precision and recall, which is expected. Key observations:

High precision (0.90) comes at the cost of low recall (0.65):

You'll rarely predict closures incorrectly (good for avoiding false alarms)

But you'll miss 35% of actual closures (risking unpreparedness)

Balanced point (0.80 precision, 0.80 recall):

Good compromise for most business scenarios

Captures 80% of closures while keeping false alarms at 20%

High recall (0.95+) comes at significant precision cost:

Catches nearly all closures but with many false alarms

Only suitable if missing closures is extremely costly

In [0]:
#Apply your chosen threshold to predictions
from pyspark.sql.functions import udf, col, lit
from pyspark.sql.types import DoubleType

# 1. Get probability of positive class (closure)
get_positive_prob = udf(lambda v: float(v[1]), DoubleType())
predictions = safe_predictions.withColumn("closure_prob", get_positive_prob("probability"))

# 2. Apply optimal threshold
CHOSEN_THRESHOLD = 0.40  # Balanced scenario
final_predictions = predictions.withColumn(
    "optimized_prediction",
    when(col("closure_prob") > CHOSEN_THRESHOLD, 1).otherwise(0)
)

# 3. Evaluate new predictions
tp = final_predictions.filter((col("optimized_prediction") == 1) & (col("StoreClosedLabel") == 1)).count()
fp = final_predictions.filter((col("optimized_prediction") == 1) & (col("StoreClosedLabel") == 0)).count()
fn = final_predictions.filter((col("optimized_prediction") == 0) & (col("StoreClosedLabel") == 1)).count()

precision = tp / (tp + fp)
recall = tp / (tp + fn)

print(f"Optimized Precision: {precision:.2f}")
print(f"Optimized Recall: {recall:.2f}")

In [0]:
## FE
# Create time-sensitive features
from pyspark.sql.functions import dayofmonth, weekofyear

df_fixed = (df_fixed
    .withColumn("day_of_month", dayofmonth("date"))
    .withColumn("week_of_year", weekofyear("date"))
    .withColumn("is_month_end", (dayofmonth("date") >= 25).cast("int"))
)

# Add weather severity index
df_fixed = df_fixed.withColumn(
    "weather_severity",
    when(col("is_snow") == 1, 3)
    .when(col("is_rain") == 1, 2)
    .when(col("is_other_weather") == 1, 1)
    .otherwise(0)
)

# Update feature set
enhanced_features = safe_feature_cols + [
    "day_of_month", 
    "week_of_year",
    "is_month_end",
    "weather_severity"
]

In [0]:
##Address Class Imbalanc
# Check current class distribution
class_balance = df_fixed.groupBy("StoreClosedLabel").count()
class_balance.show()

# Apply class weights
weighted_rf = RandomForestClassifier(
    featuresCol="safe_features",
    labelCol="StoreClosedLabel",
    weightCol="class_weights"
)

# Create weight column
df_weighted = df_fixed.withColumn(
    "class_weights",
    when(col("StoreClosedLabel") == 1, 5.0).otherwise(1.0)  # Upweight closure events
)

In [0]:
#Hyperparameter Tunin
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

param_grid = (ParamGridBuilder()
    .addGrid(rf_safe.numTrees, [30, 50, 100])
    .addGrid(rf_safe.maxDepth, [3, 5, 7])
    .addGrid(rf_safe.minInstancesPerNode, [1, 5, 10])
    .build()
)

crossval = CrossValidator(
    estimator=rf_safe,
    estimatorParamMaps=param_grid,
    evaluator=BinaryClassificationEvaluator(
        labelCol="StoreClosedLabel",
        rawPredictionCol="rawPrediction",
        metricName="areaUnderPR"  # Focus on precision-recall curve
    ),
    numFolds=5,
    seed=42
)

cv_model = crossval.fit(train_safe)
best_model = cv_model.bestModel

In [0]:
#Business Rule Integration
# Create fallback rules based on domain knowledge
final_predictions = final_predictions.withColumn(
    "final_decision",
    when((col("is_snow") == 1) & (col("optimized_prediction") == 0), 1)  # Always predict closure in snow
    .when((col("IsHoliday") == 1) & (col("month_num").isin([11,12])), 1)  # Holiday season closures
    .otherwise(col("optimized_prediction"))
)

In [0]:
display(final_predictions)

In [0]:
final_predictions.write.mode("overwrite").saveAsTable("Store_Closure_Pred")

# NOW