# CatBoost Classifier for KKBox Churn Prediction

> **데이터**: `kkbox_train_feature_v2.parquet` (86개 피처, 860,966 샘플)  
> **목표**: Recall 최적화를 통한 이탈 고객 탐지


In [None]:
# imports + seed
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from catboost import CatBoostClassifier
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    confusion_matrix,
    classification_report,
    recall_score,
    precision_score,
    f1_score,
)

import matplotlib.pyplot as plt
import seaborn as sns

RANDOM_STATE = 719
np.random.seed(RANDOM_STATE)

pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", 100)


## 1. 데이터 로드


In [None]:
# 데이터 로드 (v2: 중복 컬럼 제거됨)
df = pd.read_parquet('../data/kkbox_train_feature_v2.parquet')
print(f"Shape: {df.shape}")
print(f"Columns: {len(df.columns)}")
df.head()


In [None]:
# 타겟 분포 확인
print("=== Target Distribution ===")
print(df["is_churn"].value_counts())
print()
print(df["is_churn"].value_counts(normalize=True))


## 2. 피처 준비


In [None]:
# 제외할 컬럼 정의
EXCLUDE_COLS = ["msno", "is_churn"]

# datetime/period 타입 컬럼 제거 (CatBoost 호환)
feature_cols = []
for col in df.columns:
    if col in EXCLUDE_COLS:
        continue
    if pd.api.types.is_datetime64_any_dtype(df[col]):
        continue
    if pd.api.types.is_period_dtype(df[col]):
        continue
    feature_cols.append(col)

print(f"Feature columns: {len(feature_cols)}")
print(feature_cols)


In [None]:
# CatBoost용 범주형 컬럼 식별
cat_cols = []
for col in feature_cols:
    if pd.api.types.is_categorical_dtype(df[col]):
        cat_cols.append(col)
    elif pd.api.types.is_object_dtype(df[col]):
        cat_cols.append(col)

print(f"Categorical columns ({len(cat_cols)}): {cat_cols}")


## 3. 데이터 분할 (Train / Valid / Test)


In [None]:
# X, y 분리
X = df[feature_cols].copy()
y = df["is_churn"].astype(int)

# Train (70%) / Temp (30%) 분할
X_train, X_temp, y_train, y_temp = train_test_split(
    X, y,
    test_size=0.30,
    stratify=y,
    random_state=RANDOM_STATE,
)

# Temp를 Valid (15%) / Test (15%)로 분할
X_valid, X_test, y_valid, y_test = train_test_split(
    X_temp, y_temp,
    test_size=0.50,
    stratify=y_temp,
    random_state=RANDOM_STATE,
)

print(f"Train: {X_train.shape}")
print(f"Valid: {X_valid.shape}")
print(f"Test:  {X_test.shape}")


## 4. CatBoost 모델 학습


In [None]:
# scale_pos_weight 계산 (클래스 불균형 보정)
neg_count = (y_train == 0).sum()
pos_count = (y_train == 1).sum()
scale_pos_weight = neg_count / pos_count
print(f"scale_pos_weight: {scale_pos_weight:.2f}")

# CatBoost 모델 정의
model = CatBoostClassifier(
    loss_function="Logloss",
    eval_metric="AUC",
    learning_rate=0.05,
    depth=6,
    l2_leaf_reg=3.0,
    iterations=500,
    early_stopping_rounds=50,
    scale_pos_weight=scale_pos_weight,
    random_seed=RANDOM_STATE,
    verbose=100,
    cat_features=cat_cols if cat_cols else None,
)


In [None]:
# 모델 학습
model.fit(
    X_train, y_train,
    eval_set=(X_valid, y_valid),
    use_best_model=True,
)
print(f"\nBest iteration: {model.best_iteration_}")


## 5. 모델 평가


