# Model Training & Optimization - Churn Prediction

This notebook demonstrates:
1. **Baseline Training**: First training run using our pipeline
2. **Model Evaluation**: Performance analysis and diagnostics
3. **Feature Importance**: Understanding key drivers
4. **Optimization Opportunities**: Hyperparameter tuning strategies
5. **Model Interpretation**: SHAP values and explainability

---

In [None]:
# Standard imports
import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
import shap
from catboost import CatBoostClassifier
from plotly.subplots import make_subplots
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix,
    f1_score,
    precision_recall_curve,
    precision_score,
    recall_score,
    roc_auc_score,
    roc_curve,
)
from sklearn.model_selection import train_test_split

# Add project root to path
sys.path.append(str(Path.cwd().parent))

# Load environment
from dotenv import load_dotenv
load_dotenv(Path.cwd().parent / '.env')

# Project imports
from src.data.loaders import load_features_from_gcs
from src.data.schemas import ChurnFeatureSchema
from src.models.catboost import CatBoostModel

# Configuration
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("✓ Imports loaded")

## 1. Load and Prepare Data

In [None]:
# Load validated features
df_features = load_features_from_gcs(
    "gs://modern-tabular-dev/data/features/churn_features.parquet",
    os.getenv('GCS_KEY_ID'),
    os.getenv('GCS_SECRET'),
)

df = df_features.to_pandas()

# Prepare features and target
feature_cols = ChurnFeatureSchema.get_feature_columns()
categorical_cols = ChurnFeatureSchema.get_categorical_columns()

X = df[feature_cols].copy()
y = df['has_churned']

# Ensure proper types
for col in categorical_cols:
    if col in X.columns:
        X[col] = X[col].astype(str)

print(f"Features: {X.shape[1]}")
print(f"Samples: {X.shape[0]:,}")
print(f"Churn rate: {y.mean():.2%}")

In [None]:
# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print(f"Train set: {len(X_train):,} samples ({y_train.mean():.2%} churn)")
print(f"Test set: {len(X_test):,} samples ({y_test.mean():.2%} churn)")

## 2. Baseline Model Training

Train a CatBoost model with default parameters as baseline.

In [None]:
# Train baseline model
baseline_model = CatBoostModel(
    iterations=100,
    learning_rate=0.1,
    depth=6,
    random_state=42,
    verbose=False
)

print("Training baseline model...")
baseline_model.fit(X_train, y_train, eval_set=(X_test, y_test))
print("✓ Training complete")

## 3. Model Evaluation

In [None]:
# Predictions
y_pred = baseline_model.predict(X_test)
y_pred_proba = baseline_model.predict_proba(X_test)[:, 1]

# Calculate metrics
metrics = {
    'Accuracy': accuracy_score(y_test, y_pred),
    'Precision': precision_score(y_test, y_pred),
    'Recall': recall_score(y_test, y_pred),
    'F1 Score': f1_score(y_test, y_pred),
    'ROC-AUC': roc_auc_score(y_test, y_pred_proba),
}

# Display metrics
metrics_df = pd.DataFrame(list(metrics.items()), columns=['Metric', 'Value'])
metrics_df['Value'] = metrics_df['Value'].apply(lambda x: f"{x:.4f}")

print("\n" + "="*50)
print("BASELINE MODEL PERFORMANCE")
print("="*50)
print(metrics_df.to_string(index=False))
print("="*50)

In [None]:
# Confusion Matrix
cm = confusion_matrix(y_test, y_pred)

fig = px.imshow(cm, 
                text_auto=True,
                labels=dict(x="Predicted", y="Actual"),
                x=['No Churn', 'Churn'],
                y=['No Churn', 'Churn'],
                title='Confusion Matrix',
                color_continuous_scale='Blues')
fig.update_xaxes(side="bottom")
fig.show()

print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=['No Churn', 'Churn']))

In [None]:
# ROC Curve
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)
roc_auc = roc_auc_score(y_test, y_pred_proba)

fig = go.Figure()
fig.add_trace(go.Scatter(x=fpr, y=tpr, mode='lines', 
                         name=f'ROC Curve (AUC = {roc_auc:.4f})',
                         line=dict(color='darkorange', width=2)))
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='lines',
                         name='Random Classifier',
                         line=dict(color='navy', width=2, dash='dash')))

