# Feature Drift Monitoring for A/B Tests

**This example illustrates a specific section within Chapter 6 (6.4.2) which is all about detecting feature drift during model evaluations**

Specifically, this notebook demonstrates:
1. **Statistical drift detection** using KS tests and Population Stability Index (PSI)
2. **Categorical feature drift** detection using Chi-square tests
3. **Production-ready monitoring** class for real A/B test pipelines
4. **Visualization and alerting** for drift detection

## Why This Matters
AI models rely on input features, and those features can drift over time. Feature drift occurs when the distribution of input features changes between training and production, which can cause models to perform unpredictably during A/B tests.

 Note: While A/B tests will expose drift if it’s present, it’s way better to detect and address drift earlier through offline monitoring and validation so that online experiments aren’t the first place this type of issue is caught.

In [None]:
# Setup and Imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

## 1. Simulate Training and Production Data

Let's create realistic training data for a movie recommendation model, then simulate production data with drift.

In [None]:
# Training data (what the model was trained on)
np.random.seed(42)
n_training = 10000

training_data = pd.DataFrame({
    'user_age': np.random.normal(35, 12, n_training),
    'session_duration': np.random.exponential(15, n_training),  # minutes
    'time_of_day': np.random.normal(19, 3, n_training) % 24,    # peak at 7 PM
    'device_type': np.random.choice(['mobile', 'desktop', 'tv'], n_training, p=[0.6, 0.3, 0.1]),
    'genres_explored': np.random.poisson(3, n_training),
    'date': pd.date_range('2024-01-01', periods=n_training, freq='H')
})

print("Training Data Summary:")
print(training_data.describe())
print(f"\nTraining period: {training_data['date'].min()} to {training_data['date'].max()}")

In [None]:
# Production data with drift - users are younger, sessions shorter, more mobile usage
n_production = 5000

production_data = pd.DataFrame({
    'user_age': np.random.normal(28, 10, n_production),          # Younger users
    'session_duration': np.random.exponential(12, n_production), # Shorter sessions
    'time_of_day': np.random.normal(20, 4, n_production) % 24,   # Later viewing
    'device_type': np.random.choice(['mobile', 'desktop', 'tv'], n_production, p=[0.8, 0.15, 0.05]), # More mobile
    'genres_explored': np.random.poisson(2.5, n_production),     # Less exploration
    'date': pd.date_range('2024-06-01', periods=n_production, freq='H')
})

print("Production Data Summary:")
print(production_data.describe())

## 2. Statistical Drift Detection Functions

We'll implement two key methods for detecting drift:
- **Kolmogorov-Smirnov Test**: Compares distributions statistically
- **Population Stability Index (PSI)**: Measures the shift in population distribution

In [None]:
def kolmogorov_smirnov_test(training_feature, production_feature, feature_name, threshold=0.05):
    """
    Perform KS test to detect distribution drift
    """
    statistic, p_value = stats.ks_2samp(training_feature, production_feature)
    
    drift_detected = p_value < threshold
    
    result = {
        'feature': feature_name,
        'ks_statistic': statistic,
        'p_value': p_value,
        'drift_detected': drift_detected,
        'severity': 'HIGH' if statistic > 0.2 else 'MEDIUM' if statistic > 0.1 else 'LOW'
    }
    
    return result

def population_stability_index(training_feature, production_feature, bins=10):
    """
    Calculate Population Stability Index (PSI) to measure drift
    PSI < 0.1: No significant change
    0.1 ≤ PSI < 0.25: Some minor change
    PSI ≥ 0.25: Major shift in population
    """
    # Create bins based on training data
    _, bin_edges = np.histogram(training_feature, bins=bins)
    
    # Calculate distributions
    training_counts, _ = np.histogram(training_feature, bins=bin_edges)
    production_counts, _ = np.histogram(production_feature, bins=bin_edges)
    
    # Normalize to percentages and avoid division by zero
    training_pct = training_counts / len(training_feature)
    production_pct = production_counts / len(production_feature)
    
    # Add small constant to avoid log(0)
    training_pct = np.maximum(training_pct, 0.0001)
    production_pct = np.maximum(production_pct, 0.0001)
    
    # Calculate PSI
    psi = np.sum((production_pct - training_pct) * np.log(production_pct / training_pct))
    
    return psi

## 3. Run Drift Detection Analysis

In [None]:
print("=== FEATURE DRIFT DETECTION RESULTS ===\n")

# Numerical features drift detection
numerical_features = ['user_age', 'session_duration', 'time_of_day', 'genres_explored']
drift_results = []