In [None]:
def evaluate_model(model, X, y, threshold=0.5, dataset_name="Dataset"):
    """모델 평가 함수"""
    y_proba = model.predict_proba(X)[:, 1]
    y_pred = (y_proba >= threshold).astype(int)
    
    roc_auc = roc_auc_score(y, y_proba)
    pr_auc = average_precision_score(y, y_proba)
    recall = recall_score(y, y_pred)
    precision = precision_score(y, y_pred)
    f1 = f1_score(y, y_pred)
    
    cm = confusion_matrix(y, y_pred)
    tn, fp, fn, tp = cm.ravel()
    specificity = tn / (tn + fp)
    
    print(f"=== {dataset_name} (threshold={threshold}) ===")
    print(f"ROC-AUC:     {roc_auc:.4f}")
    print(f"PR-AUC:      {pr_auc:.4f}")
    print(f"Recall:      {recall:.4f}")
    print(f"Precision:   {precision:.4f}")
    print(f"F1-Score:    {f1:.4f}")
    print(f"Specificity: {specificity:.4f}")
    print()
    
    return {
        "roc_auc": roc_auc,
        "pr_auc": pr_auc,
        "recall": recall,
        "precision": precision,
        "f1": f1,
        "specificity": specificity,
        "confusion_matrix": cm,
    }


In [None]:
# Validation / Test 성능 평가 (기본 threshold=0.5)
valid_results = evaluate_model(model, X_valid, y_valid, threshold=0.5, dataset_name="Validation")
test_results = evaluate_model(model, X_test, y_test, threshold=0.5, dataset_name="Test")


## 6. Threshold 최적화 (Recall 기준)


In [None]:
# Threshold별 Recall/Precision 트레이드오프
y_proba_valid = model.predict_proba(X_valid)[:, 1]

thresholds = np.arange(0.1, 0.9, 0.05)
results = []

for thr in thresholds:
    y_pred = (y_proba_valid >= thr).astype(int)
    recall = recall_score(y_valid, y_pred)
    precision = precision_score(y_valid, y_pred)
    f1 = f1_score(y_valid, y_pred)
    results.append({"threshold": thr, "recall": recall, "precision": precision, "f1": f1})

results_df = pd.DataFrame(results)

# 시각화
plt.figure(figsize=(10, 5))
plt.plot(results_df["threshold"], results_df["recall"], label="Recall", marker="o")
plt.plot(results_df["threshold"], results_df["precision"], label="Precision", marker="s")
plt.plot(results_df["threshold"], results_df["f1"], label="F1-Score", marker="^")
plt.xlabel("Threshold")
plt.ylabel("Score")
plt.title("Threshold vs Recall/Precision/F1")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# Recall >= 0.95 인 threshold 찾기
high_recall_df = results_df[results_df["recall"] >= 0.95]
if len(high_recall_df) > 0:
    best_thr = high_recall_df.loc[high_recall_df["precision"].idxmax(), "threshold"]
    print(f"Recall >= 95% 중 Precision 최대 threshold: {best_thr:.2f}")
else:
    print("Recall >= 95%를 만족하는 threshold 없음")


In [None]:
# 최적 threshold로 재평가
OPTIMAL_THRESHOLD = 0.35  # 필요시 위 분석 결과로 조정

print("=" * 50)
print(f"OPTIMAL THRESHOLD = {OPTIMAL_THRESHOLD}")
print("=" * 50)

valid_results_opt = evaluate_model(model, X_valid, y_valid, threshold=OPTIMAL_THRESHOLD, dataset_name="Validation")
test_results_opt = evaluate_model(model, X_test, y_test, threshold=OPTIMAL_THRESHOLD, dataset_name="Test")


## 7. Confusion Matrix 시각화


In [None]:
def plot_confusion_matrix(cm, title, labels=("Non-churn (0)", "Churn (1)")):
    """Confusion Matrix 시각화"""
    cm = np.array(cm, dtype=int)
    cm_norm = cm / cm.sum(axis=1, keepdims=True)
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))
    
    # Raw counts
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False,
                xticklabels=labels, yticklabels=labels, ax=axes[0])
    axes[0].set_title(f"{title}\n(Counts)")
    axes[0].set_xlabel("Predicted")
    axes[0].set_ylabel("True")
    
    # Normalized
    sns.heatmap(cm_norm, annot=True, fmt=".2%", cmap="Blues", cbar=True,
                xticklabels=labels, yticklabels=labels, ax=axes[1], vmin=0, vmax=1)
    axes[1].set_title(f"{title}\n(Row-normalized)")
    axes[1].set_xlabel("Predicted")
    axes[1].set_ylabel("True")
    
    plt.tight_layout()
    plt.show()