fig.update_layout(title='ROC Curve',
                  xaxis_title='False Positive Rate',
                  yaxis_title='True Positive Rate',
                  width=700, height=500)
fig.show()

In [None]:
# Precision-Recall Curve
precision, recall, pr_thresholds = precision_recall_curve(y_test, y_pred_proba)

fig = go.Figure()
fig.add_trace(go.Scatter(x=recall, y=precision, mode='lines',
                         name='Precision-Recall Curve',
                         line=dict(color='darkgreen', width=2)))

fig.update_layout(title='Precision-Recall Curve',
                  xaxis_title='Recall',
                  yaxis_title='Precision',
                  width=700, height=500)
fig.show()

## 4. Feature Importance Analysis

In [None]:
# Get feature importance
importance = baseline_model.get_feature_importance()
importance_df = pd.DataFrame(list(importance.items()), 
                             columns=['Feature', 'Importance']).sort_values('Importance', ascending=False)

# Top 20 features
top_features = importance_df.head(20)

fig = px.bar(top_features.sort_values('Importance'), 
             x='Importance', y='Feature',
             orientation='h',
             title='Top 20 Feature Importances',
             labels={'Importance': 'Importance Score'})
fig.update_layout(height=600)
fig.show()

print("\nTop 10 Features:")
print(importance_df.head(10).to_string(index=False))

## 5. Model Interpretation with SHAP

SHAP (SHapley Additive exPlanations) provides model-agnostic feature importance.

In [None]:
# Create SHAP explainer
print("Computing SHAP values (this may take a minute)...")
explainer = shap.TreeExplainer(baseline_model.model)
shap_values = explainer.shap_values(X_test)
print("✓ SHAP values computed")

In [None]:
# SHAP summary plot
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values, X_test, plot_type="bar", show=False, max_display=20)
plt.title("SHAP Feature Importance", fontsize=14, pad=20)
plt.tight_layout()
plt.show()

In [None]:
# SHAP summary plot (detailed)
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values, X_test, show=False, max_display=20)
plt.title("SHAP Summary Plot - Feature Impact on Predictions", fontsize=14, pad=20)
plt.tight_layout()
plt.show()

In [None]:
# SHAP dependence plot for top feature
top_feature = importance_df.iloc[0]['Feature']
plt.figure(figsize=(10, 6))
shap.dependence_plot(top_feature, shap_values, X_test, show=False)
plt.title(f"SHAP Dependence Plot - {top_feature}", fontsize=14, pad=20)
plt.tight_layout()
plt.show()

## 6. Error Analysis

In [None]:
# Analyze false positives and false negatives
results_df = X_test.copy()
results_df['true_label'] = y_test.values
results_df['predicted_label'] = y_pred
results_df['predicted_proba'] = y_pred_proba
results_df['correct'] = results_df['true_label'] == results_df['predicted_label']

# False positives (predicted churn, but didn't churn)
false_positives = results_df[(results_df['true_label'] == False) & 
                             (results_df['predicted_label'] == True)]

# False negatives (didn't predict churn, but churned)
false_negatives = results_df[(results_df['true_label'] == True) & 
                             (results_df['predicted_label'] == False)]

print(f"False Positives: {len(false_positives)} ({len(false_positives)/len(results_df):.2%})")
print(f"False Negatives: {len(false_negatives)} ({len(false_negatives)/len(results_df):.2%})")

# Analyze false negatives (more costly - missed churners)
if len(false_negatives) > 0:
    print("\nFalse Negative Characteristics (Missed Churners):")
    print(false_negatives[['tenure_months', 'monthly_charges', 'contract_type', 
                           'churn_risk_score', 'predicted_proba']].describe())

In [None]:
# Prediction confidence distribution
fig = go.Figure()

fig.add_trace(go.Histogram(x=results_df[results_df['true_label'] == False]['predicted_proba'],
                           name='Actual No Churn',
                           opacity=0.7,
                           nbinsx=30))

fig.add_trace(go.Histogram(x=results_df[results_df['true_label'] == True]['predicted_proba'],
                           name='Actual Churn',
                           opacity=0.7,
                           nbinsx=30))

