In [None]:
import json
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
from joblib import dump, load
from lightgbm import LGBMClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score, plot_confusion_matrix, roc_auc_score
from sklearn.model_selection import train_test_split

## Setup Parameters

In [None]:
churn_filepath = Path('data')/'Churn_Modelling.csv'
cat_cols = []
num_cols =  [
        'CreditScore',
        'Age',
        'Tenure',
        'Balance',
        'NumOfProducts',
        'HasCrCard',
        'IsActiveMember',
        'EstimatedSalary'
        ]
targ_col = 'Exited'
test_size = 0.25
random_state = 42
models_dir = Path('models')
models_dir.mkdir(exist_ok=True)
model_fname = 'model.joblib'

In [None]:
# papermill parameter
model_type = 'random-forest'
n_estimators = 50
max_depth = 5

In [None]:
train_params = {'n_estimators': n_estimators, 
                'max_depth': max_depth}

## Read Data

In [None]:
df = pd.read_csv(churn_filepath)
df.head()

## Data Exploration

In [None]:
df.shape

In [None]:
# Are there missing values?
df.isna().sum()

In [None]:
df['Geography'].value_counts()

## Data Preprocessing and Splitting

In [None]:
X, y = df[cat_cols + num_cols], df[targ_col]

In [None]:
y.mean()

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, 
                                                    test_size=test_size, 
                                                    random_state=random_state)

## Train Model

In [None]:
if model_type == 'random-forest':
    clf = RandomForestClassifier(random_state=random_state, **train_params)
elif model_type == 'lightgbm':
    clf = LGBMClassifier(random_state=random_state, **train_params)
else:
    raise Exception('Unsupported model_type')
clf.fit(X_train, y_train)

In [None]:
dump(clf, models_dir/model_fname)

In [None]:
clf = load(models_dir/model_fname)

## Model Evaluation

In [None]:
plot_confusion_matrix(clf, X_test, y_test, normalize='true', cmap=plt.cm.Blues)
plt.savefig(Path('eval_plots')/'cm.png')

In [None]:
y_prob = clf.predict_proba(X_test)
y_pred = y_prob[:, 1] >= 0.5

In [None]:
from dvclive import Live
live = Live("eval_plots")
live.log_plot("roc", y_test, y_prob[:, 1])

In [None]:
f1 = f1_score(y_test, y_pred)
roc_auc = roc_auc_score(y_test, y_prob[:, 1])
metrics = {
        'f1': f1,
        'roc_auc': roc_auc
    }
metrics

In [None]:
json.dump(
    obj=metrics,
    fp=open('metrics.json', 'w'),
    indent=4, 
    sort_keys=True
)