In [5]:
from pathlib import Path
import pandas as pd
import numpy as np
import joblib
import h3
from tqdm import tqdm
import geocoder
from collections import defaultdict
from xgboost import XGBClassifier
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
import warnings
import optuna
warnings.filterwarnings('ignore')

In [6]:
data_root = Path('data')
hexses_target_path = data_root / 'hexses_target.lst'
hexses_data_path = data_root / 'hexses_data.lst'
hexses_suburb_path = data_root / 'hexses_suburb.parquet'

train_data_fn   = data_root / 'transactions.parquet'
train_target_fn = data_root / 'target.parquet'

moscow_osm_path = data_root / 'moscow.parquet'

In [7]:
with open(hexses_target_path, "r") as f:
    hexses_target = [x.strip() for x in f.readlines()]
len(hexses_target)

1657

In [8]:
with open(hexses_data_path, "r") as f:
    hexses_data = [x.strip() for x in f.readlines()]
len(hexses_data)

8154

In [9]:
def extract_hexses_latlng(hexses_data, hexses_target):
    """
    Creates a dictionary of a form h3_index:(lat, lng).
    """
    hexses_latlng = dict()

    for hex in tqdm(hexses_data):
        latlng = h3.h3_to_geo(hex)
        hexses_latlng[hex] = latlng

    for hex in tqdm(hexses_target):
        latlng = h3.h3_to_geo(hex)
        hexses_latlng[hex] = latlng

    return hexses_latlng

hexses_latlng = extract_hexses_latlng(hexses_data, hexses_target)

100%|██████████| 8154/8154 [00:00<00:00, 595679.71it/s]
100%|██████████| 1657/1657 [00:00<00:00, 699191.32it/s]


In [10]:
# def extract_hexses_suburb(hexses_latlng):
#     hexses_suburb = dict()

#     for h3_09 in tqdm(hexses_latlng.keys()):
#         location = geocoder.osm([hexses_latlng[h3_09][0], hexses_latlng[h3_09][1]], method='reverse').json['raw']['address']
#         if location.get('suburb') is not None:
#             hexses_suburb[h3_09] = location['suburb']
#         elif location.get('municipality') is not None:
#             hexses_suburb[h3_09] = location['municipality']
#         else:
#             hexses_suburb[h3_09] = 'Другое'

#     return hexses_suburb

# hexses_suburb = extract_hexses_suburb(hexses_latlng)
# df_hexses_suburb = pd.DataFrame(pd.Series(hexses_suburb), columns=['suburb'])
# df_hexses_suburb.to_parquet('hexses_suburb.parquet')

In [14]:
hexses_suburb = pd.read_parquet(hexses_suburb_path)['suburb'].to_dict()

In [16]:
all_hexses = list(set(hexses_data) | set(hexses_target))
all_hexses = pd.DataFrame({"h3_09" : all_hexses})
all_hexses[['lat', 'lng']] = all_hexses['h3_09'].apply(lambda x: pd.Series(hexses_latlng[x]))
all_hexses['suburb'] = all_hexses['h3_09'].apply(lambda x: hexses_suburb[x])

In [17]:
city_center = '8911aa7abcbffff'

In [18]:
transactions = pd.read_parquet(train_data_fn)
target = (
    pd.read_parquet(train_target_fn)
    .assign(customer_id = lambda x: x.customer_id.astype(int))
    .pipe(lambda x: pd.pivot(x.assign(v = 1.), index='customer_id', columns='h3_09', values='v'))
    .pipe(lambda x: x.reindex(sorted(x.columns), axis=1))
    .sort_values(by='customer_id')
    .fillna(0)
)
display(transactions.head(10))
display(target.head(10))

