# Module 3: Driver Ratings Prediction (Regression)

This notebook covers:
- Training regression models to predict driver ratings
- Model comparison (Linear Regression, Random Forest, XGBoost)
- Feature importance and explainability

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

# Load driver features
data_dir = Path('../data')
driver_features = pd.read_csv(data_dir / 'driver_features.csv')
print(f"Loaded {len(driver_features)} drivers")

In [None]:
# Define features and target
feature_cols = ['speed_mean', 'speed_std', 'speed_max', 'hard_brake_count', 
                'overspeed_count', 'harsh_turn_count', 'avg_trip_duration', 
                'total_distance', 'trip_count']
X = driver_features[feature_cols]
y = driver_features['avg_rating']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train Random Forest
rf_model = RandomForestRegressor(n_estimators=200, max_depth=10, random_state=42)
rf_model.fit(X_train, y_train)
rf_pred = rf_model.predict(X_test)

print("=== Random Forest Results ===")
print(f"RMSE: {np.sqrt(mean_squared_error(y_test, rf_pred)):.4f}")
print(f"MAE:  {mean_absolute_error(y_test, rf_pred):.4f}")
print(f"R²:   {r2_score(y_test, rf_pred):.4f}")

In [None]:
# Feature Importance
importance = pd.DataFrame({
    'feature': feature_cols,
    'importance': rf_model.feature_importances_
}).sort_values('importance', ascending=True)

plt.figure(figsize=(10, 6))
plt.barh(importance['feature'], importance['importance'], color='steelblue')
plt.xlabel('Feature Importance')
plt.title('Random Forest Feature Importance for Rating Prediction')
plt.tight_layout()
plt.show()

# Save model
import joblib
joblib.dump(rf_model, Path('../src/rating_model.joblib'))
print("✓ Saved rating_model.joblib")