for feature in numerical_features:
    # KS Test
    ks_result = kolmogorov_smirnov_test(
        training_data[feature], 
        production_data[feature], 
        feature
    )
    
    # PSI
    psi_score = population_stability_index(
        training_data[feature], 
        production_data[feature]
    )
    
    ks_result['psi_score'] = psi_score
    
    if psi_score >= 0.25:
        ks_result['psi_interpretation'] = 'MAJOR SHIFT'
    elif psi_score >= 0.1:
        ks_result['psi_interpretation'] = 'MINOR CHANGE'
    else:
        ks_result['psi_interpretation'] = 'STABLE'
    
    drift_results.append(ks_result)
    
    print(f"Feature: {feature}")
    print(f"  KS Statistic: {ks_result['ks_statistic']:.4f}")
    print(f"  P-value: {ks_result['p_value']:.4f}")
    print(f"  PSI Score: {psi_score:.4f}")
    print(f"  Drift Status: {'🚨 DRIFT DETECTED' if ks_result['drift_detected'] else '✅ STABLE'}")
    print(f"  PSI Interpretation: {ks_result['psi_interpretation']}")
    print()

## 4. Categorical Feature Drift Detection

In [None]:
def categorical_drift_detection(training_cat, production_cat, feature_name):
    """
    Detect drift in categorical features using Chi-square test
    """
    # Get value counts
    training_counts = training_cat.value_counts()
    production_counts = production_cat.value_counts()
    
    # Align categories
    all_categories = set(training_counts.index) | set(production_counts.index)
    
    training_aligned = [training_counts.get(cat, 0) for cat in all_categories]
    production_aligned = [production_counts.get(cat, 0) for cat in all_categories]
    
    # Chi-square test
    chi2_stat, p_value = stats.chisquare(production_aligned, training_aligned)
    
    return {
        'feature': feature_name,
        'chi2_statistic': chi2_stat,
        'p_value': p_value,
        'drift_detected': p_value < 0.05,
        'training_dist': training_counts / len(training_cat),
        'production_dist': production_counts / len(production_cat)
    }

# Test categorical drift
device_drift = categorical_drift_detection(
    training_data['device_type'], 
    production_data['device_type'], 
    'device_type'
)

print("Categorical Feature Analysis - Device Type:")
print(f"Chi-square statistic: {device_drift['chi2_statistic']:.4f}")
print(f"P-value: {device_drift['p_value']:.4f}")
print(f"Drift Status: {'🚨 DRIFT DETECTED' if device_drift['drift_detected'] else '✅ STABLE'}")
print("\nDistribution Comparison:")
print("Training Distribution:")
print(device_drift['training_dist'])
print("\nProduction Distribution:")
print(device_drift['production_dist'])

## 5. Visualization of Drift

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle('Feature Drift Visualization: Training vs Production Data', fontsize=16)

# Plot distributions for numerical features
features_to_plot = ['user_age', 'session_duration', 'time_of_day', 'genres_explored']