fig.update_layout(title='Prediction Confidence Distribution',
                  xaxis_title='Predicted Probability of Churn',
                  yaxis_title='Count',
                  barmode='overlay')
fig.show()

## 7. Optimization Opportunities

### Areas to Improve Model Performance:

#### 1. **Hyperparameter Tuning**
- `iterations`: Try [200, 500, 1000] for better convergence
- `learning_rate`: Experiment with [0.01, 0.03, 0.05, 0.1]
- `depth`: Test [4, 6, 8, 10] to balance complexity
- `l2_leaf_reg`: Add regularization [1, 3, 5, 7, 9]
- `border_count`: Try [32, 64, 128, 254] for better splits

#### 2. **Class Imbalance Handling**
- Current churn rate: ~26.5%
- Options:
  - `class_weights`: Auto-balance classes
  - `scale_pos_weight`: Boost minority class
  - SMOTE oversampling
  - Adjust decision threshold (optimize for F1 or recall)

#### 3. **Feature Engineering**
- Create interaction features (tenure × charges)
- Polynomial features for numeric columns
- Temporal features (if date data available)
- Customer segmentation clusters

#### 4. **Feature Selection**
- Remove low-importance features (<1% importance)
- Check for multicollinearity
- Use recursive feature elimination

#### 5. **Ensemble Methods**
- Stack CatBoost with LightGBM/XGBoost
- Voting classifier
- Use cross-validation for robustness

#### 6. **Business-Focused Optimization**
- Optimize for **recall** if cost of missing churner is high
- Optimize for **precision** if retention campaign cost is high
- Use custom loss function based on business metrics

## 8. Hyperparameter Tuning Example

Let's try a better configuration based on insights.

In [None]:
# Optimized model configuration
optimized_model = CatBoostModel(
    iterations=500,                # More iterations
    learning_rate=0.03,            # Lower learning rate for stability
    depth=8,                       # Deeper trees
    l2_leaf_reg=3,                 # L2 regularization
    random_state=42,
    verbose=False,
    # Additional params
    eval_metric='AUC',
    early_stopping_rounds=50,
)

print("Training optimized model...")
optimized_model.fit(X_train, y_train, eval_set=(X_test, y_test))
print("✓ Training complete")

In [None]:
# Evaluate optimized model
y_pred_opt = optimized_model.predict(X_test)
y_pred_proba_opt = optimized_model.predict_proba(X_test)[:, 1]

optimized_metrics = {
    'Accuracy': accuracy_score(y_test, y_pred_opt),
    'Precision': precision_score(y_test, y_pred_opt),
    'Recall': recall_score(y_test, y_pred_opt),
    'F1 Score': f1_score(y_test, y_pred_opt),
    'ROC-AUC': roc_auc_score(y_test, y_pred_proba_opt),
}

# Compare models
comparison = pd.DataFrame({
    'Metric': list(metrics.keys()),
    'Baseline': list(metrics.values()),
    'Optimized': list(optimized_metrics.values()),
})
comparison['Improvement'] = comparison['Optimized'] - comparison['Baseline']
comparison['Improvement %'] = (comparison['Improvement'] / comparison['Baseline'] * 100).round(2)

print("\n" + "="*70)
print("MODEL COMPARISON")
print("="*70)
print(comparison.to_string(index=False))
print("="*70)

## 9. Recommendations for Production

### Next Steps:

1. **Implement proper hyperparameter search**:
   - Use Optuna or Ray Tune for automated tuning
   - Track experiments in MLflow
   - Use cross-validation (5-fold or stratified)

2. **Monitor model performance**:
   - Set up data drift detection
   - Track prediction distributions
   - Monitor feature importance shifts

3. **Business integration**:
   - Define actionable threshold for retention campaigns
   - Calculate expected ROI per customer segment
   - A/B test model predictions vs baseline

4. **Model improvements**:
   - Add more temporal features if data available
   - Experiment with neural networks for comparison
   - Build calibrated probability estimates

5. **Deployment considerations**:
   - Export model in production format (ONNX, pickle)
   - Set up prediction API
   - Implement monitoring and alerting
   - Version control models in MLflow
