# Lightgbm Model

In [309]:
from tqdm import tqdm
import pandas as pd
import numpy as np

In [310]:
# cross category
def cross_category(cat_cols,data,col1,col2):
    column = f'{col1}_{col2}'
    new_cate = []
    for cat1,cat2 in zip(data[col1].values,data[col2].values):
        new_cate.append(f'{cat1}_{cat2}')
    data[column] = new_cate
    cat_cols.append(column)

In [311]:
clean_features = []
cat_features = []

## BaseLine Data Handle

In [312]:

base_info = pd.read_csv('./data/train/base_info.csv')
# 把缺失数量作为一种编码
base_info_clean = base_info
base_info_clean['nan_num'] = base_info.isnull().sum(axis=1)

nums, shapes = base_info_clean.shape
# 删除缺失 70%以上的数据
for name, count in base_info_clean.isnull().sum().items():
    if count * 1.0 / nums >= 0.70:
        base_info_clean.drop([name], axis=1, inplace=True)

# 删除类别相同的数据
for name, count in base_info_clean.nunique().items():
    if count == 0:
        base_info_clean.drop([name], axis=1, inplace=True)


In [313]:
base_info_clean.drop('dom',axis=1,inplace=True)
# 正则化分词，先去除掉括号里面的内容
print(base_info_clean['opscope'].head(2))
opscope = base_info_clean['opscope']
opscope.str.split(r',|、|。|;|，',expand = True).head(2)
# 这一行先删掉，还没想好怎么处理,感觉与类别强相关
base_info_clean.drop('opscope',axis=1,inplace=True)

0    纳米新材料、机械设备、五金配件加工、销售及技术推广服务，道路货物运输。（依法须经批准的项目，...
1                    健身服务。（依法须经批准的项目，经相关部门批准后方可开展经营活动）
Name: opscope, dtype: object


In [314]:
# 处理日期类
date_cols = ['opfrom','opto']
for col in tqdm(date_cols):
    base_info_clean[f'{col}_year'] = pd.to_datetime(base_info[col]).dt.year.fillna(-1)
base_info_clean['dt'] = base_info_clean['opto_year'] -  base_info_clean['opfrom_year']
base_info_clean['dt'] = np.maximum(base_info_clean['dt'].values,-1)
base_info_clean.drop(date_cols,axis=1,inplace=True)

100%|██████████| 2/2 [00:00<00:00, 73.14it/s]


In [315]:
# 处理 category 类
base_info_clean['opform'] = base_info_clean['opform'].replace('01', '01-以个人财产出资').replace('02', '02-以家庭共有财产作为个人出资')

cat_cols = ['oplocdistrict','industryphy','industryco','enttype','enttypeitem',
              'state','orgid','jobid',
              'adbusign','townsign','regtype',
              'compform','opform','venind','oploc','enttypegb']

cat_len = len(cat_cols)
for i in tqdm(range(cat_len)):
    for j in range(i+1,cat_len,1):
        # 类别交叉
        cross_category(cat_cols,base_info_clean,cat_cols[i],cat_cols[j])
        

100%|██████████| 16/16 [00:03<00:00,  4.29it/s]


In [316]:
# 类别编码
for cat_col in tqdm(cat_cols):
    base_info_clean[cat_col] = base_info_clean[cat_col].astype('category').cat.codes

100%|██████████| 136/136 [00:03<00:00, 39.12it/s]


In [317]:
# 数值数据进行分桶处理

# 对于 nan_num 采用
# base_info_clean = base_info_clean.sort_values(by='nan_num',ascending=False)
# 手动分箱 {7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
# <= 9, <= 11, 12, 13, >= 14
base_info_clean['nan_num_bin'] = 1
base_info_clean.loc[base_info_clean['nan_num'] > 9,'nan_num_bin'] = 2
base_info_clean.loc[base_info_clean['nan_num'] >= 11,'nan_num_bin'] = 3
base_info_clean.loc[base_info_clean['nan_num'] >= 12,'nan_num_bin'] = 4
base_info_clean.loc[base_info_clean['nan_num'] >= 13,'nan_num_bin'] = 5
base_info_clean.loc[base_info_clean['nan_num'] >= 14,'nan_num_bin'] = 6
cat_cols.append('nan_num_bin')
print("nan_num 分桶完毕 ......... ")