Unnamed: 0,h3_09,customer_id,datetime_id,count,sum,avg,min,max,std,count_distinct,mcc_code
0,8911aa4c62fffff,1,3,1,3346.65,3346.65,3346.65,3346.65,,1,13
1,8911aa7b5b3ffff,4,3,1,450.0,450.0,450.0,450.0,,1,8
2,8911aa63623ffff,5,3,10,11035.69,1103.569,59.0,3620.18,1190.530333,6,13
3,8911aa48577ffff,9,2,2,628.0,314.0,295.0,333.0,26.870058,2,5
4,8911aa78297ffff,11,2,1,4155.0,4155.0,4155.0,4155.0,,1,10
5,8911aa78dc7ffff,12,2,2,94.0,47.0,47.0,47.0,0.0,2,9
6,8911aa4ec93ffff,13,3,25,9089.44,363.5776,176.99,707.96,162.457572,2,0
7,8911aa6a4c3ffff,15,1,10,1400.0,140.0,140.0,140.0,0.0,1,8
8,891181b6507ffff,16,2,1,10998.0,10998.0,10998.0,10998.0,,1,10
9,8911aa7ab0fffff,17,2,2,2760.0,1380.0,1380.0,1380.0,0.0,2,8


h3_09,8911818610bffff,89118195133ffff,8911819513bffff,891181b2827ffff,891181b2957ffff,891181b2b83ffff,891181b2b97ffff,891181b2ba7ffff,891181b2bb3ffff,891181b2d1bffff,...,8911aa7b5a7ffff,8911aa7b5b3ffff,8911aa7b617ffff,8911aa7b637ffff,8911aa7b643ffff,8911aa7b65bffff,8911aa7b663ffff,8911aa7b67bffff,8911aa7b687ffff,8911aa7b68fffff
customer_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
9,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
11,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
12,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
13,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
15,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
16,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
17,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


# Generating home-related features

In [19]:
transactions = pd.merge(transactions, all_hexses, on="h3_09")

In [20]:
transactions['mcc_code'].value_counts() # assume that mcc 13 is supermarkets

mcc_code
13    1512867
8      912094
5      422039
9      284659
0      208830
6      178608
10     163521
11      95126
17      90401
12      49357
4       38132
18      38103
14      37212
3       33450
1       26964
20      22288
16      12599
22       6149
15       5952
2        4486
7        2969
21       2917
23       1257
19       1116
Name: count, dtype: int64

In [21]:
# Assume that home is located where a customer buys more food
customers_homes = (transactions[transactions['mcc_code'] == 13]
 .groupby(['customer_id', 'h3_09', 'lat', 'lng'], as_index=False)['count'].sum()
 .sort_values(by=['customer_id', 'count'], ascending=False)
 .groupby('customer_id')
 .first()
 .reset_index())
customers_no_homes = set(transactions.customer_id.unique()) - set(customers_homes.customer_id)
estimated_customers_homes = (transactions[transactions['customer_id'].isin(customers_no_homes)]
                         .groupby(['customer_id', 'h3_09', 'lat', 'lng'], as_index=False)['count'].sum()
                         .sort_values(by=['customer_id', 'count'], ascending=False)
                         .groupby('customer_id', as_index=False)
                         .first())
customers_homes = pd.concat([customers_homes, estimated_customers_homes])
customers_homes

Unnamed: 0,customer_id,h3_09,lat,lng,count
0,1,891181948a3ffff,55.368683,37.226170,10
1,4,8911aa4e587ffff,55.629198,37.341031,51
2,5,8911aa6339bffff,55.796306,37.696827,13
3,9,8911aa48567ffff,55.471859,37.300095,123
4,11,8911aa7a117ffff,55.750370,37.536372,29
...,...,...,...,...,...
5,70520,8911aa7a967ffff,55.780645,37.633646,26
6,79081,8911aa7a977ffff,55.780939,37.628494,65
7,90572,8911aa7abd3ffff,55.756868,37.615785,94
8,92781,8911aa68e43ffff,55.698043,37.913812,38


In [22]:
customers_homes['center_lat'] = hexses_latlng[city_center][0]
customers_homes['center_lng'] = hexses_latlng[city_center][1]
customers_homes['home_center_dist'] = customers_homes.apply(lambda x: h3.point_dist((x['lat'], x['lng']),
                                                                                    (x['center_lat'], x['center_lng']),
                                                                                    unit='km'), axis=1)