# Validation / Test Confusion Matrix
plot_confusion_matrix(valid_results_opt["confusion_matrix"], f"Validation (thr={OPTIMAL_THRESHOLD})")
plot_confusion_matrix(test_results_opt["confusion_matrix"], f"Test (thr={OPTIMAL_THRESHOLD})")


## 8. Feature Importance


In [None]:
# Feature Importance 추출 및 시각화
fi = pd.DataFrame({
    "feature": feature_cols,
    "importance": model.feature_importances_,
}).sort_values("importance", ascending=False)

# Top 20 시각화
fi_top20 = fi.head(20).iloc[::-1]

plt.figure(figsize=(10, 8))
plt.barh(fi_top20["feature"], fi_top20["importance"], color="steelblue")
plt.xlabel("Importance")
plt.ylabel("Feature")
plt.title("CatBoost Feature Importance (Top 20)")
plt.tight_layout()
plt.show()

# Top 20 테이블 출력
print("=== Top 20 Features ===")
fi.head(20)


## 9. 모델 저장


In [None]:
# 모델 저장
model.save_model("../models/catboost_v2.cbm")
print("Model saved to: ../models/catboost_v2.cbm")

# Feature columns 저장
import json
with open("../models/feature_cols_v2.json", "w") as f:
    json.dump(feature_cols, f, indent=2)
print("Feature columns saved to: ../models/feature_cols_v2.json")


## 8. Feature Importance


# imports + seed


In [None]:
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OneHotEncoder

from catboost import CatBoostClassifier
from sklearn.metrics import (
    roc_auc_score,
    average_precision_score,
    confusion_matrix,
    classification_report,
)

RANDOM_STATE = 719
np.random.seed(RANDOM_STATE)

pd.set_option("display.max_columns", None)
pd.set_option('display.max_rows', None)


In [None]:
df = pd.read_parquet('../data/kkbox_train_feature_v1.parquet')


In [None]:
df.head(5)