nan_num 分桶完毕 ......... 


In [318]:
#注册资本分桶
base_info_clean['regcap']=base_info_clean['regcap'].fillna(base_info_clean['regcap'].median())
base_info_clean = base_info_clean.sort_values(by='regcap')
base_info_clean['regcap_bin']=pd.qcut(base_info_clean['regcap'],6,labels = False)
cat_cols.append('regcap_bin')
print("注册资本 regcap_bin 分桶完毕 ......... ")

注册资本 regcap_bin 分桶完毕 ......... 


In [319]:
# empnum 分桶
base_info_clean['empnum']=base_info_clean['empnum'].fillna(base_info_clean['empnum'].median())
base_info_clean = base_info_clean.sort_values(by='empnum')
base_info_clean['empnum_bin']=pd.cut(base_info_clean['empnum'],4,labels = False)
cat_cols.append('empnum_bin')
print("empnum_bin 分桶完毕 ......... ")

empnum_bin 分桶完毕 ......... 


In [320]:
# dt split bin
base_info_clean['dt_bin'] = 1
base_info_clean.loc[base_info_clean['dt'] >= 0,'dt_bin'] = 2
base_info_clean.loc[base_info_clean['dt'] >= 30,'dt_bin'] = 3
base_info_clean.loc[base_info_clean['dt'] >= 50,'dt_bin'] = 4
base_info_clean.loc[base_info_clean['dt'] > 50,'dt_bin'] = 5
cat_cols.append('dt_bin')
print("dt_bin 分桶完毕 ......... ")



dt_bin 分桶完毕 ......... 


In [321]:
clean_features.append(base_info_clean)
cat_features.extend(cat_cols)

In [322]:
print(' baseline handle finish --------')

 baseline handle finish --------


## annual_report_info data handle

In [323]:
annual_report_info=pd.read_csv('./data/train/annual_report_info.csv')#企业的年报基本信息
count, shapes = annual_report_info.shape
#空值大于0.7的列都删除掉
annual_report_info_clean=annual_report_info.dropna(thresh=annual_report_info.shape[0]*0.7,how='all',axis=1)

In [324]:
# 人数的信息
num_cols = ['EMPNUM','COLGRANUM','RETSOLNUM','DISPERNUM','UNENUM','COLEMPLNUM','RETEMPLNUM','DISEMPLNUM','UNEEMPLNUM']
annual_report_info_clean[num_cols] = annual_report_info_clean[num_cols].fillna(-1,axis = 1)
annual_report_info_clean[num_cols].head()

Unnamed: 0,EMPNUM,COLGRANUM,RETSOLNUM,DISPERNUM,UNENUM,COLEMPLNUM,RETEMPLNUM,DISEMPLNUM,UNEEMPLNUM
0,10.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,4.0,3.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
3,3.0,1.0,0.0,0.0,0.0,2.0,0.0,0.0,0.0
4,10.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [325]:
# delete 0 and -1 >= 0.8
for num_col in tqdm(num_cols):
    num  = (annual_report_info_clean[num_col].values <= 0).sum()
    if num*1.0/count >= 0.8:
        annual_report_info_clean.drop(num_col,axis = 1,inplace=True)

100%|██████████| 9/9 [00:00<00:00, 765.82it/s]


In [326]:
num_cols = list(set(num_cols) & set(annual_report_info_clean.columns))
annual_report_info_clean[num_cols].head()

Unnamed: 0,COLEMPLNUM,COLGRANUM,EMPNUM
0,0.0,0.0,10.0
1,0.0,0.0,2.0
2,1.0,3.0,4.0
3,2.0,1.0,3.0
4,0.0,0.0,10.0


