Небольшой бейзлайн для задачки Оттока

In [1]:
import pandas as pd
import numpy as np
from catboost import CatBoostClassifier, Pool
from sklearn.model_selection import StratifiedKFold
from sklearn import metrics
from sklearn.preprocessing import LabelEncoder
import random
import warnings

warnings.simplefilter('ignore')

pd.options.display.max_columns = 100
pd.options.display.max_rows = 100

random.seed(42)
np.random.seed(42)

Объединим всех юзеров в 1 датафрейм, для создания фичей

In [2]:
main = pd.read_csv('data/train.csv')

sample = pd.read_csv('data/sample_submit_naive.csv').drop('predict', axis=1)
sample['target'] = -1

main = pd.concat([main, sample])

Откроем дополнительные файлы

In [3]:
clients = pd.read_csv('data/clients.csv')
report_dates = pd.read_csv('data/report_dates.csv', parse_dates=['report_dt'])

transactions = pd.read_csv('data/transactions.csv', parse_dates=['transaction_dttm'])
transactions = transactions.sort_values('transaction_dttm').reset_index(drop=True)

In [4]:
# Добавим информацию о клиенте, а также закодируем employee_count_nm

main = main.merge(clients, how='left', on='user_id')
main['employee_count_nm'] = LabelEncoder().fit_transform(main['employee_count_nm'].fillna('unknown'))

Для категорий, которые встречались больше 10 раз агрегируем информацию о транзакциях по пользователю и коду MCC

In [5]:
good_codes = transactions['mcc_code'].value_counts()
good_codes = good_codes[good_codes >= 10]

mcc_info = transactions[transactions.mcc_code.isin(good_codes)].pivot_table(
    index = 'user_id',
    values=['transaction_amt'],
    columns=['mcc_code'],
    aggfunc=['count', 'median', 'sum']
).fillna(0)
mcc_info.columns = ['main_' + '_'.join(map(str, x)) for x in mcc_info.columns]

count_cols = [x for x in mcc_info.columns if 'count' in x]
mcc_info['sum'] = mcc_info[count_cols].sum(axis=1)
for col in count_cols:
    mcc_info[f'{col}_norm'] = mcc_info[col] / mcc_info['sum']
mcc_info.drop('sum', axis=1, inplace=True)

main = main.merge(mcc_info, how='left', left_on='user_id', right_index=True)

pivot_table по валютам со сбором базовой информации

In [6]:
# currency_pivot = transactions.pivot_table(
#     index='user_id',
#     columns='currency_rk',
#     values='transaction_amt',
#     aggfunc=['min', 'max', 'median', 'count', 'std']
# ).fillna(0)
# currency_pivot.columns = [f'currency_{x[0]}_{x[1]}' for x in currency_pivot.columns]

# # Сборка нормированных значений о том в каких валютах были транзации у клиента
# currency_pivot['sum'] = currency_pivot[[f'currency_count_{x}' for x in range(4)]].sum(axis=1)
# for x in range(4):
#     currency_pivot[f'currency_count_{x}_norm'] = currency_pivot[f'currency_count_{x}'] / currency_pivot['sum']
# currency_pivot.drop('sum', axis=1, inplace=True)

# main = main.merge(currency_pivot, how='left', left_on='user_id', right_index=True)

Сбор информации о транзакциях в каждой валюте за последние 30 дней, а также за весь промежуток

In [7]:
df_more = transactions.merge(clients[['user_id', 'report']], how='left', on='user_id')
df_more = df_more.merge(report_dates, how='left', on='report')
df_more['days_to_report'] = (df_more['report_dt'] - df_more['transaction_dttm']).dt.days