lats = []
lngs = []
for h3_09 in hexses_data:
    lats.append(hexses_latlng[h3_09][0])
    lngs.append(hexses_latlng[h3_09][1])

customers_homes['median_lat'] = np.median(lats)
customers_homes['median_lng'] = np.median(lngs)

customers_homes['home_median_latlng_dist'] = customers_homes.apply(lambda x: h3.point_dist((x['lat'], x['lng']),
                                                                                    (x['median_lat'], x['median_lng']),
                                                                                    unit='km'), axis=1)

homes_distances = customers_homes[['customer_id', 'home_center_dist', 'home_median_latlng_dist']]
homes_distances

Unnamed: 0,customer_id,home_center_dist,home_median_latlng_dist
0,1,49.552419,45.087388
1,4,22.447249,17.821823
2,5,6.617184,11.272511
3,9,37.326697,32.807649
4,11,5.430813,3.526764
...,...,...,...
5,70520,3.069571,7.412017
6,79081,3.048361,7.270265
7,90572,0.570304,4.612182
8,92781,19.236200,21.779543


In [23]:
homes_transactions = transactions.merge(customers_homes[['customer_id', 'h3_09']], on=['customer_id', 'h3_09'], how='right')
homes_transactions = (homes_transactions
                     .groupby('customer_id')
                     .agg({'sum': 'sum', 'avg':'median', 'max': 'max', 'min': 'min', 'count': 'sum'})
                     .add_prefix('home_')
                     .reset_index())
homes_transactions

Unnamed: 0,customer_id,home_sum,home_avg,home_max,home_min,home_count
0,1,2801.08,260.664167,576.60,75.08,10
1,4,22807.86,99.000000,1791.90,13.73,78
2,5,4089.46,193.738000,862.56,61.81,13
3,9,39117.00,316.093333,1809.00,8.00,123
4,11,81664.40,506.644444,7900.00,40.00,88
...,...,...,...,...,...,...
69332,98388,65438.58,625.957500,11500.00,62.99,91
69333,98397,29717.29,937.822000,5119.00,8.49,53
69334,98409,9872.05,276.143750,1106.89,69.99,31
69335,98432,5362.47,312.171000,855.96,100.28,17


In [24]:
homes_suburbs = customers_homes[['customer_id', 'h3_09']]
homes_suburbs['home_suburb'] = homes_suburbs.apply(lambda x: hexses_suburb[x['h3_09']], axis=1)
homes_suburbs = pd.get_dummies(homes_suburbs, columns=['home_suburb']).drop('h3_09', axis=1)
homes_suburbs

Unnamed: 0,customer_id,home_suburb_1-й микрорайон,home_suburb_3-й микрорайон,home_suburb_Академический район,home_suburb_Алексеевский район,home_suburb_Алтуфьевский район,home_suburb_Бабушкинский район,home_suburb_Барыши,home_suburb_Басманный район,home_suburb_Бауманка,...,home_suburb_район Чертаново Северное,home_suburb_район Чертаново Центральное,home_suburb_район Чертаново Южное,home_suburb_район Черёмушки,home_suburb_район Щукино,home_suburb_район Южное Бутово,home_suburb_район Южное Медведково,home_suburb_район Южное Тушино,home_suburb_район Якиманка,home_suburb_район Ясенево
0,1,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
1,4,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
2,5,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
3,9,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
4,11,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5,70520,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
6,79081,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
7,90572,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
8,92781,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False


In [25]:
home_features = homes_suburbs.merge(homes_distances, on='customer_id').merge(homes_transactions, on='customer_id')

In [26]:
home_features

