In [7]:
# K-means Clustering with PySpark - Simple Customer Segmentation
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.clustering import KMeans
from pyspark.ml import Pipeline
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Initialize Spark
spark = SparkSession.builder.appName("CustomerSegmentation").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
print(f"Spark Version: {spark.version}")

Spark Version: 3.5.3


## 1. Data Loading and Basic Features

In [None]:
# Load transaction data
print("Loading transactions...")

df = spark.read.csv("../data/transactions_data.csv", header=True, inferSchema=True)
print(f"Raw data loaded: {df.count():,} transactions")

# Create basic features
df_processed = df \
    .withColumn("amount_numeric", regexp_replace(col("amount"), "[\$,]", "").cast("double")) \
    .withColumn("is_online", (col("merchant_city") == "ONLINE").cast("int")) \
    .withColumn("is_weekend", dayofweek(col("date")).isin([1, 7]).cast("int")) \
    .filter(col("amount_numeric").isNotNull() & (col("amount_numeric") > 0))

df_processed.cache()
print(f"Processed {df_processed.count():,} transactions")
df_processed.show(5)

Loading transactions...


  .withColumn("amount_numeric", regexp_replace(col("amount"), "[\$,]", "").cast("double")) \
  .withColumn("amount_numeric", regexp_replace(col("amount"), "[\$,]", "").cast("double")) \
25/08/11 23:25:14 WARN CacheManager: Asked to cache already cached data.        
25/08/11 23:25:14 WARN MemoryStore: Not enough space to cache rdd_4_4 in memory! (computed 8.5 MiB so far)
25/08/11 23:25:14 WARN MemoryStore: Not enough space to cache rdd_4_7 in memory! (computed 8.5 MiB so far)
25/08/11 23:25:14 WARN MemoryStore: Not enough space to cache rdd_4_8 in memory! (computed 8.5 MiB so far)
25/08/11 23:25:14 WARN CacheManager: Asked to cache already cached data.        
25/08/11 23:25:14 WARN MemoryStore: Not enough space to cache rdd_4_4 in memory! (computed 8.5 MiB so far)
25/08/11 23:25:14 WARN MemoryStore: Not enough space to cache rdd_4_7 in memory! (computed 8.5 MiB so far)
25/08/11 23:25:14 WARN MemoryStore: Not enough space to cache rdd_4_8 in memory! (computed 8.5 MiB so far)


Loaded 12,635,227 transactions
+-------+-------------------+---------+-------+-------+-----------------+-----------+-------------+--------------+-------+----+------+--------------+---------+
|     id|               date|client_id|card_id| amount|         use_chip|merchant_id|merchant_city|merchant_state|    zip| mcc|errors|amount_numeric|is_online|
+-------+-------------------+---------+-------+-------+-----------------+-----------+-------------+--------------+-------+----+------+--------------+---------+
|7475328|2010-01-01 00:02:00|      561|   4575| $14.57|Swipe Transaction|      67570|   Bettendorf|            IA|52722.0|5311|  NULL|         14.57|        0|
|7475329|2010-01-01 00:02:00|     1129|    102| $80.00|Swipe Transaction|      27092|        Vista|            CA|92084.0|4829|  NULL|          80.0|        0|
|7475331|2010-01-01 00:05:00|      430|   2860|$200.00|Swipe Transaction|      27092|  Crown Point|            IN|46307.0|4829|  NULL|         200.0|        0|
|7475332|

## 2. Customer Feature Engineering

Create customer-level features for clustering based on spending patterns, behavioral patterns, and temporal patterns.

In [11]:
# Create customer features for clustering
print("Creating customer features...")

# Check if we have the required columns
required_cols = ["client_id", "amount_numeric", "is_online", "merchant_id"]
available_cols = df_processed.columns
missing_cols = [col for col in required_cols if col not in available_cols]

if missing_cols:
    print(f"Missing columns: {missing_cols}")
    print(f"Available columns: {available_cols}")
else:
    print("All required columns found. Proceeding with feature creation...")
    
    customer_features = df_processed.groupBy("client_id").agg(
        sum("amount_numeric").alias("total_spend"),
        avg("amount_numeric").alias("avg_transaction_amount"),
        count("*").alias("transaction_count"),
        avg("is_online").alias("online_ratio"),
        countDistinct("merchant_id").alias("merchant_diversity")
    ).filter(col("transaction_count") >= 5)  # Minimum 5 transactions

    # customer_features.cache()
    print(f"Created features for {customer_features.count():,} customers")
    customer_features.describe().show()

