# Retail Analytics (PySpark)

This notebook analyzes the retail dataset using PySpark to answer business questions.
Data is pre-processed by `prepare_data.py`.

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum as _sum, window, avg, countDistinct, lit, when
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Initialize Spark Session
spark = SparkSession.builder.appName("RetailAnalytics").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")

# Set plot style
sns.set_theme(style="whitegrid")

## 1. Load Processed Data

In [None]:
PROCESSED_DIR = 'data/processed_transactions'

# Load Parquet data
df = spark.read.parquet(PROCESSED_DIR)

print("Total records:", df.count())
df.printSchema()
df.show(5)

## 2. Analysis: Weekly Purchase Data for Customers
Question: What are the purchase data for customers displayed weekly?

In [None]:
# Group by customer and week
# We use 'window' function on the date column for weekly grouping.
# '1 week' duration.

weekly_sales = df.groupBy(
    "customer_name",
    window(col("date"), "1 week")
).agg(
    _sum("price_paid").alias("total_spend")
)

# Extract start date from window for cleaner sorting/display
weekly_sales = weekly_sales.withColumn("week_start", col("window.start")) \
                           .drop("window")

# Sort
weekly_sales = weekly_sales.orderBy("customer_name", "week_start")

weekly_sales.show(10)

In [None]:
# Convert to Pandas for visualization
pdf_weekly_sales = weekly_sales.toPandas()

# Visualize for a few top customers
top_customers = pdf_weekly_sales.groupby('customer_name')['total_spend'].sum().nlargest(5).index

plt.figure(figsize=(12, 6))
sns.lineplot(data=pdf_weekly_sales[pdf_weekly_sales['customer_name'].isin(top_customers)], 
             x='week_start', y='total_spend', hue='customer_name', marker='o')
plt.title('Weekly Spend for Top 5 Customers')
plt.xlabel('Week')
plt.ylabel('Total Spend ($)')
plt.legend(title='Customer')
plt.show()

## 3. Classification: Fast, Medium, Slow Items & Stores
Group items and stores based on average weekly sales.
- **Fast**: Top 33%
- **Medium**: Middle 33%
- **Slow**: Bottom 33%

In [None]:
def classify_entity(df, entity_col, entity_name_col, metric_col='price_paid'):
    # 1. Calculate weekly sales per entity
    weekly_entity_sales = df.groupBy(
        entity_col,
        entity_name_col,
        window(col("date"), "1 week")
    ).agg(
        _sum(metric_col).alias("weekly_sales")
    )
    
    # 2. Calculate average weekly sales per entity
    avg_weekly_sales = weekly_entity_sales.groupBy(entity_col, entity_name_col) \
        .agg(avg("weekly_sales").alias("avg_weekly_sales"))
    
    # 3. Determine thresholds (33rd and 66th percentiles)
    # We use approxQuantile for efficiency
    quantiles = avg_weekly_sales.approxQuantile("avg_weekly_sales", [0.33, 0.66], 0.01)
    low_threshold = quantiles[0]
    high_threshold = quantiles[1]
    
    print(f"Classification Thresholds for {entity_name_col}:")
    print(f"  Slow < {low_threshold:.2f}")
    print(f"  {low_threshold:.2f} <= Medium <= {high_threshold:.2f}")
    print(f"  Fast > {high_threshold:.2f}")
    
    # 4. Classify
    classified_df = avg_weekly_sales.withColumn(
        "classification",
        when(col("avg_weekly_sales") > high_threshold, "Fast")
        .when(col("avg_weekly_sales") < low_threshold, "Slow")
        .otherwise("Medium")
    )
    
    return classified_df.orderBy(col("avg_weekly_sales").desc())

# Classify Products
print("--- Product Classification ---")
classified_products = classify_entity(df, "product_id", "product_name")
classified_products.show(10)

# Classify Shops
print("\n--- Shop Classification ---")
classified_shops = classify_entity(df, "shop_id", "shop_name")
classified_shops.show(10)

In [None]:
# Visualize Distribution
pdf_products = classified_products.toPandas()
pdf_shops = classified_shops.toPandas()

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

sns.countplot(data=pdf_products, x='classification', order=['Fast', 'Medium', 'Slow'], ax=axes[0])
axes[0].set_title('Product Classification Distribution')

sns.countplot(data=pdf_shops, x='classification', order=['Fast', 'Medium', 'Slow'], ax=axes[1])
axes[1].set_title('Shop Classification Distribution')

plt.show()