for day_diff in [30, 1000]:

    # Информация о размерах транзакций в различных валютах
    currency_pivot = df_more[df_more['days_to_report'] < day_diff + 100].pivot_table(
        index='user_id',
        columns='currency_rk',
        values='transaction_amt',
        aggfunc=['sum', 'mean', 'median', 'count']
    ).fillna(0)
    currency_pivot.columns = [f'currency_daydiff_{day_diff}_{x[0]}_{x[1]}' for x in currency_pivot.columns]

    currency_pivot['sum'] = currency_pivot[[x for x in currency_pivot.columns if 'count' in x]].sum(axis=1)
    for x in range(4):
        currency_pivot[f'currency_daydiff_{day_diff}_count_{x}_norm'] = currency_pivot[f'currency_daydiff_{day_diff}_count_{x}'] / currency_pivot['sum']
    currency_pivot.drop('sum', axis=1, inplace=True)

    main = main.merge(currency_pivot, how='left', left_on='user_id', right_index=True)


    general_trans_info = df_more[df_more['days_to_report'] < day_diff + 100].groupby('user_id')['transaction_amt'].agg(['sum', 'count', 'median'])
    general_trans_info[['sum', 'count']] = general_trans_info[['sum', 'count']].fillna(0)
    general_trans_info.columns = [f'general_trans_info_{day_diff}_{x}' for x in general_trans_info]
    main = main.merge(general_trans_info, how='left', left_on='user_id', right_index=True)

    general_trans_info = df_more[(df_more['days_to_report']<day_diff + 100)&(df_more['transaction_amt']>0)].groupby('user_id')['transaction_amt'].agg(['sum', 'count', 'median'])
    general_trans_info[['sum', 'count']] = general_trans_info[['sum', 'count']].fillna(0)
    general_trans_info.columns = [f'positive_general_trans_info_{day_diff}_{x}' for x in general_trans_info]
    main = main.merge(general_trans_info, how='left', left_on='user_id', right_index=True)

    general_trans_info = df_more[(df_more['days_to_report']<day_diff + 100)&(df_more['transaction_amt']<0)].groupby('user_id')['transaction_amt'].agg(['sum', 'count', 'median'])
    general_trans_info[['sum', 'count']] = general_trans_info[['sum', 'count']].fillna(0)
    general_trans_info.columns = [f'negative_general_trans_info_{day_diff}_{x}' for x in general_trans_info]
    main = main.merge(general_trans_info, how='left', left_on='user_id', right_index=True)


# Анализируем кол-во транзакций в последние n дней / кол-во транзакций до последних n дней
for x in [5, 30]:
    prev = df_more[df_more['days_to_report'] > x + 100].groupby('user_id')['report'].agg(['count']).reset_index().rename({'count': f'num_transaction_before_{x}_days'}, axis=1)
    last = df_more[df_more['days_to_report'] <= x + 100].groupby('user_id')['report'].agg(['count']).reset_index().rename({'count': f'num_transaction_last_{x}_days'}, axis=1)

    main = main.merge(prev, how='left', on='user_id')
    main = main.merge(last, how='left', on='user_id')
    main[f'num_transaction_last_{x}_days'].fillna(0, inplace=True)
    main[f'num_transaction_before_{x}_days'].fillna(0, inplace=True)
    main[f'percent_last_{x}'] = main[f'num_transaction_last_{x}_days'] / main[f'num_transaction_before_{x}_days']

In [8]:
# Кол-во уникальных MCC кодов, валют, а также уникальных дней, в которые были транзакции
main = main.merge(df_more.groupby('user_id')['days_to_report'].nunique(), how='left', on='user_id').rename({'days_to_report': 'nunique_days'}, axis=1)
main = main.merge(df_more.groupby('user_id')['mcc_code'].nunique(), how='left', on='user_id').rename({'mcc_code': 'nunique_mcc_codes'}, axis=1)
main = main.merge(df_more.groupby('user_id')['currency_rk'].nunique(), how='left', on='user_id').rename({'currency_rk': 'nunique_currency'}, axis=1)

main

