# Customer Lifetime Value (CLV) Prediction with PySpark, XGBoost, and MLflow

In this mini project, we use the **Online Retail II** dataset from the [UCI Machine Learning Repository](https://archive.ics.uci.edu/dataset/502/online+retail+ii) to predict 6‑month customer lifetime value (CLV) for a UK-based online retail business. We treat this as a supervised regression problem and focus on reproducible, production-minded workflows.

**Business goal:** given a customer's historical transactions (recency, frequency, monetary value, product mix, etc.), estimate their **future 6‑month revenue** so marketing and retention teams can prioritize high‑value customers.

---

## 1. Environment Setup & Imports

_(In this section we will import PySpark, pandas, XGBoost, scikit-learn, MLflow, and plotting libraries.)_


In [1]:
# Environment Setup & Imports
import os
from datetime import timedelta

import numpy as np
import pandas as pd

# PySpark
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col,
    sum as spark_sum,
    count as spark_count,
    countDistinct,
    max as spark_max,
    min as spark_min,
    avg as spark_avg,
    stddev,
    datediff,
    lit,
    first,
)
from pyspark.sql.window import Window

# Modeling
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, r2_score
from xgboost import XGBRegressor

# Experiment tracking
import mlflow
import mlflow.xgboost

import matplotlib.pyplot as plt
import seaborn as sns

# Plotting style
sns.set(style="whitegrid", context="notebook")

# Paths
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
DATA_DIR = os.path.join(PROJECT_ROOT, "data")
EXCEL_PATH = os.path.join(DATA_DIR, "online_retail_II.xlsx")
CSV_PATH = os.path.join(DATA_DIR, "transactions.csv")

print("Project root:", PROJECT_ROOT)
print("Excel path:", EXCEL_PATH)
print("CSV path:", CSV_PATH)


XGBoostError: 
XGBoost Library (libxgboost.dylib) could not be loaded.
Likely causes:
  * OpenMP runtime is not installed
    - vcomp140.dll or libgomp-1.dll for Windows
    - libomp.dylib for Mac OSX
    - libgomp.so for Linux and other UNIX-like OSes
    Mac OSX users: Run `brew install libomp` to install OpenMP runtime.

  * You are running 32-bit Python on a 64-bit OS

Error message(s): ["dlopen(/Users/taravat/Documents/Cursor/mini-projects/clv-prediction-pyspark-xgboost-MLflow/.venv/lib/python3.11/site-packages/xgboost/lib/libxgboost.dylib, 0x0006): Library not loaded: @rpath/libomp.dylib\n  Referenced from: <636BF463-1886-392D-B8B3-6011C44DCEE9> /Users/taravat/Documents/Cursor/mini-projects/clv-prediction-pyspark-xgboost-MLflow/.venv/lib/python3.11/site-packages/xgboost/lib/libxgboost.dylib\n  Reason: tried: '/opt/homebrew/opt/libomp/lib/libomp.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/opt/libomp/lib/libomp.dylib' (no such file), '/opt/homebrew/opt/libomp/lib/libomp.dylib' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/opt/libomp/lib/libomp.dylib' (no such file)"]


In [None]:
# 2. Spark Session & Data Loading

spark = (
    SparkSession.builder
    .appName("clv-prediction")
    .getOrCreate()
)

print("Spark version:", spark.version)

# Create a CSV from the Online Retail II Excel file (one-time, idempotent)
if not os.path.exists(CSV_PATH):
    print("CSV not found, creating from Excel...")

    # Read all sheets from the Excel file and concatenate
    excel_sheets = pd.read_excel(EXCEL_PATH, sheet_name=None)
    raw_df = pd.concat(excel_sheets.values(), ignore_index=True)

    # Drop completely empty rows
    raw_df = raw_df.dropna(how="all")

    # Remove cancelled invoices (InvoiceNo starting with 'C')
    raw_df = raw_df[~raw_df["InvoiceNo"].astype(str).str.startswith("C")]

    # Rename columns to snake_case / modeling-friendly names
    raw_df = raw_df.rename(
        columns={
            "InvoiceNo": "invoice_id",
            "StockCode": "product_id",
            "Description": "description",
            "Quantity": "quantity",
            "InvoiceDate": "invoice_date",
            "UnitPrice": "unit_price",
            "CustomerID": "customer_id",
            "Country": "country",
        }
    )

    raw_df.to_csv(CSV_PATH, index=False)
    print(f"Saved combined CSV to {CSV_PATH}")
