# MLflow Integration for Audit Tracking

This notebook demonstrates how to track Conformal-Drift audits with MLflow for experiment management and reproducibility.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/debu-sinha/conformaldrift/blob/main/examples/03_mlflow_tracking.ipynb)

In [None]:
# Install dependencies
!pip install conformal-drift mlflow matplotlib -q

In [None]:
import numpy as np
import mlflow
import matplotlib.pyplot as plt
from conformal_drift import ConformalDriftAuditor

# Set random seed for reproducibility
np.random.seed(42)

## 1. Set Up MLflow Experiment

In [None]:
# Configure MLflow
mlflow.set_experiment("conformal-drift-audit")

print(f"Experiment: conformal-drift-audit")
print(f"Tracking URI: {mlflow.get_tracking_uri()}")

## 2. Generate Synthetic Data

In [None]:
# Generate calibration data
n_calibration = 500
calibration_scores = np.random.beta(2, 5, n_calibration)

# Generate test data
n_test = 200
test_scores = np.random.beta(2, 5, n_test)
test_labels = np.random.binomial(1, 0.9, n_test)

test_data = {
    'scores': test_scores,
    'labels': test_labels
}

print(f"Generated {n_calibration} calibration and {n_test} test samples")

## 3. Run Audit with MLflow Tracking

In [None]:
# Configuration
alpha = 0.1
shift_types = ['temporal', 'semantic', 'lexical']
shift_intensities = np.linspace(0, 1, 11)

with mlflow.start_run(run_name="comprehensive_audit"):
    # Log parameters
    mlflow.log_param("alpha", alpha)
    mlflow.log_param("n_calibration", n_calibration)
    mlflow.log_param("n_test", n_test)
    mlflow.log_param("shift_types", shift_types)
    mlflow.log_param("n_shift_levels", len(shift_intensities))
    
    # Initialize auditor
    auditor = ConformalDriftAuditor(
        calibration_scores=calibration_scores,
        alpha=alpha
    )
    
    all_results = {}
    
    # Run audits for each shift type
    for shift_type in shift_types:
        print(f"\nRunning {shift_type} shift audit...")
        
        with mlflow.start_run(run_name=f"{shift_type}_shift", nested=True):
            mlflow.log_param("shift_type", shift_type)
            
            # Run audit
            results = auditor.audit(
                test_data=test_data,
                shift_type=shift_type,
                shift_intensity=shift_intensities
            )
            all_results[shift_type] = results
            
            # Log coverage at each shift level
            for intensity, coverage in zip(results.shift_intensities, results.coverage):
                mlflow.log_metric(f"coverage", coverage, step=int(intensity * 100))
            
            # Log summary metrics
            mlflow.log_metric("max_coverage_gap", results.max_coverage_gap)
            mlflow.log_metric("baseline_coverage", results.coverage[0])
            mlflow.log_metric("final_coverage", results.coverage[-1])
            
            # Create and log coverage curve
            fig, ax = plt.subplots(figsize=(8, 5))
            ax.plot(results.shift_intensities, results.coverage, 'b-o', linewidth=2)
            ax.axhline(y=0.9, color='r', linestyle='--', label='Nominal')
            ax.set_xlabel('Shift Intensity')
            ax.set_ylabel('Coverage')
            ax.set_title(f'{shift_type.capitalize()} Shift Coverage')
            ax.legend()
            ax.grid(True, alpha=0.3)
            fig.savefig(f'{shift_type}_coverage.png', dpi=150)
            mlflow.log_artifact(f'{shift_type}_coverage.png')
            plt.close()
            
            # Tag based on results
            if results.max_coverage_gap > 0.1:
                mlflow.set_tag("status", "CRITICAL")
            elif results.max_coverage_gap > 0.05:
                mlflow.set_tag("status", "WARNING")
            else:
                mlflow.set_tag("status", "PASS")
            
            print(f"  Max gap: {results.max_coverage_gap:.3f}")
    
    # Create comparison plot in parent run
    fig, ax = plt.subplots(figsize=(10, 6))
    colors = {'temporal': 'blue', 'semantic': 'green', 'lexical': 'orange'}
    
    for shift_type, results in all_results.items():
        ax.plot(results.shift_intensities, results.coverage, '-o',
                color=colors[shift_type], linewidth=2, label=shift_type.capitalize())
    
    ax.axhline(y=0.9, color='r', linestyle='--', linewidth=2, label='Nominal')
    ax.set_xlabel('Shift Intensity', fontsize=12)
    ax.set_ylabel('Coverage', fontsize=12)
    ax.set_title('Coverage Comparison Across Shift Types', fontsize=14)
    ax.legend()
    ax.grid(True, alpha=0.3)
    fig.savefig('comparison.png', dpi=150)
    mlflow.log_artifact('comparison.png')
    plt.close()
    
    # Log summary
    summary = {
        shift_type: {
            "max_gap": float(r.max_coverage_gap),
            "baseline": float(r.coverage[0]),
            "final": float(r.coverage[-1])
        }
        for shift_type, r in all_results.items()
    }
    mlflow.log_dict(summary, "summary.json")
    
    print("\n" + "="*50)
    print("AUDIT COMPLETE")
    print(f"Run ID: {mlflow.active_run().info.run_id}")

## 4. View Results in MLflow UI

Start the MLflow UI to view results:

```bash
mlflow ui
```

Then navigate to http://localhost:5000

In [None]:
# Query runs programmatically
runs = mlflow.search_runs(
    experiment_names=["conformal-drift-audit"],
    filter_string="tags.status = 'PASS'"
)

print(f"Found {len(runs)} passing runs")
if len(runs) > 0:
    print(runs[['run_id', 'params.shift_type', 'metrics.max_coverage_gap']].head())

## 5. Best Practices

1. **Always log parameters**: Alpha, sample sizes, shift types
2. **Log metrics at steps**: Use shift intensity as step for tracking curves
3. **Save artifacts**: Coverage plots, detailed results JSON
4. **Use tags**: Mark runs as PASS/WARNING/CRITICAL
5. **Nested runs**: Use for multi-shift audits to organize results