Unnamed: 0,user_id,target,time,report,employee_count_nm,bankemplstatus,customer_age,main_count_transaction_amt_10,main_count_transaction_amt_11,main_count_transaction_amt_12,main_count_transaction_amt_15,main_count_transaction_amt_16,main_count_transaction_amt_17,main_count_transaction_amt_18,main_count_transaction_amt_22,main_count_transaction_amt_23,main_count_transaction_amt_26,main_count_transaction_amt_28,main_count_transaction_amt_29,main_count_transaction_amt_31,main_count_transaction_amt_32,main_count_transaction_amt_33,main_count_transaction_amt_34,main_count_transaction_amt_39,main_count_transaction_amt_42,main_count_transaction_amt_44,main_count_transaction_amt_50,main_count_transaction_amt_51,main_count_transaction_amt_53,main_count_transaction_amt_54,main_count_transaction_amt_55,main_count_transaction_amt_56,main_count_transaction_amt_58,main_count_transaction_amt_59,main_count_transaction_amt_63,main_count_transaction_amt_65,main_count_transaction_amt_66,main_count_transaction_amt_72,main_count_transaction_amt_76,main_count_transaction_amt_77,main_count_transaction_amt_78,main_count_transaction_amt_81,main_count_transaction_amt_82,main_count_transaction_amt_85,main_count_transaction_amt_92,main_count_transaction_amt_95,main_count_transaction_amt_105,main_count_transaction_amt_111,main_count_transaction_amt_119,main_count_transaction_amt_122,...,currency_daydiff_30_count_1_norm,currency_daydiff_30_count_2_norm,currency_daydiff_30_count_3_norm,general_trans_info_30_sum,general_trans_info_30_count,general_trans_info_30_median,positive_general_trans_info_30_sum,positive_general_trans_info_30_count,positive_general_trans_info_30_median,negative_general_trans_info_30_sum,negative_general_trans_info_30_count,negative_general_trans_info_30_median,currency_daydiff_1000_sum_0,currency_daydiff_1000_sum_1,currency_daydiff_1000_sum_2,currency_daydiff_1000_sum_3,currency_daydiff_1000_mean_0,currency_daydiff_1000_mean_1,currency_daydiff_1000_mean_2,currency_daydiff_1000_mean_3,currency_daydiff_1000_median_0,currency_daydiff_1000_median_1,currency_daydiff_1000_median_2,currency_daydiff_1000_median_3,currency_daydiff_1000_count_0,currency_daydiff_1000_count_1,currency_daydiff_1000_count_2,currency_daydiff_1000_count_3,currency_daydiff_1000_count_0_norm,currency_daydiff_1000_count_1_norm,currency_daydiff_1000_count_2_norm,currency_daydiff_1000_count_3_norm,general_trans_info_1000_sum,general_trans_info_1000_count,general_trans_info_1000_median,positive_general_trans_info_1000_sum,positive_general_trans_info_1000_count,positive_general_trans_info_1000_median,negative_general_trans_info_1000_sum,negative_general_trans_info_1000_count,negative_general_trans_info_1000_median,num_transaction_before_5_days,num_transaction_last_5_days,percent_last_5,num_transaction_before_30_days,num_transaction_last_30_days,percent_last_30,nunique_days,nunique_mcc_codes,nunique_currency
0,3,0,77.0,2,4,0,3,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,175726.502930,4.0,33163.771484,175726.502930,4.0,33163.771484,,,,0.000000,13706.416641,0.0,0.0,0.000000,1246.037876,0.0,0.0,0.000000,4549.455078,0.0,0.0,0.0,11.0,0.0,0.0,0.000000,1.000000,0.0,0.0,13706.416641,11,4549.455078,186108.229797,7.0,5386.999023,-172401.813156,4.0,-9175.519287,11.0,0.0,0.000000,7.0,4.0,0.571429,8,4,1
1,13,0,86.0,6,8,0,2,0.0,0.0,6.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,-5588.771484,2.0,-2794.385742,10805.421875,1.0,10805.421875,-16394.193359,1.0,-16394.193359,10772.799805,-135490.178955,0.0,0.0,10772.799805,-6451.913284,0.0,0.0,10772.799805,-10642.210938,0.0,0.0,1.0,21.0,0.0,0.0,0.045455,0.954545,0.0,0.0,-124717.379150,22,-10529.004883,128766.684326,8.0,10789.110840,-253484.063477,14.0,-16423.615234,22.0,0.0,0.000000,20.0,2.0,0.100000,18,4,2
2,37,0,89.0,5,1,0,2,4.0,0.0,0.0,2.0,1.0,0.0,1.0,5.0,0.0,0.0,7.0,0.0,0.0,4.0,0.0,0.0,2.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,-48048.763298,41.0,-267.658539,,,,-48048.763298,41.0,-267.658539,0.000000,-331859.599463,0.0,0.0,0.000000,-1053.522538,0.0,0.0,0.000000,-236.420776,0.0,0.0,0.0,315.0,0.0,0.0,0.000000,1.000000,0.0,0.0,-331859.599463,315,-236.420776,10738.788574,2.0,5369.394287,-342598.388037,313.0,-236.546936,313.0,2.0,0.006390,273.0,42.0,0.153846,130,28,1
3,41,0,57.0,1,4,0,2,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,-8045.445801,2.0,-4022.722900,,,,-8045.445801,2.0,-4022.722900,0.000000,-108586.614166,0.0,0.0,0.000000,-6786.663385,0.0,0.0,0.000000,-6328.293701,0.0,0.0,0.0,16.0,0.0,0.0,0.000000,1.000000,0.0,0.0,-108586.614166,16,-6328.293701,,,,-108586.614166,16.0,-6328.293701,14.0,2.0,0.142857,14.0,2.0,0.142857,12,5,1
4,42,0,84.0,12,3,0,3,2.0,4.0,7.0,3.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,23617.963127,13.0,-187.755142,51152.999939,4.0,2277.169434,-27535.036812,9.0,-1037.949341,0.000000,11429.587215,0.0,0.0,0.000000,193.721817,0.0,0.0,0.000000,-321.756958,0.0,0.0,0.0,59.0,0.0,0.0,0.000000,1.000000,0.0,0.0,11429.587215,59,-321.756958,72779.679138,11.0,2706.099609,-61350.091923,48.0,-528.145752,49.0,10.0,0.204082,46.0,13.0,0.282609,38,20,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95995,561362,-1,,12,0,0,3,0.0,0.0,0.0,1.0,1.0,0.0,0.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,-9802.526207,14.0,-434.536697,,,,-9802.526207,14.0,-434.536697,0.000000,-71254.860472,0.0,0.0,0.000000,-719.746065,0.0,0.0,0.000000,-467.705963,0.0,0.0,0.0,99.0,0.0,0.0,0.000000,1.000000,0.0,0.0,-71254.860472,99,-467.705963,,,,-71254.860472,99.0,-467.705963,99.0,0.0,0.000000,84.0,15.0,0.178571,49,12,1
95996,561419,-1,,12,0,0,3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,-6571.396210,6.0,-1069.419952,,,,-6571.396210,6.0,-1069.419952,0.000000,-3495.118294,0.0,0.0,0.000000,-48.543310,0.0,0.0,0.000000,-474.838638,0.0,0.0,0.0,72.0,0.0,0.0,0.000000,1.000000,0.0,0.0,-3495.118294,72,-474.838638,69579.185341,5.0,109.552872,-73074.303635,67.0,-542.066284,72.0,0.0,0.000000,65.0,7.0,0.107692,54,9,1
95997,561895,-1,,12,0,0,2,1.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,...,1.0,0.0,0.0,-179350.281706,19.0,-1134.675049,,,,-179350.281706,19.0,-1134.675049,0.000000,-717608.803839,0.0,0.0,0.000000,-18400.225739,0.0,0.0,0.000000,-1422.766357,0.0,0.0,0.0,39.0,0.0,0.0,0.000000,1.000000,0.0,0.0,-717608.803839,39,-1422.766357,,,,-717608.803839,39.0,-1422.766357,36.0,3.0,0.083333,19.0,20.0,1.052632,24,15,1
95998,561908,-1,,12,0,0,2,2.0,1.0,0.0,1.0,0.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,1.0,0.0,0.0,-88454.045708,20.0,-391.408051,54908.425781,1.0,54908.425781,-143362.471489,19.0,-460.255188,0.000000,778253.475967,0.0,0.0,0.000000,13897.383499,0.0,0.0,0.000000,-223.971062,0.0,0.0,0.0,56.0,0.0,0.0,0.000000,1.000000,0.0,0.0,778253.475967,56,-223.971062,936792.233398,10.0,53917.962891,-158538.757431,46.0,-315.139267,44.0,12.0,0.272727,36.0,20.0,0.555556,24,12,1


