In [None]:
import os
import glob
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import random
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
import pprint
import pyspark
import pyspark.sql.functions as F

from pyspark.sql.functions import col
from pyspark.sql.types import StringType, IntegerType, FloatType, DateType
from pyspark.sql.window import Window

In [None]:
# Initialize SparkSession
spark = pyspark.sql.SparkSession.builder \
    .appName("dev") \
    .config("spark.driver.memory", "4g") \
    .config("spark.executor.memory", "4g") \
    .master("local[*]") \
    .getOrCreate()

# Set log level to ERROR to hide warnings
spark.sparkContext.setLogLevel("ERROR")

## Read Silver parquet files

In [None]:
# Define paths
silver_transactions_path = "/app/datamart/silver/transactions" 
silver_latest_transactions_path = "/app/datamart/silver/max_expiry_transactions"
gold_max_expiry_path = "/app/datamart/gold/label_store"

In [None]:
txn = (spark.read
      .option("header", True)
      .option("inferSchema", True)
      .parquet(silver_transactions_path))

In [None]:
txn_snapshots = (spark.read
                      .option("header", True)
                      .option("inferSchema", True)
                      .parquet(silver_latest_transactions_path))

print("Silver Transactions schema:")
txn_snapshots.printSchema()

## Create labels

In [None]:
# Starting point: `txn` (raw transactions) and `txn_snapshots` (from previous step)
# txn has: msno, transaction_date (date), membership_expire_date (date), is_cancel (int)
# txn_snapshots has: msno, snapshot_date (= membership_expire_date), etc.

# 1) Base: only the columns you want
churn_base = txn_snapshots.select("snapshot_date", "msno").distinct()

# 2) Candidate renewals: any txn with is_cancel = 0
renewals = txn.where(F.col("is_cancel") == 0).select("msno", "transaction_date")

# 3) Check if a renewal exists within [snapshot_date, snapshot_date + 30 days]
cond = (
    (F.col("r.msno") == F.col("c.msno")) &
    (F.col("r.transaction_date") >= F.col("c.snapshot_date")) &
    (F.col("r.transaction_date") <= F.date_add(F.col("c.snapshot_date"), 30))
)

joined = churn_base.alias("c").join(renewals.alias("r"), on=cond, how="left")

# 4) Aggregate to "has renewal" flag and map to is_churn (yes renewal → 0, else 1)
result_churn = (
    joined.groupBy(F.col("c.snapshot_date"), F.col("c.msno"))
          .agg(F.max(F.when(F.col("r.transaction_date").isNotNull(), F.lit(1)).otherwise(F.lit(0))).alias("has_renewal_30d"))
          .withColumn("is_churn", F.when(F.col("has_renewal_30d") == 1, F.lit(0)).otherwise(F.lit(1)))
          .select(F.col("snapshot_date"), F.col("msno"), F.col("is_churn"))
)

# Preview
result_churn.show(10, truncate=False)

## Write parquet files to Gold layers

In [None]:
# Save to gold layer as Parquet, with partitioning
(
    result_churn
    .write
    .mode("overwrite")
    .parquet(gold_max_expiry_path)
)

print(f"✅ Gold layer - Labelling based on max membership_expire_date successfully written to: {gold_max_expiry_path}")

##### If you're cleaning code-base for .py files, you can end here.

## Review labels & EDA

In [None]:
# ===============================
# 1. Quick overview
# ===============================
result_churn.printSchema()
print(f"Total rows: {result_churn.count():,}")
print(f"Distinct users: {result_churn.select('msno').distinct().count():,}")
print(f"Distinct snapshot dates: {result_churn.select('snapshot_date').distinct().count()}")

In [None]:
# ===============================
# 2. Class balance
# ===============================

total_count = result_churn.count()

class_dist = (
    result_churn.groupBy("is_churn")
                .agg(F.count("*").alias("count"))
                .withColumn("percentage", F.round(F.col("count") / total_count * 100, 2))
                .orderBy("is_churn")
)

class_dist.show()


In [None]:
# ===============================
# 3. Churn rate over time
# ===============================
churn_over_time = (
    result_churn.groupBy("snapshot_date")
                .agg(
                    F.count("*").alias("total"),
                    F.sum("is_churn").alias("churned")
                )
                .withColumn("churn_rate", F.round(F.col("churned") / F.col("total") * 100, 2))
                .orderBy("snapshot_date")
)
churn_over_time.show(20, truncate=False)

In [None]:
# Visualize (if using a notebook that supports matplotlib)
pdf = churn_over_time.toPandas()
plt.figure(figsize=(10,5))
plt.plot(pdf["snapshot_date"], pdf["churn_rate"], marker="o")
plt.title("Churn Rate Over Time")
plt.xlabel("Snapshot Date")
plt.ylabel("Churn Rate (%)")
plt.grid(True)
plt.show()

In [None]:
# ===============================
# 4. User churn frequency
# ===============================
# Some users may appear multiple times; check their churn history
user_churn_stats = (
    result_churn.groupBy("msno")
                .agg(
                    F.count("*").alias("num_snapshots"),
                    F.sum("is_churn").alias("num_churns")
                )
                .withColumn("churn_rate_user", F.round(F.col("num_churns") / F.col("num_snapshots") * 100, 2))
)

user_churn_stats.describe(["num_snapshots", "churn_rate_user"]).show()

In [None]:
# ===============================
# 5. Correlation sanity check (optional)
# ===============================
# You can calculate correlation between churn and time index if you plan to model temporal effects.
indexed = churn_over_time.withColumn("time_index", F.row_number().over(Window.orderBy("snapshot_date")))
corr_value = indexed.stat.corr("time_index", "churn_rate")
print(f"Correlation between time progression and churn rate: {corr_value:.3f}")

## Stop Spark Session

In [None]:
spark.stop()