In [327]:
# category 类别进行编码
cat_cols = ['STATE','EMPNUMSIGN','BUSSTNAME','FORINVESTSIGN','WEBSITSIGN','FORINVESTSIGN','PUBSTATE']
# 类别信息 取最新非nan数据
# 雇佣人数，取均值，增量，
grouped = annual_report_info_clean.sort_values(by='ANCHEYEAR',ascending= False).groupby('id')
clean_infos = []
for name, group_info in tqdm(grouped):
     clean_info = {'id':name}
     clean_info['ANCHEYEAR'] = group_info['ANCHEYEAR'].values[-1]
     clean_info['ANCHEYEAR_DT'] = group_info['ANCHEYEAR'].values[-1] - group_info['ANCHEYEAR'].values[0]
     clean_info['REPORT_NUM'] = len(group_info) 
     clean_info['HAS_REPORT'] = 1.0
     for cat_col in cat_cols:
         clean_info[cat_col] = group_info[cat_col].values[-1]
     for num_col in num_cols:
         clean_info[f'{num_col}'] = group_info[num_col].values[-1]
         clean_info[f'{num_col}_MEAN'] = group_info[num_col].values.mean()
         clean_info[f'{num_col}_ADD'] = 0.0
         if len(group_info) > 1:
             clean_info[f'{num_col}_ADD'] = group_info[num_col].values[-1] - group_info[num_col].values[-2]     
     clean_infos.append(clean_info)
annual_report_info_clean = pd.DataFrame(clean_infos)

100%|██████████| 8937/8937 [00:05<00:00, 1699.50it/s]


In [328]:
# num split bin
# num_cols 分桶
for num_col in tqdm(num_cols):
    annual_report_info_clean[num_col]=annual_report_info_clean[num_col].fillna(annual_report_info_clean[num_col].median())
    annual_report_info_clean = annual_report_info_clean.sort_values(by=num_col)
    annual_report_info_clean[f'{num_col}_bin']=pd.cut(annual_report_info_clean[num_col],3,labels = False)
    cat_cols.append(f'{num_col}_bin')   
    print(f"{num_col} 分桶完毕 ......... ")

100%|██████████| 3/3 [00:00<00:00, 154.88it/s]COLEMPLNUM 分桶完毕 ......... 
COLGRANUM 分桶完毕 ......... 
EMPNUM 分桶完毕 ......... 



In [329]:
cat_features.extend(cat_cols)
clean_features.append(annual_report_info_clean)
print("annual_report_info_clean handle end ......... ")

annual_report_info_clean handle end ......... 


## tax_info data handle

In [330]:
tax_info = pd.read_csv('./data/train/tax_info.csv')
count, shapes = tax_info.shape
#空值大于0.7的列都删除掉
tax_info_clean=tax_info.dropna(thresh= count*0.7,how='all',axis=1)

In [331]:
groups = tax_info_clean.groupby('id')
tax_cleans = []
tax_cols = ['TAX_AMOUNT']
for name, group_info in tqdm(groups):
    tax_clean ={'id':name}
    tax_clean['TAX_NUM'] = len(group_info)
    tax_clean['HAS_TAX'] = 1.0
    for tax_col in tax_cols:
        tax_clean[f'{tax_col}_MEAN'] = group_info[tax_col].dropna().values.mean()
    tax_cleans.append(tax_clean)
tax_info_clean = pd.DataFrame(tax_cleans)

100%|██████████| 808/808 [00:00<00:00, 3414.62it/s]


In [332]:
clean_features.append(tax_info_clean)
cat_features.append('HAS_TAX')
print("tax_info_clean handle end ......... ")

tax_info_clean handle end ......... 


## news_info data handle

In [333]:
news_info = pd.read_csv('./data/train/news_info.csv')
count, shape = news_info.shape

In [334]:
news_info['public_date'] = news_info['public_date'].replace('\d+\D+前$','2020-10-01',regex=True)

In [335]:
news_info['public_date'].head(20)

0     2016-12-30
1     2017-08-09
2     2016-02-29
3     2018-06-08
4     2015-06-29
5     2015-06-15
6     2019-10-26
7     2017-11-01
8     2018-04-20
9     2018-01-08
10    2017-12-14
11    2015-05-12
12    2017-11-28
13    2016-10-17
14    2019-03-29
15    2019-04-18
16    2018-04-11
17    2016-07-14
18    2018-04-02
19    2016-07-20
Name: public_date, dtype: object

