# 🧠 Advanced Model Training
Compare XGBoost, LightGBM, and CatBoost using engineered features

In [None]:
!pip install xgboost lightgbm catboost scikit-learn pandas matplotlib shap


In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from catboost import CatBoostClassifier
import shap


In [None]:
# Load enhanced dataset
df = pd.read_csv('../data/labeled_events_enhanced.csv', parse_dates=['date'])
df = df[df['event_label'] != 'neutral']
df['event_label'] = df['event_label'].astype('category')
df['target'] = df['event_label'].cat.codes


In [None]:
# Feature selection
features = [
    'open', 'close', 'high', 'low', 'volume', 'avg_sentiment', 'tx_spike',
    'daily_return', 'volatility', 'sentiment_volatility', 'tweet_count',
    'whale_tx_count', 'bot_tx_flag', 'rsi', 'bollinger_upper',
    'bollinger_lower', 'daily_return_lag1', 'volume_lag1', 'avg_sentiment_lag1'
]

X = df[features]
y = df['target']

X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.2, random_state=42)


In [None]:
# Train XGBoost
xgb_model = XGBClassifier(use_label_encoder=False, eval_metric='mlogloss', random_state=42)
xgb_model.fit(X_train, y_train)
xgb_pred = xgb_model.predict(X_test)
print("XGBoost Report:\n", classification_report(y_test, xgb_pred))


In [None]:
# Train LightGBM
lgb_model = LGBMClassifier(random_state=42)
lgb_model.fit(X_train, y_train)
lgb_pred = lgb_model.predict(X_test)
print("LightGBM Report:\n", classification_report(y_test, lgb_pred))


In [None]:
# Train CatBoost
cat_model = CatBoostClassifier(verbose=0, random_state=42)
cat_model.fit(X_train, y_train)
cat_pred = cat_model.predict(X_test)
print("CatBoost Report:\n", classification_report(y_test, cat_pred))


In [None]:
# SHAP for best model using CatBoost for an example)
explainer = shap.Explainer(cat_model, X_train)
shap_values = explainer(X_test)
shap.summary_plot(shap_values, X_test, plot_type="bar")
