# Notebook 04: Model Training

**Purpose**: XGBoost training with PRD compliance and ONNX export

**Pipeline**:
1. Load feature datasets from notebook 03
2. Convert Spark DataFrames to NumPy arrays
3. Train XGBoost with 5-fold cross-validation
4. Optimize threshold (100:1 FN:FP cost ratio)
5. Evaluate on test set
6. Check PRD compliance (recall ≥98%, FNR ≤2%)
7. Export to ONNX with validation
8. Save model artifacts to S3

**Key Features**:
- Stratified cross-validation
- Early stopping
- Cost-based threshold optimization
- PRD requirement validation
- ONNX export with parity check

**Prerequisites**:
- Notebook 03 completed (feature datasets available)

**Duration**: ~30-45 minutes

## 1. Spark Configuration

Copy the Spark configuration from notebook 00 output:

In [1]:
%%configure -f
{
    "pyFiles": [
        "s3://uip-datalake-bucket-prod/sf_trino/trino_query_predictor/code/query_predictor_latest.zip",
        "s3://uipds-108043591022/dataintelligence-dev/di-airflow-prod/dags/common/utils/ParseArgs.py"
    ],
    "driverMemory": "16G",
    "driverCores": 4,
    "executorMemory": "20G",
    "executorCores": 5,
    "conf": {
        "spark.driver.maxResultSize": "8G",
        "spark.dynamicAllocation.enabled": "true",
        "spark.dynamicAllocation.minExecutors": "2",
        "spark.dynamicAllocation.maxExecutors": "20"
    }
}


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
3283,application_1758752217644_217293,pyspark,idle,Link,Link,ankit.mani,
3288,application_1758752217644_218326,pyspark,idle,Link,Link,mumath,
3301,application_1758752217644_219296,pyspark,busy,Link,Link,pmannem,
3306,application_1758752217644_219830,pyspark,idle,Link,Link,xiao.zhang,
3308,application_1758752217644_219918,pyspark,idle,Link,Link,mumath,
3312,application_1758752217644_220089,pyspark,idle,Link,Link,kautilya.patel,
3313,application_1758752217644_220121,pyspark,idle,Link,Link,kenyu.yu,
3314,application_1758752217644_220177,pyspark,idle,Link,Link,kenyu.yu,
3315,application_1758752217644_220216,pyspark,idle,Link,Link,rsinghchouhan,
3321,application_1758752217644_220635,pyspark,idle,Link,Link,w.scroggins,


## 2. Import Dependencies

In [2]:
import sys
import yaml
import numpy as np
import boto3
import json
import tempfile
from datetime import datetime
from pyspark.sql import SparkSession

# Import training modules
from query_predictor.training.model_trainer import ModelTrainer
from query_predictor.training.prd_checker import PRDChecker
from query_predictor.training.onnx_validator import ONNXValidator

print(f"Python version: {sys.version}")
print(f"PySpark version: {spark.version}")
print("✅ All imports successful")

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
3336,application_1758752217644_221350,pyspark,idle,Link,Link,pmannem,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Python version: 3.11.13 (main, Jul 30 2025, 00:00:00) [GCC 11.5.0 20240719 (Red Hat 11.5.0-5)]
PySpark version: 3.5.4-amzn-0
? All imports successful

## 3. Load Configuration

In [3]:
import boto3

# Download training configuration from S3
s3_client = boto3.client('s3')
s3_bucket = 'uip-datalake-bucket-prod'
s3_prefix = 'sf_trino/trino_query_predictor'
config_s3_key = f"{s3_prefix}/config/training_config_latest.yaml"
config_path = '/tmp/training_config.yaml'

print(f"Downloading config from S3: s3://{s3_bucket}/{config_s3_key}")
s3_client.download_file(s3_bucket, config_s3_key, config_path)
print(f"✅ Config downloaded to: {config_path}")

# Load training configuration
with open(config_path) as f:
    config = yaml.safe_load(f)

print("✅ Configuration loaded")
print(f"\n📋 Model Configuration:")
print(f"  Algorithm: {config['model']['algorithm']}")
print(f"  N estimators: {config['model']['n_estimators']}")
print(f"  CV folds: {config['model']['cv_folds']}")
print(f"  Cost FN: {config['model']['cost_fn']}")
print(f"  Cost FP: {config['model']['cost_fp']}")
print(f"\n📊 PRD Requirements:")
print(f"  Target recall: ≥{config['prd_requirements']['target_heavy_recall']}")
print(f"  Target FNR: ≤{config['prd_requirements']['target_fnr']}")
print(f"  Target F1: ≥{config['prd_requirements']['target_f1']}")
print(f"  Target ROC-AUC: ≥{config['prd_requirements']['target_roc_auc']}")

