# DATA5000 Assessment 1: Customer Retention and Revenue Optimization

## Business Context

You are a Business Analyst at a leading SaaS technology company. The executive team is concerned about customer churn and wants to develop a data-driven retention strategy.

Your task is to analyze customer data, build predictive models, and evaluate the effectiveness of retention campaigns.

## Important Instructions

1. Run all cells in sequence from top to bottom
2. Pay attention to the outputs and visualizations
3. Take notes on key insights for your business report
4. Save important charts by right-clicking and selecting "Save image as"
5. This notebook should take 30-45 minutes to complete
6. You are allowed to modify the code lines to fit with your narrative and you also are not required to use all visualizations (ONLY SELECT THOSE THAT ARE SUITABLE).

## Dataset Overview

The dataset contains information about 15,000 SaaS customers including:
- Customer demographics and firmographics
- Engagement and usage metrics
- Revenue and payment information
- Churn status
- Retention campaign participation

---

# Section 0: Setup and Installation

First, we will install all required libraries and import necessary packages.

In [1]:
import sys
sys.version

'3.11.13 | packaged by conda-forge | (main, Jun  4 2025, 14:39:58) [MSC v.1943 64 bit (AMD64)]'

### INSTALL CELL -- Install the required libraries

In [None]:
%pip install torch==2.5.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
%pip install pytorch_forecasting==1.5.0 pytorch-lightning==2.6.0
%pip install tensorboard==2.15.1
%pip install numpy==1.26.4 scipy==1.11.4
%pip install shap==0.48.0
%pip install econml==0.16.0

#### IMPORTANT – AFTER RUNNING THE INSTALL CELL, CLICK “Runtime → Restart Runtime” BEFORE CONTINUING, IF PROMPTED BY COLAB

In [None]:
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings

# Machine Learning
import lightgbm as lgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.preprocessing import LabelEncoder

# Explainable AI
import shap

# Causal Inference
from econml.dr import DRLearner
from econml.sklearn_extensions.linear_model import StatsModelsLinearRegression
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier

# Time Series Forecasting
import torch
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.metrics import RMSE
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from pytorch_forecasting.data.encoders import GroupNormalizer
torch.serialization.add_safe_globals([GroupNormalizer])
# Settings
warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
pd.set_option('display.max_columns', None)
np.random.seed(42)

print("All libraries imported successfully")

---

# Section 1A: Data Exploration and Preparation

In this section, we will load the customer dataset and perform exploratory data analysis to understand customer behavior patterns.

In [None]:
url = "https://raw.githubusercontent.com/clarkian-teachings/data5k/main/assessments/saas_customer_data.csv"
df = pd.read_csv(url)

print(f"Dataset loaded successfully")
print(f"Number of customers: {len(df):,}")
print(f"Number of features: {df.shape[1]}")
print(f"\nFirst few rows:")
df.head()

In [None]:
# Display dataset information
print("Dataset Information:")
print("="*60)
df.info()

In [None]:
# Display summary statistics
print("Summary Statistics:")
print("="*60)
df.describe()

## Key Business Metrics

In [None]:
# Calculate key business metrics
churn_rate = df['churned'].mean() * 100
avg_monthly_revenue = df['monthly_revenue'].mean()
total_mrr = df['monthly_revenue'].sum()
campaign_coverage = df['retention_campaign_received'].mean() * 100

print("Key Business Metrics")
print("="*60)
print(f"Overall Churn Rate: {churn_rate:.2f}%")
print(f"Average Monthly Revenue per Customer: ${avg_monthly_revenue:.2f}")
print(f"Total Monthly Recurring Revenue (MRR): ${total_mrr:,.2f}")
print(f"Retention Campaign Coverage: {campaign_coverage:.2f}%")
print(f"\nChurn by Plan Type:")
print(df.groupby('plan_type')['churned'].mean().sort_values(ascending=False) * 100)

## Data Quality Check

In [None]:
# Check for missing values
print("Missing Values:")
print("="*60)
missing_data = df.isnull().sum()
missing_percent = (missing_data / len(df)) * 100
missing_df = pd.DataFrame({'Missing Count': missing_data, 'Percentage': missing_percent})
missing_df = missing_df[missing_df['Missing Count'] > 0].sort_values('Missing Count', ascending=False)
print(missing_df)