Информация о количестве и размере транзакций в разрезе часов.

In [9]:
tmp = transactions.copy()
tmp['hour'] = tmp['transaction_dttm'].dt.hour
pivot_table = tmp.pivot_table(
    index='user_id',
    columns='hour',
    values='transaction_amt',
    aggfunc=['count', 'median']
).fillna(0)
pivot_table.columns = [f'hour_{x[0]}_{x[1]}' for x in pivot_table.columns]

count_cols = [x for x in pivot_table.columns if 'count' in x]
pivot_table['sum'] = pivot_table[count_cols].sum(axis=1)
for col in count_cols:
    pivot_table[f'{col}_norm'] = pivot_table[col] / pivot_table['sum']
pivot_table.drop('sum', axis=1, inplace=True)

main = main.merge(pivot_table, how='left', left_on='user_id', right_index=True)

Фичи, основанные на временных отрезках

In [10]:
cur = transactions.groupby('user_id')['transaction_dttm'].agg(['min', 'max']).reset_index()
cur = cur.merge(clients[['user_id', 'report']], how='left', on='user_id')
cur = cur.merge(report_dates, how='left', on='report')

cur['min_diff_dttm'] = (cur['report_dt'] - cur['min']).dt.days
cur['days_to_report'] = (cur['report_dt'] - cur['max']).dt.days
cur['max_min_diff_dttm'] = cur['days_to_report'] - cur['min_diff_dttm']