# OPTIONAL: Override config parameters after loading
# Example: Change model hyperparameters
# config['model']['n_estimators'] = 150  # More trees
# config['model']['max_depth'] = 8  # Deeper trees
# config['model']['learning_rate'] = 0.05  # Slower learning
# Example: Adjust cost ratio for threshold optimization
# config['model']['cost_fn'] = 200.0  # Higher cost for missing heavy queries
# config['model']['cost_fp'] = 1.0
# Example: Change PRD requirements for stricter validation
# config['prd_requirements']['target_heavy_recall'] = 0.99  # 99% recall target
# config['prd_requirements']['target_fnr'] = 0.01  # 1% FNR target
# Example: Change validation samples
# config['validation']['onnx_validation_samples'] = 500  # Fewer samples for faster validation
# Example: Adjust ONNX export settings
# config['validation']['onnx_opset_version'] = 13  # Different opset version

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Downloading config from S3: s3://uip-datalake-bucket-prod/sf_trino/trino_query_predictor/config/training_config_latest.yaml
? Config downloaded to: /tmp/training_config.yaml
? Configuration loaded

? Model Configuration:
  Algorithm: xgboost
  N estimators: 100
  CV folds: 5
  Cost FN: 100.0
  Cost FP: 1.0

? PRD Requirements:
  Target recall: ?0.98
  Target FNR: ?0.02
  Target F1: ?0.85
  Target ROC-AUC: ?0.9

## 4. Load Feature Datasets

In [4]:
# Define paths
features_path = config['features']['output_path']
date_range = f"{config['data_loading']['start_date']}_to_{config['data_loading']['end_date']}"
base_path = f"{features_path}/{date_range}"

train_path = f"{base_path}/train"
val_path = f"{base_path}/val"
test_path = f"{base_path}/test"

print(f"Loading feature datasets from S3...")
print(f"  Base path: {base_path}")

# Load datasets
train_df = spark.read.parquet(train_path)
val_df = spark.read.parquet(val_path)
test_df = spark.read.parquet(test_path)

print(f"\n✅ Datasets loaded:")
print(f"  Train: {train_df.count():,} queries")
print(f"  Val:   {val_df.count():,} queries")
print(f"  Test:  {test_df.count():,} queries")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Loading feature datasets from S3...
  Base path: s3://uip-datalake-bucket-prod/sf_trino/trino_query_predictor/features/2025-08-01_to_2025-10-01

? Datasets loaded:
  Train: 8,782,474 queries
  Val:   14,969,526 queries
  Test:  15,099,750 queries

## 5. Convert to NumPy Arrays

Convert Spark DataFrames to NumPy arrays for sklearn/XGBoost training.

In [None]:
print("Converting DataFrames to NumPy arrays...")
print("This may take 5-10 minutes...")

# Collect features and labels
print("\n[1/3] Collecting train data...")
train_data = train_df.select('features', 'is_heavy').collect()
X_train = np.array([row['features'] for row in train_data], dtype=np.float32)
y_train = np.array([row['is_heavy'] for row in train_data], dtype=np.int32)

print("[2/3] Collecting val data...")
val_data = val_df.select('features', 'is_heavy').collect()
X_val = np.array([row['features'] for row in val_data], dtype=np.float32)
y_val = np.array([row['is_heavy'] for row in val_data], dtype=np.int32)

print("[3/3] Collecting test data...")
test_data = test_df.select('features', 'is_heavy').collect()
X_test = np.array([row['features'] for row in test_data], dtype=np.float32)
y_test = np.array([row['is_heavy'] for row in test_data], dtype=np.int32)

print(f"\n✅ Conversion complete:")
print(f"  X_train shape: {X_train.shape}")
print(f"  X_val shape:   {X_val.shape}")
print(f"  X_test shape:  {X_test.shape}")
print(f"\nClass distribution:")
print(f"  Train: {np.mean(y_train):.2%} heavy")
print(f"  Val:   {np.mean(y_val):.2%} heavy")
print(f"  Test:  {np.mean(y_test):.2%} heavy")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## 6. Train XGBoost Model

Train with cross-validation and threshold optimization.

In [None]:
print("="*70)
print("TRAINING XGBOOST MODEL")
print("="*70)

# Initialize trainer
trainer = ModelTrainer(config)

# Train model
print("\nStarting training...")
print("This may take 20-30 minutes...\n")

training_results = trainer.train(
    X_train=X_train,
    y_train=y_train,
    X_val=X_val,
    y_val=y_val
)

print("\n" + "="*70)
print("TRAINING COMPLETE")
print("="*70)

## 7. Evaluate on Test Set

In [None]:
print("\nEvaluating on test set...")

test_metrics = trainer.evaluate(
    model=training_results.model,
    X_test=X_test,
    y_test=y_test,
    threshold=training_results.optimal_threshold
)