In [336]:
news_info['positive_negtive'] = news_info['positive_negtive'].fillna('-1')
news_info['public_date_year'] = pd.to_datetime(news_info['public_date']).dt.year.fillna(-1)
groups = news_info.sort_values(by = 'public_date_year',ascending= False).groupby('id')
# 最近的情感色彩，以及最多的情感色彩 和次数，缺失默认为-1
code_map = {'中立':2,'消极':1,'积极':3,'-1':-1}
news_info_cleans = []
for name, group in tqdm(groups):
    news_info_clean = {'id':name}
    news_info_clean['public_date_year'] = group['public_date_year'].values[0]
    news_info_clean['public_date_year_dt'] = group['public_date_year'].values[0] - group['public_date_year'].values[-1]
    news_info_clean['positive_negtive_mode'] = group['positive_negtive'].mode().values[0]
    news_info_clean['positive_negtive_last'] = group['positive_negtive'].values[0]
    news_info_clean['positive_negtive_num'] = len(group)
    news_info_clean['has_news_info'] = 1.0
    news_info_cleans.append(news_info_clean)
news_info_clean = pd.DataFrame(news_info_cleans)

100%|██████████| 927/927 [00:00<00:00, 2674.55it/s]


In [337]:
cat_features.append('has_news_info')

In [338]:
news_info_clean.head()

Unnamed: 0,id,public_date_year,public_date_year_dt,positive_negtive_mode,positive_negtive_last,positive_negtive_num,has_news_info
0,09912c34159b1720558a419983a989f1dd2e0ed69a044ca3,2016,0,中立,中立,6,1.0
1,175ebe5f059ec050afbd65251ecdd3b512bfbe5e62d041b0,2020,3,积极,中立,7,1.0
2,216bd2aaf4d079240c3ac0b76f0ef4aa355d443880ba78db,2020,0,积极,积极,3,1.0
3,216bd2aaf4d079240f5823e63d24b44dd2c58e3281b822f6,2020,0,中立,中立,2,1.0
4,216bd2aaf4d0792410725ba5e7ca1dc32ce55767372f2030,2014,0,消极,消极,1,1.0


In [339]:
clean_features.append(news_info_clean)
print("news_info_clean handle end ......... ")

news_info_clean handle end ......... 


## change_info data handle

In [340]:
change_info = pd.read_csv('./data/train/change_info.csv')
count, shape = change_info.shape

In [341]:
change_info['bgrq'] = change_info['bgrq'].astype(str).str[0:4].astype(int)

In [342]:
groups = change_info.sort_values(by='bgrq',ascending= False).groupby('id')
change_info_cleans = []
for name, group in tqdm(groups):
    change_info_clean = {}
    change_info_clean['id'] = name
    change_info_clean['has_change_info'] = 1.0
    change_info_clean['bgxmdm'] = group['bgxmdm'].values[0]
    change_info_clean['bgrq'] = group['bgrq'].values[0]
    change_info_cleans.append(change_info_clean)
change_info_clean = pd.DataFrame(change_info_cleans)

100%|██████████| 8726/8726 [00:01<00:00, 5661.62it/s]


In [343]:
cat_features.append('has_change_info')
cat_features.append('bgxmdm')

In [344]:
change_info_clean.describe()

Unnamed: 0,has_change_info,bgxmdm,bgrq
count,8726.0,8726.0,8726.0
mean,1.0,175.499312,2017.816984
std,0.0,207.917334,2.169936
min,1.0,110.0,1999.0
25%,1.0,112.0,2017.0
50%,1.0,113.0,2018.0
75%,1.0,121.0,2019.0
max,1.0,939.0,2020.0


In [345]:
clean_features.append(change_info_clean)
print("change_info_clean handle end ......... ")

change_info_clean handle end ......... 


## Other_info data handle

In [346]:

other_info = pd.read_csv('./data/train/other_info.csv')
count, shape = other_info.shape
other_info.nunique(),count