for i, feature in enumerate(features_to_plot):
    ax = axes[i//2, i%2]
    
    # Plot histograms
    ax.hist(training_data[feature], alpha=0.6, label='Training', bins=30, density=True)
    ax.hist(production_data[feature], alpha=0.6, label='Production', bins=30, density=True)
    
    ax.set_title(f'{feature.replace("_", " ").title()} Distribution')
    ax.set_xlabel(feature.replace("_", " ").title())
    ax.set_ylabel('Density')
    ax.legend()
    
    # Add drift status
    drift_status = next(r for r in drift_results if r['feature'] == feature)
    status_text = f"PSI: {drift_status['psi_score']:.3f}\nStatus: {drift_status['psi_interpretation']}"
    ax.text(0.02, 0.98, status_text, transform=ax.transAxes, 
            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.tight_layout()
plt.show()

## 6. Production-Ready Drift Monitoring Class

This class can be integrated into your A/B test monitoring pipeline.

In [None]:
class FeatureDriftMonitor:
    """
    A production-ready class for monitoring feature drift during A/B tests
    """
    
    def __init__(self, reference_data, feature_columns, categorical_features=None):
        self.reference_data = reference_data[feature_columns].copy()
        self.feature_columns = feature_columns
        self.categorical_features = categorical_features or []
        self.numerical_features = [f for f in feature_columns if f not in self.categorical_features]
        
    def check_drift(self, current_data, ks_threshold=0.05, psi_threshold=0.25):
        """
        Check for drift in current data compared to reference
        """
        drift_report = {
            'timestamp': datetime.now(),
            'total_features': len(self.feature_columns),
            'features_with_drift': 0,
            'high_risk_features': [],
            'medium_risk_features': [],
            'detailed_results': []
        }
        
        # Check numerical features
        for feature in self.numerical_features:
            if feature in current_data.columns:
                ks_result = kolmogorov_smirnov_test(
                    self.reference_data[feature],
                    current_data[feature],
                    feature,
                    ks_threshold
                )
                
                psi_score = population_stability_index(
                    self.reference_data[feature],
                    current_data[feature]
                )
                
                ks_result['psi_score'] = psi_score
                
                if ks_result['drift_detected'] or psi_score > psi_threshold:
                    drift_report['features_with_drift'] += 1
                    
                    if psi_score > 0.25:
                        drift_report['high_risk_features'].append(feature)
                    else:
                        drift_report['medium_risk_features'].append(feature)
                
                drift_report['detailed_results'].append(ks_result)
        
        # Check categorical features
        for feature in self.categorical_features:
            if feature in current_data.columns:
                cat_result = categorical_drift_detection(
                    self.reference_data[feature],
                    current_data[feature],
                    feature
                )
                
                if cat_result['drift_detected']:
                    drift_report['features_with_drift'] += 1
                    drift_report['high_risk_features'].append(feature)
                
                drift_report['detailed_results'].append(cat_result)
        
        return drift_report
    
    def get_alert_summary(self, drift_report):
        """
        Generate human-readable alert summary
        """
        summary = f"""
🔍 FEATURE DRIFT MONITORING REPORT
Time: {drift_report['timestamp'].strftime('%Y-%m-%d %H:%M:%S')}

📊 OVERVIEW:
- Total features monitored: {drift_report['total_features']}
- Features with drift detected: {drift_report['features_with_drift']}

⚠️  HIGH RISK FEATURES: {', '.join(drift_report['high_risk_features']) if drift_report['high_risk_features'] else 'None'}

⚡ MEDIUM RISK FEATURES: {', '.join(drift_report['medium_risk_features']) if drift_report['medium_risk_features'] else 'None'}

🚨 RECOMMENDATION: 
"""
        
        if drift_report['features_with_drift'] == 0:
            summary += "✅ No significant drift detected. A/B test can continue safely."
        elif len(drift_report['high_risk_features']) > 0:
            summary += "🛑 High risk drift detected! Consider pausing A/B test and investigating model performance."
        else:
            summary += "⚠️  Minor drift detected. Monitor closely and consider retraining if performance degrades."
        
        return summary

## 7. Example Usage of Drift Monitor

In [None]:
# Initialize monitor with training data
monitor = FeatureDriftMonitor(
    reference_data=training_data,
    feature_columns=['user_age', 'session_duration', 'time_of_day', 'genres_explored', 'device_type'],
    categorical_features=['device_type']
)

# Check drift on production data
drift_report = monitor.check_drift(production_data)

# Print alert summary
print("=== PRODUCTION DRIFT MONITORING ALERT ===")
print(monitor.get_alert_summary(drift_report))

## 8. Integration with A/B Test Monitoring

Here's how to integrate drift monitoring with your A/B test pipeline:

In [None]:
def integrate_with_ab_test_monitoring():
    """
    Example of how to integrate drift monitoring with A/B test pipeline
    """
    print("""
💡 INTEGRATION RECOMMENDATIONS:

1. **Real-time Monitoring**: Run drift checks every few hours during A/B test
   
2. **Automated Alerts**: Set up alerts when PSI > 0.25 or multiple features show drift
   
3. **Guardrail Metrics**: Include drift scores as guardrail metrics in your A/B test dashboard
   
4. **Action Triggers**: 
   - PSI > 0.25: Pause test, investigate
   - 3+ features with drift: Consider early termination
   - Categorical shift > 20%: Check targeting criteria
   
5. **Logging**: Store drift metrics alongside A/B test results for post-hoc analysis

Example code integration:
```python
# In your A/B test monitoring pipeline
if drift_report['features_with_drift'] > 2:
    send_alert_to_team(drift_report)
    
if len(drift_report['high_risk_features']) > 0:
    consider_pausing_experiment()
    
# Log drift metrics for analysis
log_drift_metrics(drift_report, experiment_id="movie_rec_v2")
```
    """)

integrate_with_ab_test_monitoring()

## Summary

1. **Feature drift can silently undermine A/B test results** - Monitor proactively
2. **Monitor both numerical and categorical features** - Use appropriate statistical tests for each
3. **Set up automated alerts** - Don't wait for manual checks to catch drift
4. **Include drift as guardrail metrics** - Treat it as seriously as your success metrics
5. **High drift (PSI > 0.25) requires action** - Consider pausing the test to investigate

Feature drift monitoring is imporant if you want to trust the insights from your A/B test evaluation.