Unnamed: 0,msno,city,bd,gender,registered_via,registration_init_time,bd_clean,registration_month,is_churn,num_days_active_w7,total_secs_w7,avg_secs_per_day_w7,std_secs_w7,num_songs_w7,avg_songs_per_day_w7,num_unq_w7,num_25_w7,num_100_w7,short_play_w7,skip_ratio_w7,completion_ratio_w7,short_play_ratio_w7,variety_ratio_w7,num_days_active_w14,total_secs_w14,avg_secs_per_day_w14,std_secs_w14,num_songs_w14,avg_songs_per_day_w14,num_unq_w14,num_25_w14,num_100_w14,short_play_w14,skip_ratio_w14,completion_ratio_w14,short_play_ratio_w14,variety_ratio_w14,num_days_active_w21,total_secs_w21,avg_secs_per_day_w21,std_secs_w21,num_songs_w21,avg_songs_per_day_w21,num_unq_w21,num_25_w21,num_100_w21,short_play_w21,skip_ratio_w21,completion_ratio_w21,short_play_ratio_w21,variety_ratio_w21,num_days_active_w30,total_secs_w30,avg_secs_per_day_w30,std_secs_w30,num_songs_w30,avg_songs_per_day_w30,num_unq_w30,num_25_w30,num_100_w30,short_play_w30,skip_ratio_w30,completion_ratio_w30,short_play_ratio_w30,variety_ratio_w30,secs_trend_w7_w30,secs_trend_w14_w30,days_trend_w7_w14,days_trend_w7_w30,songs_trend_w7_w30,songs_trend_w14_w30,skip_trend_w7_w30,completion_trend_w7_w30,recency_secs_ratio,recency_songs_ratio,days_since_last_payment,has_ever_paid,days_since_last_cancel,has_ever_cancelled,is_auto_renew_last,last_plan_days,last_payment_method,is_free_user,total_payment_count,total_amount_paid,avg_amount_per_payment,unique_plan_count,subscription_months_est,payment_count_last_30d,payment_count_last_90d
0,+tJonkh+O1CA796Fm5X60UMOtB6POHAwPjbTRVl/EuU=,1,0,unknown,7,2011-09-14,,2011-09,0,7,75448.625,10778.375,9128.514648,338,48.285713,159,39,271,54,0.115385,0.801775,0.159763,0.470414,14,177639.296875,12688.521484,10458.754883,842,60.142857,480,127,641,157,0.150831,0.761283,0.186461,0.570071,20,238367.421875,11918.371094,9021.441406,1156,57.799999,663,170,863,220,0.147059,0.74654,0.190311,0.573529,30,358554.0,11951.799805,7876.637695,1776,59.200001,1040,277,1296,355,0.155968,0.72973,0.199887,0.585586,0.210425,0.495432,0.5,0.233333,0.190315,0.474099,-0.040584,0.072045,0.210425,0.190315,5,1,999,0,1,30,41,0,1,129,129.0,1,1.0,1,1
1,yLkV2gbZ4GLFwqTOXLVHz0VGrMYcgBGgKZ3kj9RiYu8=,4,30,male,9,2011-09-16,30.0,2011-09,0,6,123668.695312,20611.449219,9505.349609,557,92.833336,67,14,518,22,0.025135,0.929982,0.039497,0.120287,6,123668.695312,20611.449219,9505.349609,557,92.833336,67,14,518,22,0.025135,0.929982,0.039497,0.120287,6,123668.695312,20611.449219,9505.349609,557,92.833336,67,14,518,22,0.025135,0.929982,0.039497,0.120287,6,123668.695312,20611.449219,9505.349609,557,92.833336,67,14,518,22,0.025135,0.929982,0.039497,0.120287,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,1.0,1.0,1,1,999,0,1,30,39,0,2,298,149.0,1,2.0,1,2
2,I0yFvqMoNkM8ZNHb617e1RBzIS/YRKemHO7Wj13EtA0=,13,63,male,9,2011-09-18,63.0,2011-09,0,3,16989.527344,5663.175781,1434.409424,70,23.333334,65,2,65,3,0.028571,0.928571,0.042857,0.928571,10,50269.140625,5026.914062,3101.173584,249,24.9,182,43,195,47,0.172691,0.783133,0.188755,0.730924,15,63667.992188,4244.532715,2992.634277,352,23.466667,273,77,206,107,0.21875,0.585227,0.303977,0.775568,18,80453.320312,4469.628906,2823.026123,416,23.111111,337,77,269,108,0.185096,0.646635,0.259615,0.810096,0.211172,0.624824,0.3,0.166667,0.168269,0.598558,-0.156525,0.281937,0.211172,0.168269,5,1,999,0,1,30,40,0,1,149,149.0,1,1.0,1,1
3,OoDwiKZM+ZGr9P3fRivavgOtglTEaNfWJO4KaJcTTts=,1,0,unknown,7,2011-09-18,,2011-09,1,1,6168.049805,6168.049805,0.0,23,23.0,23,0,22,0,0.0,0.956522,0.0,1.0,2,8142.378906,4071.189453,2965.408447,35,17.5,34,2,30,4,0.057143,0.857143,0.114286,0.971429,2,8142.378906,4071.189453,2965.408447,35,17.5,34,2,30,4,0.057143,0.857143,0.114286,0.971429,3,8613.391602,2871.130615,2952.498535,38,12.666667,37,3,31,5,0.078947,0.815789,0.131579,0.973684,0.7161,0.945316,0.5,0.333333,0.605263,0.921053,-0.078947,0.140732,0.7161,0.605263,6,1,999,0,1,30,41,0,1,149,149.0,1,1.0,1,1
4,4De1jAxNRABoyRBDZ82U0yEmzYkqeOugRGVNIf92Xb8=,4,28,female,9,2011-09-20,28.0,2011-09,0,2,5703.128906,2851.564453,2644.321289,29,14.5,24,5,24,5,0.172414,0.827586,0.172414,0.827586,5,15160.677734,3032.135498,1988.283691,90,18.0,42,6,82,6,0.066667,0.911111,0.066667,0.466667,8,19365.294922,2420.661865,1723.983154,118,14.75,55,9,105,10,0.076271,0.889831,0.084746,0.466102,10,22494.763672,2249.476318,1725.134766,134,13.4,68,12,117,13,0.089552,0.873134,0.097015,0.507463,0.253531,0.673965,0.4,0.2,0.216418,0.671642,0.082862,-0.045548,0.253531,0.216418,29,1,999,0,1,30,36,0,1,180,180.0,1,1.0,1,1


