In [2]:
# File: flows/b_training_ada.py
from sklearn.naive_bayes import GaussianNB
from sklearn.ensemble import AdaBoostClassifier
from xgboost import XGBClassifier

from exercises.d_TrainingAndEvaluation.generic_trainer import train_fraud
from domino_short_id import domino_short_id


# Load DataFrame from data source
transformed_df_filename = 'transformed_cc_transactions.csv'

ada_model = {'model': AdaBoostClassifier(
            n_estimators=10,
            learning_rate=0.001,
            algorithm="SAMME",
        ), 'name': "AdaBoost"}
gnb_model = {'model': GaussianNB(), 
             'name': "GaussianNB"}
xbg_model = {'model': XGBClassifier(
            n_estimators=200,
            learning_rate=0.05,
            max_depth=4,
            subsample=0.8,
            colsample_bytree=0.8,
            use_label_encoder=False,
            eval_metric="auc",
        ), 'name': "XGBoost"}

all_models = [ada_model, gnb_model, xbg_model]

for model_dict in all_models:
    model_name = model_dict['name']
    model_obj = model_dict['model']
    res = train_fraud(model_obj, model_name, transformed_df_filename)
    print(f"✅ Training {model_name} completed successfully")
    print(res)
    
print(f"{'✅' * len(all_models)} Trainings completed successfully")



training model AdaBoost
🏃 View run AdaBoost at: http://127.0.0.1:8768/#/experiments/1566/runs/e252926ca5d54f8a87acd1c9b78c0f4e
🧪 View experiment at: http://127.0.0.1:8768/#/experiments/1566
✅ Training AdaBoost completed successfully
{'model_name': 'AdaBoost', 'roc_auc': 0.7920946465652298, 'pr_auc': 0.6611633477127952, 'accuracy': 0.8233389457324181, 'precision_fraud': 0.8026649746192893, 'recall_fraud': 0.6788908765652951, 'f1_fraud': 0.7356076759061834, 'fit_time_sec': 1.3792343139648438}
training model GaussianNB
🏃 View run GaussianNB at: http://127.0.0.1:8768/#/experiments/1566/runs/5878ca9f00884fbe91e984889b3b5fc7
🧪 View experiment at: http://127.0.0.1:8768/#/experiments/1566
✅ Training GaussianNB completed successfully
{'model_name': 'GaussianNB', 'roc_auc': 0.8606071740107799, 'pr_auc': 0.7316422120373787, 'accuracy': 0.7477010749902863, 'precision_fraud': 0.7819573901464714, 'recall_fraud': 0.4202146690518784, 'f1_fraud': 0.5466604607865952, 'fit_time_sec': 0.0312349796295166}