Unnamed: 0,customer_id,home_suburb_1-й микрорайон,home_suburb_3-й микрорайон,home_suburb_Академический район,home_suburb_Алексеевский район,home_suburb_Алтуфьевский район,home_suburb_Бабушкинский район,home_suburb_Барыши,home_suburb_Басманный район,home_suburb_Бауманка,...,home_suburb_район Южное Тушино,home_suburb_район Якиманка,home_suburb_район Ясенево,home_center_dist,home_median_latlng_dist,home_sum,home_avg,home_max,home_min,home_count
0,1,False,False,False,False,False,False,False,False,False,...,False,False,False,49.552419,45.087388,2801.08,260.664167,576.60,75.08,10
1,4,False,False,False,False,False,False,False,False,False,...,False,False,False,22.447249,17.821823,22807.86,99.000000,1791.90,13.73,78
2,5,False,False,False,False,False,False,False,False,False,...,False,False,False,6.617184,11.272511,4089.46,193.738000,862.56,61.81,13
3,9,False,False,False,False,False,False,False,False,False,...,False,False,False,37.326697,32.807649,39117.00,316.093333,1809.00,8.00,123
4,11,False,False,False,False,False,False,False,False,False,...,False,False,False,5.430813,3.526764,81664.40,506.644444,7900.00,40.00,88
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
69332,70520,False,False,False,False,False,False,False,False,False,...,False,False,False,3.069571,7.412017,1456.00,56.000000,56.00,56.00,26
69333,79081,False,False,False,False,False,False,False,False,False,...,False,False,False,3.048361,7.270265,12809.85,197.298879,500.00,25.00,65
69334,90572,False,False,False,False,False,False,False,False,False,...,False,False,False,0.570304,4.612182,36483.58,160.273411,7200.00,55.00,94
69335,92781,False,False,False,False,False,False,False,False,False,...,False,False,False,19.236200,21.779543,61958.00,1272.133333,5950.00,312.00,38


# Generating other features

In [27]:
class BaseTransform:
    def __init__(self, filepath=None):
        if filepath:
            self.load(filepath)
        pass

    def fit(self, transactions):
        pass

    def transform(self, transactions):
        return transactions

    def save(self, filepath):
        pass

    def load(self, filepath):
        pass

class SimpleFeaturesTransform(BaseTransform):
    def __init__(self, filepath=None):
        if filepath:
            self.load(filepath)
        pass

    def fit(self, transactions):
        self.mcc_codes = set(transactions.mcc_code.unique())
        self.datetime_ids = set(transactions.datetime_id.unique())
        self.hexses_data = set(hexses_data)
        self.suburbs = set(transactions.suburb.unique())
        chunk = dict()
        # for mcc_code in self.mcc_codes:
        #     chunk[f"mcc_{ mcc_code }_count"] = 0
        #     chunk[f"mcc_{ mcc_code }_sum"] = 0
        # for datetime_id in self.datetime_ids:
        #     chunk[f"dt_{ datetime_id }_count"] = 0
        #     chunk[f"dt_{ datetime_id }_sum"] = 0
        for h3_09 in self.hexses_data:
            chunk[f"hex_{ h3_09 }_count"] = 0
        for suburb in self.suburbs:
            chunk[f"{ suburb }_count"] = 0
        self.template = chunk

    def transform(self, transactions):
        features = []
        row_labels = []
        gb = transactions.groupby(by="customer_id")
        for customer_id, group in tqdm(gb):
            row_labels.append(customer_id)
            chunk = self.template.copy()

            # chunk['lat_median'] = group['lat'].median()
            # chunk['lng_median'] = group['lng'].median()

            # for mcc_code, subgroup in group.groupby(by='mcc_code'):
            #     if mcc_code in self.mcc_codes:
            #         chunk[f"mcc_{ mcc_code }_count"] = subgroup["count"].sum()
            #         chunk[f"mcc_{ mcc_code }_sum"] = subgroup["sum"].sum()

            # for datetime_id, subgroup in group.groupby(by='datetime_id'):
            #     if datetime_id in self.datetime_ids:
            #         chunk[f"dt_{ datetime_id }_count"] = subgroup["count"].sum()
            #         chunk[f"dt_{ datetime_id }_count"] = 0

            for h3_09, subgroup in group.groupby(by='h3_09'):
                if h3_09 in self.hexses_data:
                    chunk[f"hex_{ h3_09 }_count"] = subgroup['count'].sum()

            for suburb, subgroup in group.groupby(by='suburb'):
                if suburb in self.suburbs:
                    chunk[f"{ suburb }_count"] = subgroup['count'].sum()


            features.append(chunk)
        return pd.DataFrame(features, index=row_labels)

    def save(self, filepath):
        pass

    def load(self, filepath):
        pass