In [None]:
# 1) bd vs bd_clean 중복 제거
if "bd" in df.columns and "bd_clean" in df.columns:
    df = df.drop(columns=["bd"])

# 2) time / month 숫자화
df["reg_year"]  = df["registration_init_time"].dt.year.astype("Int64")

# 원본 제거 (Period/Datetime 에러 방지)
df = df.drop(columns=["registration_init_time", "registration_month"])


# train/valid/test split (stratify 유지)


In [None]:
assert "msno" in df.columns and "is_churn" in df.columns

trainval_df, test_df = train_test_split(
    df,
    test_size=0.15,
    random_state=RANDOM_STATE,
    stratify=df["is_churn"],
)

valid_size = 0.15 / 0.85
train_df, valid_df = train_test_split(
    trainval_df,
    test_size=valid_size,
    random_state=RANDOM_STATE,
    stratify=trainval_df["is_churn"],
)

feature_cols = [c for c in df.columns if c not in ["msno", "is_churn"]]

X_train, y_train = train_df[feature_cols], train_df["is_churn"].astype(int)
X_valid, y_valid = valid_df[feature_cols], valid_df["is_churn"].astype(int)
X_test,  y_test  = test_df[feature_cols],  test_df["is_churn"].astype(int)

print("churn rate:", y_train.mean(), y_valid.mean(), y_test.mean())
print(X_train.shape, X_valid.shape, X_test.shape)


churn rate: 0.09460141104009451 0.09459909404158116 0.09459909404158116
(602676, 86) (129145, 86) (129145, 86)


# column split + preprocess (OHE)


In [None]:
from pandas.api.types import is_numeric_dtype

cat_cols = [c for c in X_train.columns if not is_numeric_dtype(X_train[c])]
num_cols = [c for c in X_train.columns if c not in cat_cols]

print("num:", len(num_cols), "cat:", len(cat_cols))
print("cat example:", cat_cols[:10])

preprocess = ColumnTransformer(
    transformers=[
        ("num", Pipeline([
            ("imputer", SimpleImputer(strategy="median")),
        ]), num_cols),
        ("cat", Pipeline([
            ("imputer", SimpleImputer(strategy="most_frequent")),
            ("ohe", OneHotEncoder(handle_unknown="ignore", sparse_output=True)),
        ]), cat_cols),
    ],
    remainder="drop",
)


num: 85 cat: 1
cat example: ['gender']


# 공통 평가 함수


In [None]:
def eval_binary(y_true, p_pred, prefix="", thr=0.5):
    roc = roc_auc_score(y_true, p_pred)
    ap  = average_precision_score(y_true, p_pred)

    y_hat = (p_pred >= thr).astype(int)
    cm = confusion_matrix(y_true, y_hat)
    cr = classification_report(y_true, y_hat, digits=4)

    print(f"{prefix}ROC-AUC: {roc:.6f} | PR-AUC(AP): {ap:.6f} | thr={thr}")
    print(f"{prefix}Confusion matrix:\n{cm}")
    print(f"{prefix}Classification report:\n{cr}")
    return roc, ap, cm


# CatBoost 학습

- loss_function: "Logloss"
- eval_metric: "AUC"
- scale_pos_weight로 불균형 대응 (Recall 최적화)
- early_stopping_rounds=50 적용
- thread_count=-1로 풀코어


