In [23]:
import joblib
import mlflow
import pandas as pd
import catboost as ctb
from kedro.io import PickleLocalDataSet
from catboost import CatBoostClassifier, cv, Pool
from sklearn.preprocessing import StandardScaler
from twitter_bot_detection.helpers import log_running_time
from sklearn.model_selection import train_test_split, StratifiedKFold
from eli5 import show_weights, explain_prediction, explain_weights_catboost
from sklearn.metrics import accuracy_score, classification_report, f1_score

In [24]:
X_train = pd.read_pickle("data/05_model_input/X_train.pkl")
X_test = pd.read_pickle("data/05_model_input/X_test.pkl")
y_train = pd.read_pickle("data/05_model_input/y_train.pkl")
y_test = pd.read_pickle("data/05_model_input/y_test.pkl")

In [25]:
# @log_running_time
def train_catboost(X_train: PickleLocalDataSet, X_test: PickleLocalDataSet, y_train: PickleLocalDataSet, y_test: PickleLocalDataSet, log=False) -> PickleLocalDataSet:
    params = {
        "iterations": 2500,
        "learning_rate": 0.02,
        "loss_function": 'Logloss',
        "random_seed": 1,
        "od_wait": 30,
        "od_type": "Iter",
        "thread_count": 8,
#         "cat_features": ["created_at_time"]
    }
    model = CatBoostClassifier(**params)
    features = X_train.columns.values
#     data = Pool(data=X_train, label=y_train, cat_features=params["cat_features"])

#     params = {"iterations": 100,
#               "depth": 2,
#               "loss_function": "Logloss",
#               "verbose": False}
# 
#     scores = cv(data, params, fold_count=3, verbose=200, plot="True", stratified=True)
    model.fit(
        X_train, y_train,
        eval_set=(X_test, y_test),
        verbose=200,
        plot=False,
    )
    
    y_pred = model.predict(X_test)
    
    f1 = f1_score(y_test, y_pred, average="weighted")        
    joblib.dump(model, 'data/06_models/catboost.pkl')
    
#     model.save_model('data/06_models/catboost', format="cbm", export_parameters=None, pool=None)
    
    print(classification_report(y_test, y_pred, digits=5))
    if log:
        mlflow.set_tracking_uri("databricks")
        mlflow.set_experiment("/Users/firefly.eugene@gmail.com/twitter-bot-detection")

        run_id = mlflow.search_runs(experiment_ids="3889491181315524", filter_string="tags.`mlflow.runName`='catboost'", run_view_type=1)["run_id"][0]
        mlflow.start_run(run_id=run_id, nested=False)
        with mlflow.start_run(nested=True):
            mlflow.set_tags({
                "lib": "catboost",
                "features": features,
            })

            mlflow.log_params(params)
            mlflow.log_metric("f1", f1, 1)
            mlflow.log_artifact('data/05_model_input/X_test.pkl')
        mlflow.end_run()
    
    return model

In [26]:
m = train_catboost(X_train, X_test, y_train, y_test, log=True);

0:	learn: 0.6716830	test: 0.6717144	best: 0.6717144 (0)	total: 23ms	remaining: 57.5s
200:	learn: 0.2102663	test: 0.2164879	best: 0.2164879 (200)	total: 3.67s	remaining: 42s
400:	learn: 0.1849778	test: 0.1979392	best: 0.1979392 (400)	total: 7.46s	remaining: 39.1s
600:	learn: 0.1693293	test: 0.1902297	best: 0.1902297 (600)	total: 11.2s	remaining: 35.4s
800:	learn: 0.1569460	test: 0.1850337	best: 0.1850337 (800)	total: 14.9s	remaining: 31.7s
1000:	learn: 0.1476537	test: 0.1819563	best: 0.1819537 (999)	total: 18.6s	remaining: 27.9s
1200:	learn: 0.1395202	test: 0.1794074	best: 0.1794074 (1200)	total: 22.4s	remaining: 24.3s
1400:	learn: 0.1320143	test: 0.1776228	best: 0.1776183 (1392)	total: 26.1s	remaining: 20.5s
1600:	learn: 0.1248483	test: 0.1761560	best: 0.1761324 (1595)	total: 30s	remaining: 16.8s
1800:	learn: 0.1185706	test: 0.1753527	best: 0.1753527 (1800)	total: 33.7s	remaining: 13.1s
2000:	learn: 0.1126907	test: 0.1746323	best: 0.1746305 (1998)	total: 37.5s	remaining: 9.35s
2200:	le

In [30]:
explain_weights_catboost(m, top=100)

Weight,Feature
0.2226,is_retweet_mean
0.0574,followers_count
0.0409,tweets
0.0375,quotes_mean
0.037,tweets_per_day
0.0341,replies_mean
0.0324,statuses_count
0.0318,verified
0.0307,favourites_count
0.0304,account_active_for_days


In [13]:
#eli5, shapley values