(id                    1888
 legal_judgment_num      93
 brand_num               82
 patent_num             114
 dtype: int64,
 1890)

In [347]:
other_info_clean=other_info.dropna(thresh= count*0.5,how='all',axis=1)

In [348]:
groups = other_info_clean.groupby('id')
other_info_cleans = []
for name, group in tqdm(groups):
    other_info_clean = {'id':name}
    other_info_clean['has_other_info'] = 1.0
    other_info_clean['legal_judgment_num'] = group['legal_judgment_num'].sum()
    other_info_clean['has_legal_judgment'] = int(len(group['legal_judgment_num'].dropna()) > 0)
    other_info_cleans.append(other_info_clean)
other_info_clean = pd.DataFrame(other_info_cleans)

100%|██████████| 1888/1888 [00:00<00:00, 2792.45it/s]


In [349]:
cat_features.append('has_other_info')
cat_features.append('has_legal_judgment')

In [350]:
other_info_clean.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1888 entries, 0 to 1887
Data columns (total 4 columns):
 #   Column              Non-Null Count  Dtype  
---  ------              --------------  -----  
 0   id                  1888 non-null   object 
 1   has_other_info      1888 non-null   float64
 2   legal_judgment_num  1888 non-null   float64
 3   has_legal_judgment  1888 non-null   int64  
dtypes: float64(2), int64(1), object(1)
memory usage: 59.1+ KB


In [351]:
clean_features.append(other_info_clean)
print("other_info_clean handle end ......... ")

other_info_clean handle end ......... 


## Merhe features

In [352]:

features = clean_features[0]
for clean_feature in clean_features[1:]:
    features = pd.merge(features, clean_feature, on='id',how = 'left')
features.head()

Unnamed: 0,id,oplocdistrict,industryphy,industryco,enttype,enttypeitem,state,orgid,jobid,adbusign,...,positive_negtive_mode,positive_negtive_last,positive_negtive_num,has_news_info,has_change_info,bgxmdm,bgrq,has_other_info,legal_judgment_num,has_legal_judgment
0,82750f1b9d1223507d25fecaca05aec1cfdf7ceb97a535f1,1,14,244,0,2,1,40,156,0,...,,,,,,,,,,
1,f000950527a6feb6a005c09bbd2e2696880df8cd9b350ae2,1,14,246,0,2,1,30,137,0,...,,,,,1.0,115.0,2020.0,1.0,1.0,1.0
2,f000950527a6feb6ff31ae2cda4757703f52e853bb38b67e,1,17,318,0,2,2,32,201,0,...,,,,,1.0,115.0,2013.0,,,
3,47645761dc56bb8cabcff709f67c168821d276310d6e5210,10,10,138,0,5,1,62,124,0,...,,,,,1.0,113.0,2014.0,,,
4,516ab81418ed215dcd9df1915f3a2ae8098e3589b042ee45,12,17,329,9,18,1,70,280,0,...,,,,,1.0,128.0,2015.0,,,


In [353]:
features['addition_nan_num'] = features.isnull().sum(axis=1)
# 缺失值分桶
features['addition_nan_num_bin'] = 1
features.loc[features['addition_nan_num'] >= 8,'addition_nan_num_bin'] = 2
features.loc[features['addition_nan_num'] >= 10,'addition_nan_num_bin'] = 3
features.loc[features['addition_nan_num'] >= 11,'addition_nan_num_bin'] = 4
features.loc[features['addition_nan_num'] >= 26,'addition_nan_num_bin'] = 5
print("addition_nan_num 分桶完毕 ......... ")

addition_nan_num 分桶完毕 ......... 


In [354]:
cat_features.append('addition_nan_num_bin')

In [355]:
has_cols = ['has_other_info','has_news_info','HAS_TAX','has_change_info','HAS_REPORT']
features[has_cols] = features[has_cols].fillna(0)

In [356]:
code_map = {'中立':2,'消极':1,'积极':3,'-1':-1}
cols = ['positive_negtive_mode','positive_negtive_last']
features[cols] = features[cols].fillna('-1')
for col in tqdm(cols):
    features[col] = features[col].map(code_map)