In [None]:
# scale_pos_weight 계산 (Recall 최적화)
scale_pos_weight = (y_train == 0).sum() / (y_train == 1).sum()
print(f"scale_pos_weight: {scale_pos_weight:.2f}")

# 전처리 적용
X_train_prep = preprocess.fit_transform(X_train)
X_valid_prep = preprocess.transform(X_valid)
X_test_prep = preprocess.transform(X_test)

print(f"Preprocessed shapes: {X_train_prep.shape}, {X_valid_prep.shape}, {X_test_prep.shape}")


scale_pos_weight: 9.57
Preprocessed shapes: (602676, 88), (129145, 88), (129145, 88)


In [None]:
cb_model = CatBoostClassifier(
    loss_function="Logloss",
    eval_metric="AUC",
    learning_rate=0.05,
    depth=6,
    l2_leaf_reg=3.0,
    iterations=1000,
    early_stopping_rounds=50,
    scale_pos_weight=float(scale_pos_weight),
    random_seed=RANDOM_STATE,
    thread_count=-1,
    verbose=100,
)

cb_model.fit(
    X_train_prep,
    y_train,
    eval_set=(X_valid_prep, y_valid),
    use_best_model=True,
)

print(f"\nBest iteration: {cb_model.best_iteration_}")


0:	test: 0.9674497	best: 0.9674497 (0)	total: 208ms	remaining: 3m 27s
100:	test: 0.9875600	best: 0.9875600 (100)	total: 5.59s	remaining: 49.8s
200:	test: 0.9885004	best: 0.9885004 (200)	total: 11s	remaining: 43.8s
300:	test: 0.9888783	best: 0.9888783 (300)	total: 16.5s	remaining: 38.3s
400:	test: 0.9890667	best: 0.9890667 (400)	total: 22.1s	remaining: 32.9s
500:	test: 0.9891445	best: 0.9891472 (493)	total: 27.7s	remaining: 27.6s
600:	test: 0.9891831	best: 0.9891843 (595)	total: 33.3s	remaining: 22.1s
700:	test: 0.9892343	best: 0.9892343 (700)	total: 39.3s	remaining: 16.8s
800:	test: 0.9892411	best: 0.9892411 (800)	total: 45.3s	remaining: 11.2s
900:	test: 0.9892585	best: 0.9892596 (899)	total: 51.4s	remaining: 5.65s
Stopped by overfitting detector  (50 iterations wait)

bestTest = 0.9892643162
bestIteration = 920

Shrink model to first 921 iterations.

Best iteration: 920


In [None]:
# 예측 및 평가
p_valid = cb_model.predict_proba(X_valid_prep)[:, 1]
p_test  = cb_model.predict_proba(X_test_prep)[:, 1]

print("=" * 60)
print("Threshold = 0.5")
print("=" * 60)
eval_binary(y_valid, p_valid, prefix="[CB valid] ")
eval_binary(y_test,  p_test,  prefix="[CB test ] ")


Threshold = 0.5
[CB valid] ROC-AUC: 0.989264 | PR-AUC(AP): 0.933037 | thr=0.5
[CB valid] Confusion matrix:
[[112456   4472]
 [   839  11378]]
[CB valid] Classification report:
              precision    recall  f1-score   support

           0     0.9926    0.9618    0.9769    116928
           1     0.7179    0.9313    0.8108     12217

    accuracy                         0.9589    129145
   macro avg     0.8552    0.9465    0.8939    129145
weighted avg     0.9666    0.9589    0.9612    129145

[CB test ] ROC-AUC: 0.989417 | PR-AUC(AP): 0.933324 | thr=0.5
[CB test ] Confusion matrix:
[[112394   4534]
 [   806  11411]]
[CB test ] Classification report:
              precision    recall  f1-score   support

           0     0.9929    0.9612    0.9768    116928
           1     0.7156    0.9340    0.8104     12217

    accuracy                         0.9587    129145
   macro avg     0.8543    0.9476    0.8936    129145
weighted avg     0.9667    0.9587    0.9611    129145



(0.9894168489517845,
 0.9333237670024315,
 array([[112394,   4534],
        [   806,  11411]]))