else:
    print("Using existing CSV at:", CSV_PATH)

# Load CSV into Spark

df = spark.read.csv(CSV_PATH, header=True, inferSchema=True)

print("Raw Spark schema:")
df.printSchema()
df.show(5)


## 3. Data Cleaning & Basic EDA (PySpark)

_In this section we will load the Online Retail II dataset into Spark, clean cancellations/returns, filter invalid rows, and explore basic distributions (invoices, customers, countries, etc.)._


In [3]:
# 3. Data Cleaning & Basic EDA (PySpark)

from pyspark.sql.functions import to_timestamp

# Cast invoice_date to timestamp and create amount column

df = df.withColumn("invoice_date", to_timestamp("invoice_date"))
df = df.withColumn("amount", col("quantity") * col("unit_price"))

# Basic filtering: non-null customer_id and positive amount

df = df.filter(col("customer_id").isNotNull())
df = df.filter(col("amount") > 0)

summary = df.select(
    countDistinct("customer_id").alias("n_customers"),
    countDistinct("invoice_id").alias("n_orders"),
    spark_sum("amount").alias("total_revenue"),
)

print("High-level summary after basic cleaning:")
summary.show()

print("Row count after cleaning:", df.count())


NameError: name 'df' is not defined

## 4. Train / Prediction Window Definition (CLV Target Construction)

We will:
- Define an **observation window** (history period) per customer to build features.
- Define a **6‑month prediction window** after the observation end date.
- Compute **future revenue** per customer over the prediction window as the CLV target.



In [2]:
# 4. Train / Prediction Window Definition (CLV Target Construction)

# Find the maximum invoice_date in the cleaned data
max_date_row = df.select(spark_max("invoice_date").alias("max_date")).collect()[0]
max_date = max_date_row["max_date"]
print("Max invoice_date:", max_date)

# Define cutoff date 6 months (180 days) before the max date
cutoff_date = max_date - timedelta(days=180)
print("Cutoff date (6 months before max):", cutoff_date)

# Split into history (features) and future (CLV target) windows
history_df = df.filter(col("invoice_date") <= lit(cutoff_date))
future_df = df.filter(col("invoice_date") > lit(cutoff_date))

print("History rows:", history_df.count())
print("Future rows:", future_df.count())

# Compute future 6-month revenue (CLV) per customer
clv_df = future_df.groupBy("customer_id").agg(
    spark_sum("amount").alias("future_6m_revenue")
)

print("Sample of CLV labels:")
clv_df.orderBy(col("future_6m_revenue").desc()).show(5)


NameError: name 'df' is not defined

## 5. Feature Engineering with PySpark (RFM & More)

Here we will create customer-level features such as:
- **Recency**: days since last purchase in the observation window
- **Frequency**: number of transactions
- **Monetary value**: total and average revenue
- Product/category diversity, country, etc.



In [None]:
# 5. Feature Engineering with PySpark (RFM & more)

customer_agg = history_df.groupBy("customer_id").agg(
    spark_count("invoice_id").alias("num_orders"),
    spark_sum("amount").alias("total_spent"),
    spark_avg("amount").alias("avg_order_value"),
    stddev("amount").alias("order_amount_std"),
    countDistinct("product_id").alias("num_unique_products"),
    spark_min("invoice_date").alias("first_order_date"),
    spark_max("invoice_date").alias("last_order_date"),
    first("country").alias("country"),
)

# Recency & tenure
customer_agg = customer_agg.withColumn(
    "recency_days",
    datediff(lit(cutoff_date), col("last_order_date")),
).withColumn(
    "tenure_days",
    datediff(col("last_order_date"), col("first_order_date")),
)