main = main.merge(cur[['user_id', 'min_diff_dttm','days_to_report','max_min_diff_dttm']], how='left', on='user_id')

In [11]:
main['плотность транзакций'] = main['max_min_diff_dttm'] / main['general_trans_info_1000_count']
main['плотность дней'] = (main['max_min_diff_dttm'] + 1) / main['nunique_days']

Статистические фичи

In [12]:
cur = main[main.target != -1]
cur = cur.groupby(['customer_age', 'employee_count_nm'])['target'].agg(['mean']).reset_index().rename({'mean':'group_employee_age_mean'}, axis=1)
main = main.merge(cur, how='left', on=['customer_age', 'employee_count_nm'])
main.loc[main.target != -1, 'group_report_employee_age_mean'] = -1

cur = main[main.target != -1]
cur = cur.groupby(['report', 'customer_age'])['target'].agg(['mean']).reset_index().rename({'mean':'group_report_age_mean'}, axis=1)
main = main.merge(cur, how='left', on=['report', 'customer_age'])
main.loc[main.target != -1, 'group_report_age_mean'] = -1

Смотрим на MCC код последних транзакций, берем средний средний таргет (есть небольшой лик:) )

In [13]:
cur = transactions.merge(main[['user_id', 'target']], on='user_id', how='left').groupby('user_id').last().reset_index()
last_transaction = cur[cur['target'] != -1].groupby('mcc_code')['target'].agg(['mean', 'count']).sort_values('mean').reset_index()
last_transaction = last_transaction[last_transaction['count'] > 100]

cur = cur[['user_id', 'mcc_code']].merge(last_transaction[['mcc_code', 'mean']], how='left', on='mcc_code').rename({'mean': 'mean_target_last_mcc_code'}, axis=1)
main = main.merge(cur[['user_id', 'mean_target_last_mcc_code']], how='left', on='user_id')

Подготовка данных к обучение модели

In [14]:
cat_cols = ['customer_age', 'employee_count_nm', 'report']
main[cat_cols] = main[cat_cols].astype(str)

main = main.sort_values('user_id').reset_index(drop=True)
train = main[main.target != -1]
test = main[main.target == -1]

Обучение модельки для того чтобы получить важные фичи

In [15]:
model = CatBoostClassifier(
    iterations = 1400,
    depth=5,
    learning_rate=0.03,

    eval_metric='AUC',
    cat_features = cat_cols,
    thread_count=6,
    early_stopping_rounds=200,
)
model.fit(train.drop(['user_id', 'target', 'time', 'group_employee_age_mean', 'group_report_age_mean'], axis=1), train['target'], verbose=100)


df_imp = pd.DataFrame({
    'name': train.drop(['user_id', 'target', 'time', 'group_employee_age_mean', 'group_report_age_mean'], axis=1).columns,
    'imp': model.get_feature_importance()
}).sort_values('imp', ascending=False)
# display(df_imp) # Можно посмотреть на предварительный feature_importance()

df_imp = df_imp[df_imp['imp'] > 0.3] # Берем все фичи, у которых важность больше 0.3

# Добавляем статистические фичи, их нельзя было использовать для тренировки здесь, т.к. получился бы лик в данных
good_cols = df_imp['name'].tolist() + ['group_employee_age_mean', 'group_report_age_mean']

0:	total: 92.2ms	remaining: 2m 9s
100:	total: 3.8s	remaining: 48.8s
200:	total: 7.7s	remaining: 45.9s
300:	total: 11.4s	remaining: 41.6s
400:	total: 15.1s	remaining: 37.6s
500:	total: 18.8s	remaining: 33.8s
600:	total: 22.4s	remaining: 29.8s
700:	total: 26.1s	remaining: 26s
800:	total: 29.8s	remaining: 22.3s
900:	total: 33.3s	remaining: 18.4s
1000:	total: 36.7s	remaining: 14.6s
1100:	total: 40.2s	remaining: 10.9s
1200:	total: 43.8s	remaining: 7.26s
1300:	total: 47.4s	remaining: 3.6s
1399:	total: 50.9s	remaining: 0us


Обучение основных моделей на 5 Фолдах. Стратификация по report, возможно следует попробовать что нибудь другое:)

