In [None]:
# K-means Clustering with PySpark - Simple Customer Segmentation

# Import required libraries for big data processing and machine learning
from pyspark.sql import SparkSession  # Main entry point for Spark SQL functionality
from pyspark.sql.functions import *   # SQL functions for data transformations
from pyspark.ml.feature import VectorAssembler, StandardScaler  # Feature preprocessing tools
from pyspark.ml.clustering import KMeans  # K-means clustering algorithm
from pyspark.ml import Pipeline  # ML pipeline for chaining transformations
import pandas as pd  # For local data manipulation and visualization
import matplotlib.pyplot as plt  # For creating plots and charts
import numpy as np  # For numerical computations

# Initialize Spark session with specific configurations
# These configs prevent common networking issues when running Spark locally
spark = (
	SparkSession.builder
	.appName("CustomerSegmentation")  # Give our Spark application a descriptive name
	.config("spark.driver.host", "127.0.0.1")  # Force localhost binding to avoid network issues
	.config("spark.driver.bindAddress", "127.0.0.1")  # Explicit bind address for driver
	.getOrCreate()  # Create new session or get existing one
)

# Reduce log noise by setting log level to WARN (only show warnings and errors)
spark.sparkContext.setLogLevel("WARN")
print(f"Spark Version: {spark.version}")  # Display Spark version for debugging

## 1. Data Loading and Basic Features

In [None]:
# Load transaction data from CSV file
print("Loading transactions...")

# Read CSV with automatic schema inference - Spark will detect column types
# header=True: First row contains column names
# inferSchema=True: Automatically detect data types (can be slow for large files)
df = spark.read.csv("../data/transactions_data.csv", header=True, inferSchema=True)
print(f"Raw data loaded: {df.count():,} transactions")  # Count triggers full data scan

# Create basic features through data cleaning and feature engineering
df_processed = df \
    .withColumn("amount_numeric", regexp_replace(col("amount"), "[\$,]", "").cast("double")) \
    .withColumn("is_online", (col("merchant_city") == "ONLINE").cast("int")) \
    .withColumn("is_weekend", dayofweek(col("date")).isin([1, 7]).cast("int")) \
    .filter(col("amount_numeric").isNotNull() & (col("amount_numeric") > 0))
    
# Breaking down the transformations:
# 1. amount_numeric: Remove $ and comma symbols, convert to double for math operations
# 2. is_online: Create binary flag (1/0) for online transactions based on merchant_city
# 3. is_weekend: Create binary flag for weekend transactions (Sunday=1, Saturday=7 in Spark)
# 4. filter: Remove null amounts and negative/zero amounts (data quality step)

# Cache the processed DataFrame in memory for faster repeated access
# This is crucial for iterative ML algorithms like K-means
df_processed.cache()
print(f"Processed {df_processed.count():,} transactions")
df_processed.show(5)  # Display first 5 rows to inspect data quality

## 2. Customer Feature Engineering

Create customer-level features for clustering based on spending patterns, behavioral patterns, and temporal patterns.

In [None]:
# Create customer-level aggregated features for clustering analysis
print("Creating customer features...")

# Define required columns for feature creation - defensive programming approach
required_cols = ["client_id", "amount_numeric", "is_online", "merchant_id"]
available_cols = df_processed.columns
missing_cols = [col for col in required_cols if col not in available_cols]

# Validate that all required columns exist before proceeding
if missing_cols:
    print(f"Missing columns: {missing_cols}")
    print(f"Available columns: {available_cols}")
else:
    print("All required columns found. Proceeding with feature creation...")
    
    # TODO: Create customer_features DataFrame here
    # Aggregate transaction-level data to customer-level features
    # Hint: Use groupBy("client_id").agg() with sum, avg, count, countDistinct
    
    # customer_features = ...
    
    # Display statistical summary of all features to understand data distribution
    # customer_features.describe().show()

## 3. K-means Clustering

In [None]:
# Prepare features and run K-means clustering algorithm
# Define the feature columns to use for clustering (all numerical features)
feature_cols = ["total_spend", "avg_transaction_amount", "transaction_count", "online_ratio", "merchant_diversity"]

# Defensive check to ensure customer_features DataFrame exists
# This prevents runtime errors if cells are run out of order
if 'customer_features' not in locals():
    print("ERROR: customer_features not created. Please run the previous cell first.")