In [28]:
transformer = SimpleFeaturesTransform()
transformer.fit(transactions)
train_data = transformer.transform(transactions)

100%|██████████| 69337/69337 [06:12<00:00, 186.25it/s]


In [29]:
train_data = train_data.reset_index(names='customer_id')
train_data

Unnamed: 0,customer_id,hex_8911aa705d7ffff_count,hex_891181b6693ffff_count,hex_8911aa7a8dbffff_count,hex_8911aa70aabffff_count,hex_8911aa7155bffff_count,hex_8911aa45a37ffff_count,hex_8911aa7946fffff_count,hex_8911aa6ada7ffff_count,hex_8911aa60687ffff_count,...,Алексеевский район_count,район Капотня_count,Град Московский_count,Красная горка_count,район Аэропорт_count,район Новогиреево_count,Южнопортовый район_count,район Северное Измайлово_count,Некрасовка_count,поселение Внуковское_count
0,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,4,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,35,226
2,5,0,0,0,0,0,0,0,0,0,...,2,0,0,0,0,0,0,0,6,0
3,9,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,33,0
4,11,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,13,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
69332,98388,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,56,0
69333,98397,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,65,0
69334,98409,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
69335,98432,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [30]:
train_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 69337 entries, 0 to 69336
Columns: 8341 entries, customer_id to поселение Внуковское_count
dtypes: int64(8341)
memory usage: 4.3 GB


# Training models

In [33]:
train_data = train_data.merge(home_features, on='customer_id')

In [34]:
train_data = train_data.drop('customer_id', axis=1)
train_data

Unnamed: 0,hex_8911aa705d7ffff_count,hex_891181b6693ffff_count,hex_8911aa7a8dbffff_count,hex_8911aa70aabffff_count,hex_8911aa7155bffff_count,hex_8911aa45a37ffff_count,hex_8911aa7946fffff_count,hex_8911aa6ada7ffff_count,hex_8911aa60687ffff_count,hex_8911aa7229bffff_count,...,home_suburb_район Южное Тушино,home_suburb_район Якиманка,home_suburb_район Ясенево,home_center_dist,home_median_latlng_dist,home_sum,home_avg,home_max,home_min,home_count
0,0,0,0,0,0,0,0,0,0,0,...,False,False,False,49.552419,45.087388,2801.08,260.664167,576.60,75.08,10
1,0,0,0,0,0,0,0,0,0,0,...,False,False,False,22.447249,17.821823,22807.86,99.000000,1791.90,13.73,78
2,0,0,0,0,0,0,0,0,0,0,...,False,False,False,6.617184,11.272511,4089.46,193.738000,862.56,61.81,13
3,0,0,0,0,0,0,0,0,0,0,...,False,False,False,37.326697,32.807649,39117.00,316.093333,1809.00,8.00,123
4,0,0,0,0,0,0,0,0,0,0,...,False,False,False,5.430813,3.526764,81664.40,506.644444,7900.00,40.00,88
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
69332,0,0,0,0,0,0,0,0,0,0,...,False,False,False,0.570304,4.612182,65438.58,625.957500,11500.00,62.99,91
69333,0,0,0,0,0,0,0,0,0,0,...,False,False,False,0.570304,4.612182,29717.29,937.822000,5119.00,8.49,53
69334,0,0,0,0,0,0,0,0,0,0,...,False,False,False,11.249410,7.432063,9872.05,276.143750,1106.89,69.99,31
69335,0,0,0,0,0,0,0,0,0,0,...,False,False,False,22.622546,17.968165,5362.47,312.171000,855.96,100.28,17