## Exploratory Data Analysis: Visualizations

The following visualizations will help you understand customer behavior patterns. Save these charts for your business report.

In [None]:
# Visualization 1: Churn Rate by Customer Segments
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Churn by Plan Type
churn_by_plan = df.groupby('plan_type')['churned'].mean().sort_values(ascending=False) * 100
axes[0, 0].bar(churn_by_plan.index, churn_by_plan.values, color='coral')
axes[0, 0].set_title('Churn Rate by Plan Type', fontsize=14, fontweight='bold')
axes[0, 0].set_ylabel('Churn Rate (%)')
axes[0, 0].set_xlabel('Plan Type')

# Churn by Company Size
churn_by_size = df.groupby('company_size')['churned'].mean().sort_values(ascending=False) * 100
axes[0, 1].bar(churn_by_size.index, churn_by_size.values, color='skyblue')
axes[0, 1].set_title('Churn Rate by Company Size', fontsize=14, fontweight='bold')
axes[0, 1].set_ylabel('Churn Rate (%)')
axes[0, 1].set_xlabel('Company Size')

# Churn by Industry
churn_by_industry = df.groupby('industry')['churned'].mean().sort_values(ascending=False) * 100
axes[1, 0].bar(churn_by_industry.index, churn_by_industry.values, color='lightgreen')
axes[1, 0].set_title('Churn Rate by Industry', fontsize=14, fontweight='bold')
axes[1, 0].set_ylabel('Churn Rate (%)')
axes[1, 0].set_xlabel('Industry')
axes[1, 0].tick_params(axis='x', rotation=45)

# Churn by Contract Length
churn_by_contract = df.groupby('contract_length')['churned'].mean().sort_values(ascending=False) * 100
axes[1, 1].bar(churn_by_contract.index, churn_by_contract.values, color='plum')
axes[1, 1].set_title('Churn Rate by Contract Length', fontsize=14, fontweight='bold')
axes[1, 1].set_ylabel('Churn Rate (%)')
axes[1, 1].set_xlabel('Contract Length')
axes[1, 1].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.savefig('churn_by_segments.png', dpi=300, bbox_inches='tight')
plt.show()

print("Chart saved as: churn_by_segments.png")

