In [7]:
import findspark
from pyspark.sql import SparkSession, Window
from pyspark.sql.types import StringType
import pyspark.sql.functions as f

import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt 
import seaborn as sns 
import warnings 

warnings.filterwarnings("ignore")

In [8]:
findspark.init()

spark = SparkSession.builder \
    .appName("Risk Budget") \
    .config("spark.driver.memory", "2g") \
    .config("spark.executor.memory", "2g") \
    .config("spark.driver.extraClassPath", r"C:\Drivers\sqljdbc_12.10.0.0_enu\sqljdbc_12.10\enu\jars\mssql-jdbc-12.10.0.jre11.jar") \
    .getOrCreate()

# Read CSV
customers = spark.read.csv("../final data/customer_features.csv", header=True, inferSchema=True)
orders = spark.read.csv("../final data/orders_facts.csv", header=True, inferSchema =True)

In [9]:
orders = orders \
    .withColumn(
        "order_year", f.year(f.col("order_date"))
    ).withColumn(
        "order_month", f.month(f.col("order_date"))
    )

# RISK BUDGET 
Not all churn is created equal

In [20]:
years = [row["order_year"] for row in orders.select("order_year").distinct().collect()]
yearly_budget_dfs = []

order_value = orders.join(
    customers, "customer_id"
).withColumn(
    "start_year", f.year(f.to_date("customer_first_date"))
).groupBy(
    "customer_id", "order_id", "order_year", "start_year"
).agg(
    f.avg("total_price").alias("aov")
)

for year in years:
    # Filter for current year
    this_year_df = order_value.filter(f.col("order_year") == year)

    # Compute quantiles once
    q33, q66 = this_year_df.approxQuantile("aov", [0.33, 0.66], 0.001)

    # Define conditions
    aov_grouped = this_year_df.withColumn(
        "aov_group",
        f.when(f.col("start_year") == year, "New")
         .when(f.col("aov") <= q33, "Low")
         .when((f.col("aov") > q33) & (f.col("aov") <= q66), "Med")
         .when(f.col("aov") > q66, "High")
         .otherwise("Unknown")
    )

    yearly_budget_dfs.append((year, aov_grouped))


In [None]:
for year, df in yearly_budget_dfs:
    print(f"=== Year: {year} ===")
    df.groupBy("aov_group").agg(
        f.countDistinct("customer_id").alias("no_of_customers"),
        f.round(f.avg("aov").alias("value"), 2)
    ).show()