In [36]:
def MBCE(prediction, target):
    eps = 1e-8
    mbce = (-np.log(np.clip(prediction, eps, 1 - eps)) * target \
           - np.log(np.clip(1 - prediction, eps, 1 - eps)) * (1 - target)).sum(axis=1).mean()
    return mbce

In [None]:
MBCE_trial_scores = []

def objective(trial, X, y):
    params = {
        'objective': 'binary:logistic',
        'verbosity': 0,
        'random_state': 1337,
        'tree_method': 'gpu_hist',
        'sampling_method': 'gradient_based',
        'nthread': -1,
        'n_estimators': trial.suggest_int("n_estimators", 400, 1000),
        "learning_rate": trial.suggest_float("learning_rate", 0.075, 0.1),
        'gamma': trial.suggest_float('gamma', 0, 5),
        "reg_alpha": trial.suggest_float("reg_alpha", 1e-9, 10, log=True),
        "reg_lambda": trial.suggest_float("reg_lambda", 1e-9, 10, log=True),
        "max_depth": trial.suggest_int("max_depth", 4, 10),
        "colsample_bytree": trial.suggest_float("colsample_bytree", 0.25, 0.9),
        "subsample": trial.suggest_float("subsample", 0.25, 0.9),
        "min_child_samples": trial.suggest_int("min_child_samples", 6, 100),
    }

    MBCE_val_scores = []

    skfold = MultilabelStratifiedKFold(n_splits=3, shuffle=True, random_state=1337)

    print(f'TRIAL: {trial.number}')
    for fold, (train_indices, test_indices) in enumerate(skfold.split(X, y)):
        print(f'Fold: {fold}')
        print(f"Train indices: {train_indices}, Test indices: {test_indices}")
        X_train, X_val = X.iloc[train_indices], X.iloc[test_indices]
        y_train, y_val = y.iloc[train_indices], y.iloc[test_indices]

        model = XGBClassifier(**params)

        model.fit(X_train, y_train, eval_set=[(X_val, y_val)], verbose=5, early_stopping_rounds=5)
        joblib.dump(model, f'model_{trial.number}_{fold}.pkl')

        train_preds = model.predict_proba(X_train)
        val_preds = model.predict_proba(X_val)

        train_score = MBCE(train_preds, y_train)
        val_score = MBCE(val_preds, y_val)

        print(f'Train MBCE: {train_score}')
        print(f'Test MBCE: {val_score}')

        MBCE_val_scores.append(val_score)

    score = np.array(MBCE_val_scores).mean()
    MBCE_trial_scores.append(score)
    return score

sampler = optuna.samplers.TPESampler(seed=1337)
study = optuna.create_study(direction='minimize', sampler=sampler)
study.optimize(lambda trial: objective(trial, train_data, target), n_trials=10)
display(study.best_params)

[I 2024-04-05 16:20:43,097] A new study created in memory with name: no-name-c14c3e9e-2c49-47bd-a74d-3d6573a038c8