100%|██████████| 2/2 [00:00<00:00, 208.69it/s]


In [357]:
features = features.fillna(-1)

In [358]:
features['new_empnum'] = (features['empnum'] + features['EMPNUM'] + 1)/2.0

In [359]:
for cat_col in tqdm(cat_features):
    features[cat_col] = features[cat_col].astype('category').cat.codes

100%|██████████| 157/157 [00:00<00:00, 852.74it/s]


In [360]:
features.describe()

Unnamed: 0,oplocdistrict,industryphy,industryco,enttype,enttypeitem,state,orgid,jobid,adbusign,townsign,...,has_news_info,has_change_info,bgxmdm,bgrq,has_other_info,legal_judgment_num,has_legal_judgment,addition_nan_num,addition_nan_num_bin,new_empnum
count,24865.0,24865.0,24865.0,24865.0,24865.0,24865.0,24865.0,24865.0,24865.0,24865.0,...,24865.0,24865.0,24865.0,24865.0,24865.0,24865.0,24865.0,24865.0,24865.0,24865.0
mean,5.764327,13.29069,229.647738,6.130505,4.112407,1.141283,32.801689,196.97776,0.002815,0.578444,...,0.037281,0.350935,2.930263,707.473638,0.07593,-0.325317,0.116308,27.81118,3.492821,3.736014
std,3.519292,2.166892,54.332704,7.378857,5.241884,0.382233,18.623124,127.305216,0.052985,0.493818,...,0.189454,0.477272,6.443869,963.526183,0.264891,11.407953,0.42842,11.19453,0.86549,19.400637
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,-1.0,0.0,-1.0,0.0,0.0,0.0,0.0
25%,3.0,12.0,199.0,0.0,0.0,1.0,18.0,89.0,0.0,0.0,...,0.0,0.0,0.0,-1.0,0.0,-1.0,0.0,15.0,3.0,1.5
50%,6.0,14.0,238.0,0.0,3.0,1.0,32.0,162.0,0.0,1.0,...,0.0,0.0,0.0,-1.0,0.0,-1.0,0.0,34.0,4.0,2.0
75%,8.0,14.0,260.0,16.0,5.0,1.0,44.0,323.0,0.0,1.0,...,0.0,1.0,4.0,2017.0,0.0,-1.0,0.0,37.0,4.0,3.5
max,15.0,19.0,345.0,16.0,31.0,5.0,77.0,433.0,1.0,1.0,...,1.0,1.0,35.0,2020.0,1.0,959.0,2.0,37.0,4.0,1296.5


In [361]:
features.to_csv('./features/lgb_features.csv',index=False)

In [375]:

import warnings
import lightgbm as lgb
import catboost as cab
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold,GridSearchCV,ParameterGrid
from sklearn.metrics import f1_score, precision_recall_fscore_support,make_scorer
from matplotlib import pyplot as plt


In [376]:
warnings.filterwarnings("ignore")


In [377]:
cat_features