print(f"\n📊 Test Set Metrics:")
print(f"  Threshold: {test_metrics['threshold']:.4f}")
print(f"  Precision: {test_metrics['precision']:.4f}")
print(f"  Recall:    {test_metrics['recall']:.4f}")
print(f"  F1 Score:  {test_metrics['f1']:.4f}")
print(f"  ROC-AUC:   {test_metrics['roc_auc']:.4f}")
print(f"  FNR:       {test_metrics['fnr']:.4f}")
print(f"  FPR:       {test_metrics['fpr']:.4f}")

cm = test_metrics['confusion_matrix']
print(f"\n📈 Confusion Matrix:")
print(f"  TN: {cm['tn']:,}  FP: {cm['fp']:,}")
print(f"  FN: {cm['fn']:,}  TP: {cm['tp']:,}")

## 8. PRD Compliance Check

Validate that model meets PRD requirements.

In [None]:
print("\n" + "="*70)
print("PRD COMPLIANCE CHECK")
print("="*70)

# Initialize PRD checker
prd_checker = PRDChecker(config)

# Check compliance
prd_report = prd_checker.check_compliance(test_metrics, strict=True)

# Generate text report
report_text = prd_checker.generate_report_text(prd_report)
print(report_text)

# Raise error if critical requirements not met
if not prd_report.summary['critical_requirements_met']:
    print("\n❌ CRITICAL: Model does not meet PRD requirements!")
    print("Cannot proceed with ONNX export.")
    raise ValueError("Model failed PRD compliance check")

print("\n✅ PRD compliance validated - proceeding with ONNX export")

## 9. Export to ONNX

Convert XGBoost model to ONNX format for production inference.

In [None]:
print("\n" + "="*70)
print("ONNX EXPORT")
print("="*70)

# Initialize ONNX validator from config
onnx_validator = ONNXValidator(config=config)

# Create temporary ONNX file
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmp:
    onnx_path = tmp.name

# Export to ONNX (opset version from config)
input_size = X_train.shape[1]
onnx_validator.export_to_onnx(
    model=training_results.model,
    output_path=onnx_path,
    input_size=input_size
)

print(f"\n✅ ONNX export complete: {onnx_path}")

## 10. Validate ONNX Predictions

Ensure ONNX predictions match XGBoost predictions.

In [None]:
print("\n" + "="*70)
print("ONNX VALIDATION")
print("="*70)

# Get validation config
validation_config = config.get('validation', {})
n_samples = validation_config.get('onnx_validation_samples', 1000)

# Validate ONNX
onnx_result = onnx_validator.validate_onnx(
    xgb_model=training_results.model,
    onnx_path=onnx_path,
    X_test=X_test,
    n_samples=n_samples
)

# Generate report
validation_report = onnx_validator.generate_validation_report(onnx_result)
print(validation_report)

# Raise error if validation failed
if not onnx_result.passed:
    print("\n⚠️  WARNING: ONNX validation failed!")
    print("ONNX predictions differ significantly from XGBoost.")
    raise ValueError("ONNX validation failed")

print("✅ ONNX validation passed")

## 11. Save Model Artifacts to S3

In [None]:
print("\nSaving model artifacts to S3...")

s3_client = boto3.client('s3')
models_path = config['model']['models_path']
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
model_version = f"{date_range}_{timestamp}"

# 1. Save XGBoost model
print("\n[1/4] Saving XGBoost model...")
with tempfile.NamedTemporaryFile(suffix='.pkl', delete=False) as tmp:
    trainer.save_model(training_results.model, tmp.name)
    xgb_s3_key = f"{config['s3']['prefix']}/models/xgboost_{model_version}.pkl"
    s3_client.upload_file(tmp.name, config['s3']['bucket'], xgb_s3_key)
    xgb_s3_path = f"s3://{config['s3']['bucket']}/{xgb_s3_key}"
    print(f"  ✅ Uploaded: {xgb_s3_path}")

# 2. Save ONNX model
print("[2/4] Saving ONNX model...")
onnx_s3_key = f"{config['s3']['prefix']}/models/model_{model_version}.onnx"
s3_client.upload_file(onnx_path, config['s3']['bucket'], onnx_s3_key)
onnx_s3_path = f"s3://{config['s3']['bucket']}/{onnx_s3_key}"
print(f"  ✅ Uploaded: {onnx_s3_path}")

# Also save as "latest"
latest_onnx_key = f"{config['s3']['prefix']}/models/model_latest.onnx"
s3_client.copy_object(
    Bucket=config['s3']['bucket'],
    CopySource={'Bucket': config['s3']['bucket'], 'Key': onnx_s3_key},
    Key=latest_onnx_key
)
print(f"  ✅ Updated latest: s3://{config['s3']['bucket']}/{latest_onnx_key}")