TRIAL: 0
Fold: 0
Train indices: [    0     2     4 ... 69333 69334 69336], Test indices: [    1     3     6 ... 69318 69332 69335]
[0]	validation_0-logloss:0.12009
[5]	validation_0-logloss:0.08064
[10]	validation_0-logloss:0.05508
[15]	validation_0-logloss:0.03830
[20]	validation_0-logloss:0.02722
[25]	validation_0-logloss:0.01987
[30]	validation_0-logloss:0.01498
[35]	validation_0-logloss:0.01174
[40]	validation_0-logloss:0.00959
[45]	validation_0-logloss:0.00816
[50]	validation_0-logloss:0.00723
[55]	validation_0-logloss:0.00662
[60]	validation_0-logloss:0.00622
[65]	validation_0-logloss:0.00596
[70]	validation_0-logloss:0.00580
[75]	validation_0-logloss:0.00570
[80]	validation_0-logloss:0.00565
[85]	validation_0-logloss:0.00562
[90]	validation_0-logloss:0.00561
[95]	validation_0-logloss:0.00561
[97]	validation_0-logloss:0.00561
Train MBCE: 3.5585604045847714
Test MBCE: 9.287464237785262
Fold: 1
Train indices: [    1     3     4 ... 69332 69335 69336], Test indices: [    0     2     

[I 2024-04-05 21:41:06,981] Trial 0 finished with value: 9.295364696408702 and parameters: {'n_estimators': 557, 'learning_rate': 0.07896709930386164, 'gamma': 1.3906325974717988, 'reg_alpha': 3.918942326120748e-05, 'reg_lambda': 1.6218302824305035e-06, 'max_depth': 7, 'colsample_bytree': 0.42026290161734403, 'subsample': 0.8844554351703319, 'min_child_samples': 75}. Best is trial 0 with value: 9.295364696408702.


Train MBCE: 3.5802083699509852
Test MBCE: 9.2834926316722
TRIAL: 1
Fold: 0
Train indices: [    0     2     4 ... 69333 69334 69336], Test indices: [    1     3     6 ... 69318 69332 69335]
[0]	validation_0-logloss:0.11939
[5]	validation_0-logloss:0.07800
[10]	validation_0-logloss:0.05197
[15]	validation_0-logloss:0.03537
[20]	validation_0-logloss:0.02471
[25]	validation_0-logloss:0.01784
[30]	validation_0-logloss:0.01340
[35]	validation_0-logloss:0.01053
[40]	validation_0-logloss:0.00868
[45]	validation_0-logloss:0.00749
[50]	validation_0-logloss:0.00673
[55]	validation_0-logloss:0.00624
[60]	validation_0-logloss:0.00593
[65]	validation_0-logloss:0.00573
[70]	validation_0-logloss:0.00560
[75]	validation_0-logloss:0.00552
[80]	validation_0-logloss:0.00547
[85]	validation_0-logloss:0.00544
[90]	validation_0-logloss:0.00543
[95]	validation_0-logloss:0.00542
[100]	validation_0-logloss:0.00542
[105]	validation_0-logloss:0.00541
[110]	validation_0-logloss:0.00541
[115]	validation_0-logloss:0

[I 2024-04-06 07:00:46,216] Trial 1 finished with value: 8.965012398077521 and parameters: {'n_estimators': 469, 'learning_rate': 0.08465687671585899, 'gamma': 3.142505897698558, 'reg_alpha': 1.7806528672608995e-08, 'reg_lambda': 6.846774924343185, 'max_depth': 7, 'colsample_bytree': 0.7632129223155371, 'subsample': 0.7661770742645313, 'min_child_samples': 40}. Best is trial 1 with value: 8.965012398077521.


TRIAL: 2
Fold: 0
Train indices: [    0     2     4 ... 69333 69334 69336], Test indices: [    1     3     6 ... 69318 69332 69335]
[0]	validation_0-logloss:0.11874
[5]	validation_0-logloss:0.07566
[10]	validation_0-logloss:0.04929
[15]	validation_0-logloss:0.03290
[20]	validation_0-logloss:0.02264
[25]	validation_0-logloss:0.01620
[30]	validation_0-logloss:0.01214
[35]	validation_0-logloss:0.00959
[40]	validation_0-logloss:0.00800
[45]	validation_0-logloss:0.00700
[50]	validation_0-logloss:0.00638
[55]	validation_0-logloss:0.00601
[60]	validation_0-logloss:0.00578
[65]	validation_0-logloss:0.00564
[70]	validation_0-logloss:0.00557
[75]	validation_0-logloss:0.00553
[80]	validation_0-logloss:0.00551
[85]	validation_0-logloss:0.00551
[90]	validation_0-logloss:0.00550
[95]	validation_0-logloss:0.00550
[100]	validation_0-logloss:0.00550
[105]	validation_0-logloss:0.00550
Train MBCE: 5.078896701130004
Test MBCE: 9.119426585709885
Fold: 1
Train indices: [    1     3     4 ... 69332 69335 6933