['oplocdistrict',
 'industryphy',
 'industryco',
 'enttype',
 'enttypeitem',
 'state',
 'orgid',
 'jobid',
 'adbusign',
 'townsign',
 'regtype',
 'compform',
 'opform',
 'venind',
 'oploc',
 'enttypegb',
 'oplocdistrict_industryphy',
 'oplocdistrict_industryco',
 'oplocdistrict_enttype',
 'oplocdistrict_enttypeitem',
 'oplocdistrict_state',
 'oplocdistrict_orgid',
 'oplocdistrict_jobid',
 'oplocdistrict_adbusign',
 'oplocdistrict_townsign',
 'oplocdistrict_regtype',
 'oplocdistrict_compform',
 'oplocdistrict_opform',
 'oplocdistrict_venind',
 'oplocdistrict_oploc',
 'oplocdistrict_enttypegb',
 'industryphy_industryco',
 'industryphy_enttype',
 'industryphy_enttypeitem',
 'industryphy_state',
 'industryphy_orgid',
 'industryphy_jobid',
 'industryphy_adbusign',
 'industryphy_townsign',
 'industryphy_regtype',
 'industryphy_compform',
 'industryphy_opform',
 'industryphy_venind',
 'industryphy_oploc',
 'industryphy_enttypegb',
 'industryco_enttype',
 'industryco_enttypeitem',
 'industryco

In [378]:
features = pd.read_csv('./features/lgb_features.csv')
entprise_info = pd.read_csv('./data/train/entprise_info.csv')
data = pd.merge(features, entprise_info, how='left', on='id')
data[cat_features] = data[cat_features].astype(int)
# print(data.max())
train = data[data.label.notna()]
test = data[data.label.isnull()]

In [379]:
train_data, train_labels = train.drop(
        ['id', 'label'], axis=1), train['label']
test_data = test.drop(
        ['label'], axis=1)

In [380]:
def cross_val(model, train_data, train_labels, n_splits = 5):
    f1_scores = []
    sk = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=2022)
    for train, test in sk.split(train_data, train_labels):
        x_train = train_data.iloc[train]
        y_train = train_labels.iloc[train]
        x_valid = train_data.iloc[test]
        y_valid = train_labels.iloc[test]

        model.fit(x_train, y_train,eval_set=(x_valid, y_valid))
        y_pred = model.predict(x_valid)
        f1_score_k = f1_score(y_pred.round(),y_valid)
        f1_scores.append(f1_score_k)
    return sum(f1_scores)/n_splits

In [381]:
def cat_model(loss_function, lr, max_depth, l2_leaf_reg,cat_features):
    clf = cab.CatBoostClassifier(iterations=200,
                                 learning_rate=lr,
                                 depth=max_depth,
                                 loss_function = loss_function,
                                 l2_leaf_reg = l2_leaf_reg,
                                 silent=True,
                                 thread_count=8,
                                 task_type='GPU',
                                 cat_features=cat_features,
                                 early_stopping_rounds = 200,
                                 leaf_estimation_iterations = 10,
                                 )
    return clf

In [396]:
def catboost_GridSearchCV(train_data, train_labels, test_data, params, cat_features, n_splits=5):
    ps = {'f1':0,
          'param': [],
          'best_model':None,
    }
    for prms in tqdm(list(ParameterGrid(params)), ascii=True, desc='Params Tuning:\n'):
        print('cat_serachParm 搜索最佳参数 .......',prms)
        clf = cat_model(prms['loss_function'],prms['learning_rate'],prms['depth'],prms['l2_leaf_reg'],cat_features)             
        f1 = cross_val(clf,train_data, train_labels, n_splits=5)
        if acc>ps['f1']:
            ps['f1'] = f1
            ps['param'] = prms
            ps['best_model'] = clf
            print('f1: '+str(ps['f1']))
            print('Params: '+str(ps['param']))
    print('f1: '+str(ps['f1']))
    print('Params: '+str(ps['param']))
    return ps['best_model']

In [394]:
params = {
          'depth': [4, 5, 6,7,8],
          'learning_rate': [0.03, 0.035, 0.040, 0.045,0.05,0.055,0.06,0.065],
          'loss_function':  ['Logloss'],
          'l2_leaf_reg': [10],
         }
print('cat_serachParm 搜索最佳参数 .......',params)


cat_serachParm 搜索最佳参数 ....... {'depth': [4, 5, 6, 7, 8], 'learning_rate': [0.03, 0.035, 0.04, 0.045, 0.05, 0.055, 0.06, 0.065], 'loss_function': ['Logloss', 'CrossEntropy'], 'l2_leaf_reg': array([1.00000000e-20, 3.16227766e-20, 1.00000000e-19])}


In [395]:

model = catboost_GridSearchCV(train_data,train_labels,test_data,params,cat_features)

Params Tuning::   0%|          | 0/240 [00:00<?, ?it/s]cat_serachParm 搜索最佳参数 ....... {'depth': 4, 'l2_leaf_reg': 1e-20, 'learning_rate': 0.03, 'loss_function': 'Logloss'}
Params Tuning::   0%|          | 0/240 [00:20<?, ?it/s]


KeyboardInterrupt: 