In [None]:
# Extensions: compare tree-based models with CV and show feature importances
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import cross_validate, KFold
import warnings
warnings.filterwarnings('ignore')

models = {
    'RandomForest': RandomForestRegressor(n_estimators=200, random_state=42),
    'GradientBoosting': GradientBoostingRegressor(random_state=42),
}
# Try to add XGBoost if available
try:
    import xgboost as xgb
    models['XGBoost'] = xgb.XGBRegressor(n_estimators=100, random_state=42, verbosity=0)
except Exception:
    print('XGBoost not available; skipping.')

cv = KFold(n_splits=5, shuffle=True, random_state=42)
scoring = {'MAE':'neg_mean_absolute_error','R2':'r2'}
results = {}
for name, m in models.items():
    p = Pipeline([('preprocess', preprocess), ('model', m)])
    res = cross_validate(p, X, y, cv=cv, scoring=scoring, return_train_score=False)
    mae = -res['test_MAE'].mean()
    r2 = res['test_R2'].mean()
    results[name] = (mae, r2)
    print(f"{name}: MAE={mae:.2f}, R2={r2:.3f}")

# Pick best by MAE and evaluate on the holdout set
best = min(results.items(), key=lambda x: x[1][0])[0]
print('Best model by CV MAE:', best)
best_model = models[best]
pipe_best = Pipeline([('preprocess', preprocess), ('model', best_model)])
pipe_best.fit(X_train, y_train)
pred_best = pipe_best.predict(X_test)
print('Holdout MAE:', round(mean_absolute_error(y_test, pred_best), 2))
print('Holdout R2:', round(r2_score(y_test, pred_best), 3))

# Feature importances for tree-based models
try:
    ohe = pipe_best.named_steps['preprocess'].named_transformers_['cat'].named_steps['onehot']
    num_features = num_cols
    cat_features = list(ohe.get_feature_names_out(cat_cols))
    feature_names = list(num_features) + list(cat_features)
    importances = pipe_best.named_steps['model'].feature_importances_
    fi = pd.Series(importances, index=feature_names).sort_values(ascending=False)
    print('\nFeature importances:\n', fi)
except Exception as e:
    print('Could not extract feature importances:', e)