# Join with CLV labels (future 6-month revenue)
features_df = (
    customer_agg.join(clv_df, on="customer_id", how="left").fillna({"future_6m_revenue": 0.0})
)

# Replace null stddev with 0
features_df = features_df.fillna({"order_amount_std": 0.0})

print("Feature dataframe sample:")
features_df.show(5)


## 6. Train/Test Split at Customer Level

We will split customers into train and test sets (e.g., 80/20) to avoid data leakage across time windows and transactions for the same customer.



In [None]:
# 6. Train/Test Split at Customer Level

features_pd = features_df.toPandas()
print("Pandas feature shape:", features_pd.shape)

target_col = "future_6m_revenue"
id_col = "customer_id"

cols_to_drop = [target_col, id_col, "first_order_date", "last_order_date"]
feature_cols = [c for c in features_pd.columns if c not in cols_to_drop]

X = features_pd[feature_cols].copy()
y = features_pd[target_col].values

# One-hot encode country if present
if "country" in X.columns:
    X = pd.get_dummies(X, columns=["country"], dummy_na=True)

# Simple missing value handling
X = X.fillna(0)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

print("Train shape:", X_train.shape, "Test shape:", X_test.shape)


## 7. Modeling CLV with XGBoost + MLflow

- Train an XGBoost regressor on engineered features.
- Use MLflow to track experiments (hyperparameters, metrics, artifacts).
- Log the best model for potential deployment.



In [None]:
# 7. Modeling CLV with XGBoost + MLflow

mlflow.set_experiment("clv-prediction")

params = {
    "n_estimators": 400,
    "max_depth": 6,
    "learning_rate": 0.05,
    "subsample": 0.8,
    "colsample_bytree": 0.8,
    "reg_lambda": 1.0,
    "random_state": 42,
    "n_jobs": -1,
}

with mlflow.start_run(run_name="xgb_clv_baseline"):
    model = XGBRegressor(**params)
    model.fit(X_train, y_train)

    y_pred = model.predict(X_test)

    mae = mean_absolute_error(y_test, y_pred)
    r2 = r2_score(y_test, y_pred)

    mlflow.log_params(params)
    mlflow.log_metric("mae", float(mae))
    mlflow.log_metric("r2", float(r2))

    mlflow.xgboost.log_model(model, artifact_path="model")

    print(f"MAE: {mae:.2f}")
    print(f"R2: {r2:.3f}")


## 8. Evaluation & Business Interpretation

- Evaluate regression performance (RMSE, MAE, R²).
- Analyze feature importance and segments (e.g., top‑decile customers by predicted CLV).
- Discuss how marketing/retention teams could use these insights.



In [None]:
# 8. Evaluation & Business Interpretation (basic plots)

plt.figure(figsize=(6, 6))
plt.scatter(y_test, y_pred, alpha=0.3)
plt.xlabel("Actual CLV (future_6m_revenue)")
plt.ylabel("Predicted CLV")
plt.title("Predicted vs Actual CLV")

min_val = min(y_test.min(), y_pred.min())
max_val = max(y_test.max(), y_pred.max())
plt.plot([min_val, max_val], [min_val, max_val], "r--")

plt.tight_layout()
plt.show()

# Feature importances
importances = model.feature_importances_
sorted_idx = np.argsort(importances)[-20:]

plt.figure(figsize=(8, 6))
plt.barh(range(len(sorted_idx)), importances[sorted_idx])
plt.yticks(range(len(sorted_idx)), X_train.columns[sorted_idx])
plt.xlabel("Feature Importance")
plt.title("XGBoost Feature Importances")
plt.tight_layout()
plt.show()


## 9. Production Mindset & Next Steps

Ideas for future work:
- Move heavy feature engineering into dedicated PySpark jobs (e.g., in `scripts/preprocess.py`).
- Schedule regular CLV refreshes (daily/weekly) with updated transactions.
- Serve predictions via an API or batch exports to marketing tools.



In [None]:
# Placeholder for future preprocessing helpers (if needed)

if __name__ == "__main__":
    print("Notebook imported as a module; no action taken.")