# 3. Save threshold
print("[3/4] Saving optimal threshold...")
threshold_data = {
    'optimal_threshold': float(training_results.optimal_threshold),
    'model_version': model_version,
    'timestamp': timestamp
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp:
    json.dump(threshold_data, tmp, indent=2)
    tmp.flush()
    threshold_s3_key = f"{config['s3']['prefix']}/models/threshold_{model_version}.json"
    s3_client.upload_file(tmp.name, config['s3']['bucket'], threshold_s3_key)
    print(f"  ✅ Uploaded: s3://{config['s3']['bucket']}/{threshold_s3_key}")

# Also save as "latest"
latest_threshold_key = f"{config['s3']['prefix']}/models/threshold_latest.json"
s3_client.copy_object(
    Bucket=config['s3']['bucket'],
    CopySource={'Bucket': config['s3']['bucket'], 'Key': threshold_s3_key},
    Key=latest_threshold_key
)

# 4. Save training metadata
print("[4/4] Saving training metadata...")
metadata = {
    'model_version': model_version,
    'timestamp': timestamp,
    'date_range': date_range,
    'config': {
        'n_estimators': config['model']['n_estimators'],
        'max_depth': config['model']['max_depth'],
        'learning_rate': config['model']['learning_rate'],
        'cost_fn': config['model']['cost_fn'],
        'cost_fp': config['model']['cost_fp']
    },
    'dataset_sizes': {
        'train': int(X_train.shape[0]),
        'val': int(X_val.shape[0]),
        'test': int(X_test.shape[0])
    },
    'cv_results': {
        'mean_metrics': training_results.cv_results.mean_metrics,
        'best_iteration': training_results.cv_results.best_iteration
    },
    'optimal_threshold': float(training_results.optimal_threshold),
    'test_metrics': test_metrics,
    'prd_compliance': prd_checker.export_report_json(prd_report),
    'onnx_validation': {
        'passed': onnx_result.passed,
        'max_difference': onnx_result.max_difference,
        'mean_difference': onnx_result.mean_difference,
        'mismatch_rate': onnx_result.mismatch_rate
    },
    's3_paths': {
        'xgboost_model': xgb_s3_path,
        'onnx_model': onnx_s3_path,
        'threshold': f"s3://{config['s3']['bucket']}/{threshold_s3_key}"
    }
}

with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp:
    json.dump(metadata, tmp, indent=2)
    tmp.flush()
    metadata_s3_key = f"{config['s3']['prefix']}/metadata/training_{model_version}.json"
    s3_client.upload_file(tmp.name, config['s3']['bucket'], metadata_s3_key)
    print(f"  ✅ Uploaded: s3://{config['s3']['bucket']}/{metadata_s3_key}")

print("\n✅ All artifacts saved to S3")

## 12. Training Summary

In [None]:
print("\n" + "="*70)
print("TRAINING SUMMARY")
print("="*70)

print(f"\n📦 Model Version: {model_version}")
print(f"\n📊 Cross-Validation Results:")
cv_metrics = training_results.cv_results.mean_metrics
cv_std = training_results.cv_results.std_metrics
for metric, value in cv_metrics.items():
    print(f"  {metric}: {value:.4f} ± {cv_std[metric]:.4f}")

print(f"\n🎯 Optimal Threshold: {training_results.optimal_threshold:.4f}")

print(f"\n📈 Test Set Performance:")
print(f"  Recall:    {test_metrics['recall']:.4f}")
print(f"  Precision: {test_metrics['precision']:.4f}")
print(f"  F1 Score:  {test_metrics['f1']:.4f}")
print(f"  ROC-AUC:   {test_metrics['roc_auc']:.4f}")
print(f"  FNR:       {test_metrics['fnr']:.4f}")

print(f"\n✅ PRD Compliance: {prd_report.overall_status.value}")
print(f"  Critical requirements met: {prd_report.summary['critical_requirements_met']}")
print(f"  All requirements met: {prd_report.summary['all_requirements_met']}")

print(f"\n✅ ONNX Validation: {'PASSED' if onnx_result.passed else 'FAILED'}")
print(f"  Max difference: {onnx_result.max_difference:.9f}")
print(f"  Mismatch rate: {onnx_result.mismatch_rate:.2f}%")

print(f"\n☁️  S3 Artifacts:")
print(f"  ONNX model:   {onnx_s3_path}")
print(f"  XGBoost model: {xgb_s3_path}")
print(f"  Threshold:     s3://{config['s3']['bucket']}/{threshold_s3_key}")
print(f"  Metadata:      s3://{config['s3']['bucket']}/{metadata_s3_key}")

print("\n" + "="*70)
print("MODEL TRAINING COMPLETE!")
print("="*70)

print("\n✅ Next Steps:")
print("1. Deploy ONNX model to production inference service")
print("2. Configure threshold in service config")
print("3. Monitor production metrics")
print("4. (Optional) Run notebook 05 for offline replay validation")
print("="*70)