In [None]:
# Visualization 2: Revenue Distribution
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Monthly Revenue Distribution
axes[0].hist(df[df['monthly_revenue'] > 0]['monthly_revenue'], bins=50, color='steelblue', edgecolor='black')
axes[0].set_title('Distribution of Monthly Revenue', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Monthly Revenue ($)')
axes[0].set_ylabel('Number of Customers')
axes[0].axvline(df['monthly_revenue'].mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: ${df["monthly_revenue"].mean():.2f}')
axes[0].legend()

# Monthly Revenue by Churn Status
churned_revenue = df[df['churned'] == 1]['monthly_revenue']
retained_revenue = df[df['churned'] == 0]['monthly_revenue']
axes[1].hist([churned_revenue, retained_revenue], bins=30, label=['Churned', 'Retained'], color=['red', 'green'], alpha=0.7)
axes[1].set_title('Monthly Revenue: Churned vs Retained', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Monthly Revenue ($)')
axes[1].set_ylabel('Number of Customers')
axes[1].legend()

plt.tight_layout()
plt.savefig('revenue_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

print("Chart saved as: revenue_distribution.png")

In [None]:
# Visualization 3: Engagement Metrics by Churn Status
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Monthly Active Days
df.boxplot(column='monthly_active_days', by='churned', ax=axes[0, 0])
axes[0, 0].set_title('Monthly Active Days by Churn Status', fontsize=12, fontweight='bold')
axes[0, 0].set_xlabel('Churned (0=No, 1=Yes)')
axes[0, 0].set_ylabel('Monthly Active Days')
axes[0, 0].get_figure().suptitle('')

# Customer Health Score
df.boxplot(column='customer_health_score', by='churned', ax=axes[0, 1])
axes[0, 1].set_title('Customer Health Score by Churn Status', fontsize=12, fontweight='bold')
axes[0, 1].set_xlabel('Churned (0=No, 1=Yes)')
axes[0, 1].set_ylabel('Health Score')

# Days Since Last Login
df.boxplot(column='days_since_last_login', by='churned', ax=axes[1, 0])
axes[1, 0].set_title('Days Since Last Login by Churn Status', fontsize=12, fontweight='bold')
axes[1, 0].set_xlabel('Churned (0=No, 1=Yes)')
axes[1, 0].set_ylabel('Days Since Last Login')

# Feature Adoption Score
df.boxplot(column='feature_adoption_score', by='churned', ax=axes[1, 1])
axes[1, 1].set_title('Feature Adoption Score by Churn Status', fontsize=12, fontweight='bold')
axes[1, 1].set_xlabel('Churned (0=No, 1=Yes)')
axes[1, 1].set_ylabel('Feature Adoption Score')

plt.tight_layout()
plt.savefig('engagement_metrics.png', dpi=300, bbox_inches='tight')
plt.show()

print("Chart saved as: engagement_metrics.png")

In [None]:
# Visualization 4: Correlation Heatmap of Key Metrics
# Select numerical features for correlation analysis
numerical_features = ['monthly_revenue', 'monthly_active_days', 'feature_adoption_score',
                     'support_tickets_count', 'login_frequency', 'session_duration_avg',
                     'days_since_last_login', 'customer_health_score', 'nps_score', 'churned']

correlation_matrix = df[numerical_features].corr()

plt.figure(figsize=(12, 10))
sns.heatmap(correlation_matrix, annot=True, fmt='.2f', cmap='coolwarm', center=0,
            square=True, linewidths=1, cbar_kws={"shrink": 0.8})
plt.title('Correlation Heatmap of Customer Metrics', fontsize=16, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig('correlation_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()

print("Chart saved as: correlation_heatmap.png")
print("\nKey Correlations with Churn:")
print(correlation_matrix['churned'].sort_values(ascending=False))

## Data Preparation for Machine Learning

In [None]:
# Handle missing values
# For numerical columns, fill with median
numerical_cols = df.select_dtypes(include=[np.number]).columns
for col in numerical_cols:
    if df[col].isnull().sum() > 0:
        df[col].fillna(df[col].median(), inplace=True)

print("Missing values handled")
print(f"Remaining missing values: {df.isnull().sum().sum()}")

In [None]:
# Encode categorical variables for machine learning
# Create label encoders for categorical columns
categorical_cols = ['company_size', 'industry', 'country', 'plan_type', 'contract_length', 'payment_method']

label_encoders = {}
for col in categorical_cols:
    le = LabelEncoder()
    df[f'{col}_encoded'] = le.fit_transform(df[col])
    label_encoders[col] = le

print("Categorical variables encoded")
print(f"\nNew encoded columns: {[f'{col}_encoded' for col in categorical_cols]}")

---

# Section 1B: Predictive Modeling

In this section, we will build two predictive models:
1. LightGBM for customer churn prediction
2. Temporal Fusion Transformer for revenue forecasting

## Part 1: Customer Churn Prediction with LightGBM

In [None]:
# Prepare features for churn prediction
feature_columns = [
    'company_size_encoded', 'industry_encoded', 'country_encoded',
    'plan_type_encoded', 'contract_length_encoded', 'payment_method_encoded',
    'monthly_revenue', 'payment_failures', 'discount_received',
    'monthly_active_days', 'feature_adoption_score', 'support_tickets_count',
    'login_frequency', 'session_duration_avg', 'api_calls_count',
    'days_since_last_login', 'feature_usage_decline', 'customer_health_score',
    'nps_score', 'product_feedback_submitted'
]

X = df[feature_columns]
y = df['churned']

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

print(f"Training set size: {len(X_train):,} customers")
print(f"Testing set size: {len(X_test):,} customers")
print(f"\nChurn rate in training set: {y_train.mean()*100:.2f}%")
print(f"Churn rate in testing set: {y_test.mean()*100:.2f}%")

In [None]:
# Train LightGBM model
print("Training LightGBM model...")
print("This may take 1-2 minutes")

lgbm_model = lgb.LGBMClassifier(
    n_estimators=200,
    learning_rate=0.05,
    max_depth=6,
    num_leaves=31,
    min_child_samples=20,
    random_state=42,
    verbose=-1
)

lgbm_model.fit(X_train, y_train)

print("\nModel training complete")

In [None]:
# Make predictions
y_pred = lgbm_model.predict(X_test)
y_pred_proba = lgbm_model.predict_proba(X_test)[:, 1]

# Calculate performance metrics
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
roc_auc = roc_auc_score(y_test, y_pred_proba)

print("LightGBM Churn Prediction Model Performance")
print("="*60)
print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"Precision: {precision:.4f} ({precision*100:.2f}%)")
print(f"Recall: {recall:.4f} ({recall*100:.2f}%)")
print(f"F1 Score: {f1:.4f}")
print(f"ROC-AUC Score: {roc_auc:.4f}")


In [None]:
# Confusion Matrix Visualization
cm = confusion_matrix(y_test, y_pred)

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True,
            xticklabels=['Retained', 'Churned'],
            yticklabels=['Retained', 'Churned'])
plt.title('Confusion Matrix: Churn Prediction Model', fontsize=14, fontweight='bold', pad=20)
plt.ylabel('Actual', fontsize=12)
plt.xlabel('Predicted', fontsize=12)
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

print("Chart saved as: confusion_matrix.png")

## Part 2: Revenue Forecasting with Temporal Fusion Transformer

In [None]:
# Load time series data
url = "https://raw.githubusercontent.com/clarkian-teachings/data5k/main/assessments/revenue_timeseries.csv"
ts_df = pd.read_csv(url)
ts_df['date'] = pd.to_datetime(ts_df['date'])

print(f"Time series data loaded: {len(ts_df):,} records")
print(f"Number of unique customers: {ts_df['customer_id'].nunique():,}")
print(f"Time range: {ts_df['date'].min()} to {ts_df['date'].max()}")
print(f"\nFirst few rows:")
ts_df.head()

In [None]:
# Encode categorical variables for time series
for col in ['plan_type', 'company_size', 'customer_id']:
    le = LabelEncoder()
    ts_df[f'{col}_encoded'] = le.fit_transform(ts_df[col])

print("Time series data prepared for forecasting")

### Prepare data for Temporal Fusion Transformer

In [None]:
print("Preparing Temporal Fusion Transformer model...")

max_prediction_length = 6  # Forecast 6 months ahead
max_encoder_length = 12    # Use 12 months of history

training_cutoff = ts_df["time_idx"].max() - max_prediction_length

# Convert encoded categorical columns to string type
ts_df['plan_type_encoded'] = ts_df['plan_type_encoded'].astype(str)
ts_df['company_size_encoded'] = ts_df['company_size_encoded'].astype(str)
ts_df['customer_id_encoded'] = ts_df['customer_id_encoded'].astype(str)

training = TimeSeriesDataSet(
    ts_df[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="revenue",
    group_ids=["customer_id_encoded"],
    min_encoder_length=max_encoder_length // 2,
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=["plan_type_encoded", "company_size_encoded"],
    time_varying_known_reals=["time_idx"],
    time_varying_unknown_reals=["revenue"],
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

validation = TimeSeriesDataSet.from_dataset(training, ts_df, predict=True, stop_randomization=True)

batch_size = 64
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)

print("\nData prepared for TFT model")

### Train the TFT Model

In [None]:
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger

# Train Temporal Fusion Transformer
print("Training Temporal Fusion Transformer...")
print("This will take 5-10 minutes")
# configure network and trainer
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
lr_logger = LearningRateMonitor()  # log the learning rate
logger = TensorBoardLogger("lightning_logs", default_hp_metric=False)  # logging results to a tensorboard
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=8,
    loss=RMSE(),
    log_interval=10,
    optimizer="Adam",
    reduce_on_plateau_patience=4,
)
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    mode="min",
    save_top_k=1,
    dirpath="checkpoints",
    filename="best_tft_model"
)


trainer = pl.Trainer(
    max_epochs=5,
    accelerator="cpu",
    enable_model_summary=True,
    gradient_clip_val=0.1,
    limit_train_batches=50,  # coment in for training, running valiation every 30 batches
    # fast_dev_run=True,  # comment in to check that networkor dataset has no serious bugs
    callbacks=[lr_logger, early_stop_callback, checkpoint_callback],  # Add the checkpoint_callback here
    logger=logger
)

trainer.fit(tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

print("\nTFT model training complete")

### Evaluate and Visualize the model prediction

In [None]:
# Make predictions
best_model_path = trainer.checkpoint_callback.best_model_path if hasattr(trainer, 'checkpoint_callback') else None  
if best_model_path:
    best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)    
else:
    best_tft = tft

predictions = best_tft.predict(
    val_dataloader,
    return_x=True,
    trainer_kwargs=dict(accelerator="cpu")
)

print("Revenue forecasts generated")
print(f"Predictions shape: {predictions.output.shape}")

In [None]:
# Visualize sample forecasts
# Select first 4 customers for visualization
raw_predictions = best_tft.predict(
    val_dataloader,
    mode="raw",
    return_x=True,
    trainer_kwargs=dict(accelerator="cpu")  # ← ADD THIS
)

for idx in range(min(4, len(raw_predictions.x['decoder_target']))):
    best_tft.plot_prediction(raw_predictions.x, raw_predictions.output, idx=idx, add_loss_to_title=True)
    plt.savefig(f'tft_forecast_customer_{idx+1}.png', dpi=300, bbox_inches='tight')
    plt.show()

print("Forecast charts saved")

In [None]:
# Calculate forecast accuracy metrics
from sklearn.metrics import mean_absolute_error, mean_squared_error

# Get actual and predicted values
actuals = predictions.x['decoder_target']
forecasts = predictions.output.squeeze()

# Calculate metrics
mae = mean_absolute_error(actuals.flatten(), forecasts.flatten())
rmse = np.sqrt(mean_squared_error(actuals.flatten(), forecasts.flatten()))
# mape = np.mean(np.abs((actuals.flatten() - forecasts.flatten()) / actuals.flatten())) * 100

print("Revenue Forecasting Model Performance")
print("="*60)
print(f"Mean Absolute Error (MAE): ${mae:.2f}")
print(f"Root Mean Squared Error (RMSE): ${rmse:.2f}")


---

# Section 1C: Explainable AI Analysis with SHAP

In this section, we use SHAP (SHapley Additive exPlanations) to understand which features are driving customer churn predictions.

In [None]:
# Initialize SHAP explainer
print("Generating SHAP values...")
print("This may take 2-3 minutes")

explainer = shap.Explainer(lgbm_model)
shap_values = explainer.shap_values(X_test)

# Handle case where SHAP returns a list (for binary classification)
if isinstance(shap_values, list):
    shap_values_class1 = shap_values[1]  # SHAP values for churn class
else:
    shap_values_class1 = shap_values

print("SHAP values calculated successfully")

In [None]:
# SHAP Summary Plot: Global Feature Importance
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values_class1, X_test, plot_type="bar", show=False)
plt.title('Feature Importance for Churn Prediction (SHAP)', fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig('shap_feature_importance.png', dpi=300, bbox_inches='tight')
plt.show()

print("Chart saved as: shap_feature_importance.png")

In [None]:
# SHAP Summary Plot: Detailed View
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values_class1, X_test, show=False)
plt.title('SHAP Summary Plot: Impact of Features on Churn', fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig('shap_summary_plot.png', dpi=300, bbox_inches='tight')
plt.show()

print("Chart saved as: shap_summary_plot.png")

In [None]:
# Identify Top 5 Features
feature_importance = np.abs(shap_values_class1).mean(axis=0)
feature_importance_df = pd.DataFrame({
    'Feature': X_test.columns,
    'Importance': feature_importance
}).sort_values('Importance', ascending=False)

print("Top 5 Features Driving Customer Churn")
print("="*60)
for idx, row in feature_importance_df.head(5).iterrows():
    print(f"{idx+1}. {row['Feature']}: {row['Importance']:.4f}")

print("\nFull Feature Importance Ranking:")
print(feature_importance_df)

In [None]:
# SHAP Waterfall Plot: Individual Customer Explanation
# Select a churned customer example
churned_indices = X_test.index[y_test == 1].tolist()
sample_idx = churned_indices[0] if churned_indices else 0
sample_position = X_test.index.get_loc(sample_idx)

plt.figure(figsize=(10, 6))
shap.waterfall_plot(shap.Explanation(values=shap_values_class1[sample_position],
                                     base_values=explainer.expected_value[1] if isinstance(explainer.expected_value, list) else explainer.expected_value,
                                     data=X_test.iloc[sample_position],
                                     feature_names=X_test.columns.tolist()))
plt.tight_layout()
plt.savefig('shap_waterfall.png', dpi=300, bbox_inches='tight')
plt.show()

print("Chart saved as: shap_waterfall.png")
print("\nThis chart shows how each feature contributes to the churn prediction for one customer")

In [None]:
# SHAP Force Plot: Another Individual Customer
plt.figure(figsize=(20, 3))
shap.force_plot(
    explainer.expected_value[1] if isinstance(explainer.expected_value, list) else explainer.expected_value,
    shap_values_class1[sample_position],
    X_test.iloc[sample_position],
    matplotlib=True,
    show=False
)
plt.tight_layout()
plt.savefig('shap_force_plot.png', dpi=300, bbox_inches='tight')
plt.show()

print("Chart saved as: shap_force_plot.png")

In [None]:
# SHAP Dependence Plots for Top 3 Features
top_features = feature_importance_df.head(3)['Feature'].tolist()

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

for idx, feature in enumerate(top_features):
    shap.dependence_plot(feature, shap_values_class1, X_test, ax=axes[idx], show=False)
    axes[idx].set_title(f'SHAP Dependence: {feature}', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.savefig('shap_dependence_plots.png', dpi=300, bbox_inches='tight')
plt.show()

print("Chart saved as: shap_dependence_plots.png")

---

# Section 1D: Causal Analysis with EconML

In this section, we use causal inference to evaluate the effectiveness of retention campaigns.

In [None]:
# Prepare data for causal analysis
# Treatment: retention_campaign_received
# Outcome: churned (we want to see if campaign reduces churn)
# Controls: customer characteristics

# Select features for causal model (exclude treatment and outcome)
causal_features = [
    'company_size_encoded', 'industry_encoded', 'plan_type_encoded',
    'monthly_revenue', 'monthly_active_days', 'feature_adoption_score',
    'customer_health_score', 'days_since_last_login'
]

X_causal = df[causal_features]
T_causal = df['retention_campaign_received']  # Treatment
Y_causal = df['churned']  # Outcome

print("Causal Analysis Data Prepared")
print("="*60)
print(f"Number of customers: {len(X_causal):,}")
print(f"Treated customers (received campaign): {T_causal.sum():,} ({T_causal.mean()*100:.1f}%)")
print(f"Control customers (no campaign): {(1-T_causal).sum():,} ({(1-T_causal.mean())*100:.1f}%)")

In [None]:
# Naive comparison (before causal analysis)
churn_with_campaign = df[df['retention_campaign_received'] == 1]['churned'].mean()
churn_without_campaign = df[df['retention_campaign_received'] == 0]['churned'].mean()

naive_effect = churn_without_campaign - churn_with_campaign

print("Naive Comparison (Without Causal Adjustment)")
print("="*60)
print(f"Churn rate WITH campaign: {churn_with_campaign*100:.2f}%")
print(f"Churn rate WITHOUT campaign: {churn_without_campaign*100:.2f}%")
print(f"Naive effect: {naive_effect*100:.2f} percentage points")
print("\nNote: This naive comparison does not account for selection bias.")
print("At-risk customers are more likely to receive the campaign.")

In [None]:
# Train Doubly Robust Learner for causal inference
print("Training Causal Inference Model (Doubly Robust Learner)...")
print("This may take 2-3 minutes")

# Initialize the DRLearner
dml = DRLearner(
    model_propensity=RandomForestClassifier(n_estimators=100, random_state=42),
    model_regression=RandomForestRegressor(n_estimators=100, random_state=42),
    model_final=StatsModelsLinearRegression(),
    cv=3,
    random_state=42
)

# Fit the model
dml.fit(Y=Y_causal, T=T_causal, X=X_causal)

print("\nCausal model training complete")

In [None]:
# Calculate Average Treatment Effect (ATE)
ate = dml.ate(X=X_causal)
ate_interval = dml.ate_interval(X=X_causal, alpha=0.05)

print("Average Treatment Effect (ATE) Analysis")
print("="*60)
print(f"ATE: {ate:.4f}")
print(f"95% Confidence Interval: [{ate_interval[0]:.4f}, {ate_interval[1]:.4f}]")
print("\nInterpretation:")
if ate < 0:
    print(f"- The retention campaign REDUCES churn by {abs(ate)*100:.2f} percentage points on average")
    print(f"- This means the campaign is EFFECTIVE")
else:
    print(f"- The retention campaign INCREASES churn by {ate*100:.2f} percentage points on average")
    print(f"- This suggests the campaign may not be effective")

if ate_interval[0] < 0 < ate_interval[1]:
    print(f"- However, the confidence interval includes zero, so the effect may not be statistically significant")
else:
    print(f"- The effect is statistically significant at the 95% confidence level")

In [None]:
# Calculate Conditional Average Treatment Effect (CATE)
# This shows how treatment effect varies by customer segment
cate = dml.effect(X=X_causal)

# Add CATE to dataframe
df['cate'] = cate

print("Conditional Average Treatment Effect (CATE) Analysis")
print("="*60)
print(f"CATE Statistics:")
print(f"Mean CATE: {cate.mean():.4f}")
print(f"Median CATE: {np.median(cate):.4f}")
print(f"Min CATE: {cate.min():.4f}")
print(f"Max CATE: {cate.max():.4f}")


In [None]:
# CATE Distribution
plt.figure(figsize=(12, 6))
plt.hist(cate, bins=50, color='steelblue', edgecolor='black', alpha=0.7)
plt.axvline(cate.mean(), color='red', linestyle='--', linewidth=2, label=f'Mean CATE: {cate.mean():.4f}')
plt.axvline(0, color='black', linestyle='-', linewidth=1, label='No Effect')
plt.xlabel('CATE (Treatment Effect)', fontsize=12)
plt.ylabel('Number of Customers', fontsize=12)
plt.title('Distribution of Conditional Average Treatment Effects', fontsize=14, fontweight='bold')
plt.legend(fontsize=12)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('cate_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

print("Chart saved as: cate_distribution.png")

In [None]:
# CATE by Customer Segments
# Analyze treatment effect by different customer characteristics

fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# CATE by Plan Type
cate_by_plan = df.groupby('plan_type')['cate'].mean().sort_values()
axes[0, 0].barh(cate_by_plan.index, cate_by_plan.values, color='skyblue')
axes[0, 0].axvline(0, color='black', linestyle='--', linewidth=1)
axes[0, 0].set_xlabel('Average CATE')
axes[0, 0].set_title('Treatment Effect by Plan Type', fontweight='bold')

# CATE by Company Size
cate_by_size = df.groupby('company_size')['cate'].mean().sort_values()
axes[0, 1].barh(cate_by_size.index, cate_by_size.values, color='lightcoral')
axes[0, 1].axvline(0, color='black', linestyle='--', linewidth=1)
axes[0, 1].set_xlabel('Average CATE')
axes[0, 1].set_title('Treatment Effect by Company Size', fontweight='bold')

# CATE by Revenue Segment
df['revenue_segment'] = pd.cut(df['monthly_revenue'], bins=[0, 50, 150, 500, 2000],
                                labels=['Low', 'Medium', 'High', 'Very High'])
cate_by_revenue = df.groupby('revenue_segment')['cate'].mean().sort_values()
axes[1, 0].barh(cate_by_revenue.index, cate_by_revenue.values, color='lightgreen')
axes[1, 0].axvline(0, color='black', linestyle='--', linewidth=1)
axes[1, 0].set_xlabel('Average CATE')
axes[1, 0].set_title('Treatment Effect by Revenue Segment', fontweight='bold')

# CATE by Health Score Segment
df['health_segment'] = pd.cut(df['customer_health_score'], bins=[0, 40, 60, 80, 100],
                               labels=['Poor', 'Fair', 'Good', 'Excellent'])
cate_by_health = df.groupby('health_segment')['cate'].mean().sort_values()
axes[1, 1].barh(cate_by_health.index, cate_by_health.values, color='plum')
axes[1, 1].axvline(0, color='black', linestyle='--', linewidth=1)
axes[1, 1].set_xlabel('Average CATE')
axes[1, 1].set_title('Treatment Effect by Health Score', fontweight='bold')

plt.tight_layout()
plt.savefig('cate_by_segments.png', dpi=300, bbox_inches='tight')
plt.show()

print("Chart saved as: cate_by_segments.png")

---

# Section 2: Summary and Report Guidance


## Key Findings Checklist

Use this checklist to ensure you capture all important insights for your report:

### Section 1: Business Problem and Data Overview
- [ ] Overall churn rate and customer demographics
- [ ] Key patterns in customer behavior
- [ ] Data quality observations

### Section 2: Predictive Analytics Findings
- [ ] LightGBM model performance metrics (accuracy, precision, recall, F1, ROC-AUC)
- [ ] Confusion matrix interpretation
- [ ] Revenue forecasting accuracy (MAE, RMSE, MAPE)
- [ ] Sample forecast visualizations

### Section 3: Explainable AI Insights
- [ ] Top 5 features driving churn (from SHAP analysis)
- [ ] How different customer segments behave differently
- [ ] Feature importance rankings

### Section 4: Causal Analysis
- [ ] Average Treatment Effect (ATE) of retention campaigns
- [ ] Which customer segments benefit most from campaigns (CATE)
- [ ] Recommendations for targeted interventions

### Section 5: Recommendations
- [ ] Data-driven retention strategy recommendations
- [ ] Customer segmentation for targeted campaigns
- [ ] Resource allocation suggestions
- [ ] Expected ROI and business impact

## Writing Your Business Report

Remember to:

1. **Write for a business audience**: Avoid technical jargon. Explain findings in terms of business impact.

2. **Use the outputs above**: Reference specific numbers, charts, and findings from this notebook.

3. **Connect insights to actions**: Every finding should lead to a recommendation.

4. **Focus on ROI**: Quantify the business impact where possible (e.g., revenue at risk, cost savings).

5. **Be concise**: Maximum 1,200 words. Every sentence should add value.

6. **Professional formatting**: Use headers, bullet points, and charts effectively.

7. **Executive summary**: Start with a clear, compelling summary of your key findings and recommendations.

Good luck with your report!

---

## Additional Notes

- All code cells have been executed and outputs generated
- Charts have been saved to your working directory
- You can re-run any cell to regenerate outputs
- For questions about the analysis, consult the course materials or your instructor

**Submission Reminder:**
- Submit your business report (.docx format) via Turnitin by Tuesday (23:55 AEST), Week 5

- Ensure your report is maximum 1,200 words (excluding charts and tables)

In [2]:
import numpy as np
np.__version__

'1.26.4'

In [3]:
import scipy
scipy.__version__

'1.11.4'

In [4]:
import pandas as pd
pd.__version__

'2.3.3'

In [5]:
import shap
shap.__version__

'0.48.0'

In [6]:
import torch
torch.__version__

'2.5.1+cu121'

In [7]:
import pytorch_forecasting
pytorch_forecasting.__version__

'1.5.0'

In [8]:
import sklearn
sklearn.__version__

'1.6.1'

In [9]:
import seaborn as sns
sns.__version__

'0.13.2'

In [10]:
import lightgbm as lgb
lgb.__version__

'4.5.0'

In [11]:
import econml
econml.__version__

'0.16.0'

In [12]:
import lightning.pytorch as pl
pl.__version__

'2.6.0'