else:
    print(f"Using {customer_features.count():,} customers for clustering")
    
    # Handle missing values by replacing nulls with 0.0
    customer_features_clean = customer_features.fillna(0.0)
    
    # TODO: Create ML Pipeline for K-means clustering
    # You'll need: VectorAssembler, StandardScaler, KMeans
    
    # assembler = VectorAssembler(...)
    # scaler = StandardScaler(...)  
    # kmeans = KMeans(...)
    # pipeline = Pipeline(stages=[...])
    
    print("Training K-means model...")
    # model = pipeline.fit(customer_features_clean)
    # predictions = model.transform(customer_features_clean)

    print("K-means clustering completed!")
    # Show distribution of customers across clusters
    # predictions.groupBy("prediction").count().orderBy("prediction").show()

## 4. Analyze Clusters

In [None]:
# Analyze cluster characteristics to understand what each cluster represents
# This step transforms cluster numbers into meaningful business insights

# Calculate summary statistics for each cluster
# groupBy("prediction"): Group customers by their assigned cluster
# alias("cluster"): Rename prediction column to more meaningful "cluster" name
cluster_summary = predictions.groupBy(col("prediction").alias("cluster")).agg(
    count("*").alias("customer_count"),  # How many customers in each cluster
    avg("total_spend").alias("avg_total_spend"),  # Average spending per cluster
    avg("avg_transaction_amount").alias("avg_transaction_amount"),  # Average transaction size per cluster
    avg("online_ratio").alias("avg_online_ratio")  # Average online shopping tendency per cluster
).orderBy("cluster")  # Sort by cluster number for easier interpretation

# Why these metrics matter for business:
# - customer_count: Shows relative size of each market segment
# - avg_total_spend: Identifies high-value vs low-value customer segments
# - avg_transaction_amount: Reveals spending behavior patterns per transaction
# - avg_online_ratio: Distinguishes digital-native vs traditional customers

print("Cluster Analysis:")
cluster_summary.show()  # Display the cluster analysis results

## 5. Simple Visualization

In [None]:
# Create visualizations to better understand cluster characteristics
# Sample 10% of data for visualization (full dataset would be too slow and cluttered)
sample_data = predictions.sample(0.1, seed=42).toPandas()  # Convert to Pandas for matplotlib

# TODO: Create visualizations for cluster analysis
# Hint: Use plt.subplot(1, 2, 1) for side-by-side plots
# Plot 1: Scatter plot of total_spend vs transaction_count colored by cluster
# Plot 2: Bar chart showing customer count per cluster

plt.figure(figsize=(10, 4))

# Your visualization code goes here...

plt.show()

## 6. Business Insights  

In [None]:
# Translate technical clustering results into actionable business insights
# This section converts cluster numbers into meaningful customer personas

print("Customer Segmentation Results:")
print("=" * 40)

# Iterate through each cluster to create business interpretations
# collect() brings Spark DataFrame to driver for local processing
for row in cluster_summary.collect():
    cluster_id = row['cluster']  # Get cluster number (0, 1, 2, 3)
    
    # Display key metrics for each cluster in business-friendly format
    print(f"\nCluster {cluster_id}: {row['customer_count']:,} customers")
    print(f"  Average spend: ${row['avg_total_spend']:,.2f}")
    print(f"  Online ratio: {row['avg_online_ratio']*100:.1f}%")
    
    # Simple persona assignment based on business rules
    # This is a basic approach - real implementations would use more sophisticated logic
    if row['avg_total_spend'] > 1000:
        persona = "High-Value Customer"  # Focus on retention and premium services
    elif row['avg_online_ratio'] > 0.5:
        persona = "Digital Customer"  # Target with online promotions and mobile features  
    else:
        persona = "Traditional Customer"  # Engage through in-store experiences and phone support
    
    print(f"  Persona: {persona}")
    
    # Persona assignment rationale:
    # - High spend threshold ($1000): Identifies most valuable customers requiring VIP treatment
    # - Online ratio threshold (50%): Distinguishes digital-native vs traditional shoppers
    # - This creates 3 distinct marketing strategies rather than generic approaches

# Business value statement
print(f"\n✅ Clustering completed! Found meaningful customer segments for targeted marketing.")

# Resource cleanup - stops Spark session to free up memory and resources
# Important for local development but may not be desired in production notebooks
spark.stop()

# Marketing strategy implications by persona:
# 1. High-Value Customers: Premium loyalty programs, personal shoppers, exclusive events
# 2. Digital Customers: Mobile app features, online-only deals, social media campaigns  
# 3. Traditional Customers: In-store promotions, direct mail, phone-based customer service

# Next steps for business implementation:
# 1. Validate personas with business stakeholders
# 2. Design targeted marketing campaigns for each segment
# 3. Set up automated segmentation pipeline for new customers
# 4. Measure campaign effectiveness by persona
# 5. Refine clustering model based on business outcomes