In [16]:
strat_kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

X, y = train.drop(['time', 'group_employee_age_mean', 'group_report_age_mean'], axis=1), train['target']
scores = []
models = []
for train_index, valid_index in strat_kfold.split(train, train['report']):
    
    X_train, X_val = X.iloc[train_index], X.iloc[valid_index]
    y_train, y_val, target_val = y.iloc[train_index], y.iloc[valid_index], train['target'].iloc[valid_index]

    
    # Для того чтобы не было лика, генерируем статистические фичи только на данных для тренировки
    cur = X_train[['employee_count_nm', 'customer_age', 'target']].groupby(['customer_age', 'employee_count_nm'])['target'].agg(['mean']).reset_index().rename({'mean':'group_employee_age_mean'}, axis=1)
    X_train = X_train.merge(cur, how='left', on=['customer_age', 'employee_count_nm'])
    X_val = X_val.merge(cur, how='left', on=['customer_age', 'employee_count_nm'])

    cur = X_train[['report', 'customer_age', 'target']].groupby(['customer_age', 'report'])['target'].agg(['mean']).reset_index().rename({'mean':'group_report_age_mean'}, axis=1)
    X_train = X_train.merge(cur, how='left', on=['customer_age', 'report']).drop(['user_id', 'target'], axis=1)
    X_val = X_val.merge(cur, how='left', on=['customer_age', 'report']).drop(['user_id', 'target'], axis=1)
    

    model = CatBoostClassifier(
        iterations = 2500,
        depth=4,
        learning_rate=0.03,
        eval_metric='AUC',
        cat_features = cat_cols,
        early_stopping_rounds=400,
    )

    model.fit(Pool(X_train[good_cols], y_train, cat_features=cat_cols),
              eval_set=Pool(X_val[good_cols], y_val, cat_features=cat_cols),
              verbose=100)
    models.append(model)
    
    pred = model.predict_proba(X_val[good_cols])[:, 1]
    scores.append(metrics.roc_auc_score(y_val, pred))

np.mean(scores)

0:	test: 0.5872496	best: 0.5872496 (0)	total: 12.4ms	remaining: 31s
100:	test: 0.7414132	best: 0.7414132 (100)	total: 1.24s	remaining: 29.4s
200:	test: 0.7488502	best: 0.7488502 (200)	total: 2.46s	remaining: 28.1s
300:	test: 0.7534694	best: 0.7534694 (300)	total: 3.78s	remaining: 27.6s
400:	test: 0.7559216	best: 0.7559896 (393)	total: 5.15s	remaining: 27s
500:	test: 0.7577898	best: 0.7577898 (500)	total: 6.43s	remaining: 25.7s
600:	test: 0.7583763	best: 0.7584230 (593)	total: 7.71s	remaining: 24.4s
700:	test: 0.7589259	best: 0.7589403 (699)	total: 8.99s	remaining: 23.1s
800:	test: 0.7594282	best: 0.7594587 (794)	total: 10.4s	remaining: 22.1s
900:	test: 0.7599972	best: 0.7600174 (899)	total: 11.7s	remaining: 20.8s
1000:	test: 0.7603364	best: 0.7603676 (990)	total: 13s	remaining: 19.4s
1100:	test: 0.7608296	best: 0.7608296 (1100)	total: 14.2s	remaining: 18.1s
1200:	test: 0.7609389	best: 0.7610706 (1180)	total: 15.4s	remaining: 16.7s
1300:	test: 0.7613880	best: 0.7614363 (1270)	total: 16.

0.7655268808805896

In [17]:
pd.DataFrame({
    'name': good_cols,
    'imp': model.get_feature_importance()
}).sort_values('imp', ascending=False)[:50]

Unnamed: 0,name,imp
83,group_employee_age_mean,18.108158
2,percent_last_30,4.727142
4,nunique_mcc_codes,3.754328
84,group_report_age_mean,3.691214
3,max_min_diff_dttm,3.260093
5,positive_general_trans_info_1000_median,2.363426
7,percent_last_5,2.212472
6,currency_daydiff_1000_sum_1,2.052549
20,main_sum_transaction_amt_12,1.866844
8,hour_median_3,1.820598


In [19]:
sample = pd.read_csv('data/sample_submit_naive.csv')
sample['predict'] = 0
for i in range(len(models)):
    sample['predict'] += models[i].predict_proba(test[good_cols])[:, 1]

sample.to_csv('submit_baseline.csv', index=False)