Creating customer features...
All required columns found. Proceeding with feature creation...


ERROR:root:KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "/Users/julienlook/Documents/Coding/big-data-analytics/.venv/lib/python3.12/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/julienlook/Documents/Coding/big-data-analytics/.venv/lib/python3.12/site-packages/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/homebrew/Cellar/python@3.12/3.12.11/Frameworks/Python.framework/Versions/3.12/lib/python3.12/socket.py", line 720, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt


KeyboardInterrupt: 

## 3. K-means Clustering

In [None]:
# Prepare features and run K-means
feature_cols = ["total_spend", "avg_transaction_amount", "transaction_count", "online_ratio", "merchant_diversity"]

# Check if customer_features exists and has data
if 'customer_features' not in locals():
    print("ERROR: customer_features not created. Please run the previous cell first.")
else:
    print(f"Using {customer_features.count():,} customers for clustering")
    
    # Handle any null values
    customer_features_clean = customer_features.fillna(0.0)
    
    # Create pipeline: features -> scaling -> clustering
    assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
    scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures")
    kmeans = KMeans(featuresCol="scaledFeatures", k=4, seed=42)  # Use 4 clusters

    pipeline = Pipeline(stages=[assembler, scaler, kmeans])
    
    print("Training K-means model...")
    model = pipeline.fit(customer_features_clean)
    predictions = model.transform(customer_features_clean)

    print("K-means clustering completed!")
    predictions.groupBy("cluster").count().orderBy("cluster").show()

## 4. Analyze Clusters

In [None]:
# Analyze cluster characteristics
cluster_summary = predictions.groupBy("cluster").agg(
    count("*").alias("customer_count"),
    avg("total_spend").alias("avg_total_spend"),
    avg("avg_transaction_amount").alias("avg_transaction_amount"), 
    avg("online_ratio").alias("avg_online_ratio")
).orderBy("cluster")

print("Cluster Analysis:")
cluster_summary.show()

## 5. Simple Visualization

In [None]:
# Simple cluster visualization
sample_data = predictions.sample(0.1, seed=42).toPandas()

plt.figure(figsize=(10, 4))

plt.subplot(1, 2, 1)
plt.scatter(sample_data['total_spend'], sample_data['transaction_count'], 
           c=sample_data['cluster'], cmap='viridis', alpha=0.6)
plt.xlabel('Total Spend ($)')
plt.ylabel('Transaction Count')
plt.title('Customer Segments: Spending vs Frequency')
plt.colorbar()

plt.subplot(1, 2, 2)
cluster_summary_pd = cluster_summary.toPandas()
plt.bar(cluster_summary_pd['cluster'], cluster_summary_pd['customer_count'])
plt.xlabel('Cluster')
plt.ylabel('Number of Customers') 
plt.title('Customers per Cluster')

plt.tight_layout()
plt.show() 
                     sample_customers['transaction_count'],
                     c=sample_customers['cluster'],     # Color by cluster assignment
                     cmap='tab10',                      # Use distinct colors for each cluster
                     alpha=0.6,                        # Semi-transparent for overlapping points
                     s=50)                             # Point size
plt.xlabel('Total Spend ($)', fontsize=11)
plt.ylabel('Transaction Count', fontsize=11)
plt.title('Customer Segments: Spending vs Frequency\n(High-value vs High-frequency customers)', fontweight='bold')

## 6. Business Insights  

In [None]:
# Simple business interpretation
print("Customer Segmentation Results:")
print("=" * 40)

for i, row in cluster_summary.collect():
    cluster_id = row['cluster']
    print(f"\nCluster {cluster_id}: {row['customer_count']:,} customers")
    print(f"  Average spend: ${row['avg_total_spend']:,.2f}")
    print(f"  Online ratio: {row['avg_online_ratio']*100:.1f}%")
    
    # Simple persona assignment
    if row['avg_total_spend'] > 1000:
        persona = "High-Value Customer"
    elif row['avg_online_ratio'] > 0.5:
        persona = "Digital Customer"
    else:
        persona = "Traditional Customer"
    
    print(f"  Persona: {persona}")

print(f"\n✅ Clustering completed! Found meaningful customer segments for targeted marketing.")

# Cleanup
spark.stop()