In [None]:
import optuna

import pandas as pd
import numpy as np
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score
)
from optuna import Trial
from optuna.samplers import TPESampler
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold
from sklearn.tree import DecisionTreeClassifier
from xgboost import XGBRegressor, XGBClassifier
from catboost import CatBoostRegressor
from lightgbm import LGBMRegressor, LGBMClassifier
from sklearn.preprocessing import LabelEncoder
from pytorch_tabnet.tab_model import TabNetClassifier, TabNetRegressor
from pytorch_tabnet.multitask import TabNetMultiTaskClassifier


In [None]:
seed = 42
iteration = 3000
threshold = 0.055

In [None]:
df_train = pd.read_csv("train.csv")
df_test = pd.read_csv("submission_original.csv").drop(columns='id')
df_all = pd.concat([df_train, df_test]).reset_index(drop=True)

In [None]:
# inquiry_type

df_all.loc[(df_all['product_category'] == 'others') | (df_all['product_category'] == 'etc.'), ['product_category']] = 'other'
df_all.loc[(df_all['inquiry_type'] == 'other_') | (df_all['inquiry_type'] == 'ETC.') | (df_all['inquiry_type'] == 'Etc.') | (df_all['inquiry_type'] == 'Other') | (df_all['inquiry_type'] == 'Others') | (df_all['inquiry_type'] == 'etc.'), ['inquiry_type']] = 'other'
df_all.loc[(df_all['inquiry_type'] == 'Sales inquiry'), ['inquiry_type']] = 'Sales Inquiry'
df_all.loc[(df_all['inquiry_type'] == 'Quotation or purchase consultation'), ['inquiry_type']] = 'Quotation or Purchase consultation'
df_all.loc[(df_all['inquiry_type'] == 'quotation_or_purchase_consultation'), ['inquiry_type']] = 'Quotation or Purchase consultation'
df_all.loc[(df_all['inquiry_type'] == 'Quotation or Purchase Consultation'), ['inquiry_type']] = 'Quotation or Purchase consultation'
df_all.loc[(df_all['inquiry_type'] == 'Purchase or Quotation'), ['inquiry_type']] = 'Quotation or Purchase consultation'
df_all.loc[(df_all['inquiry_type'] == 'Request for quotation or purchase'), ['inquiry_type']] = 'Quotation or Purchase consultation'

df_all.loc[(df_all['inquiry_type'] == 'Quotation or Purchase Consultation'), ['inquiry_type']] = 'Quotation or Purchase consultation'
df_all.loc[(df_all['inquiry_type'] == 'usage_or_technical_consultation'), ['inquiry_type']] = 'Usage or technical consultation'
df_all.loc[(df_all['inquiry_type'] == 'usage or technical consultation'), ['inquiry_type']] = 'Usage or technical consultation'
df_all.loc[(df_all['inquiry_type'] == 'Usage or Technical Consultation'), ['inquiry_type']] = 'Usage or technical consultation'
df_all.loc[(df_all['inquiry_type'] == 'technical_consultation'), ['inquiry_type']] = 'Usage or technical consultation'
df_all.loc[(df_all['customer_type'] == 'Specifier/ Influencer'), ['customer_type']] = 'Specifier / Influencer'
df_all.loc[(df_all['customer_type'] == 'End Customer'), ['customer_type']] = 'End-Customer'
df_all.loc[(df_all['customer_type'] == 'homeowner'), ['customer_type']] = 'Home Owner'
df_all.loc[(df_all['customer_type'] == 'Software/Solution Provider'), ['customer_type']] = 'Software / Solution Provider'
df_all.loc[((df_all['customer_type'] == 'other') | (df_all['customer_type'] == 'Etc.') | (df_all['customer_type'] == 'Other') | (df_all['customer_type'] == 'Others')), ['customer_type']] = 'other'

In [133]:
# customer_country

outlier_dict = {'fl': 'unitedstates', 'nevada': 'unitedstates', '85wmainstsuitec,canton,ga30114,': 'unitedstates', '230highlandave,suite531somervillema2143,': 'unitedstates',
                'ironhorsecustomsllc4443genellawaynorthlasvegas,nv89031,': 'unitedstates', '5301stevenscreekblvdsantaclaraca95051': 'unitedstates', '7700westsunriseblvdplantationfl33322': 'unitedstates', 'mo64108.': 'unitedstates', '3440rockefellerctwaldorf,md20602': 'unitedstates',
                'ca91915-6002': 'unitedstates', '2877prospectrd,fortlauderdale,fl33309': 'unitedstates', '9landsdownestreetbostonma2215': 'unitedstates', '160gouldstste300,needhamheights,massachusetts02494needhamma2721': 'unitedstates', 'oneconstitutionroadbostonma2129': 'unitedstates',
                '36marginstpeabodyma1960': 'unitedstates', '400centrestreetnewtonma2458': 'unitedstates', '1755northbrownrd.suite200lawrenceville,ga30043': 'unitedstates', '6kimballlanelynnfieldma1940': 'unitedstates', '1275sistergroverdvanalstyne,tx75495': 'unitedstates',
                '3nassonavenue': 'unitedstates', '810nkingstondrpeoria,il61604-2145': 'unitedstates', 'mi48827': 'unitedstates', '45n200wwillardut84340': 'unitedstates', '1385nweberrd,romeoville,il60446,': 'unitedstates',
                'ma02780': 'unitedstates', '152bowdoinstreet': 'unitedstates', '9820huntersvillenc28078': 'unitedstates', '3000montourchurchroad': 'unitedstates', '1380enterprisedr': 'unitedstates',
                '6601carrollhighlandsrd': 'unitedstates', '275mishawumroad': 'unitedstates', '222maxinedr': 'unitedstates', '2900highway280suite250birminghamal35223': 'unitedstates', 'il60069': 'unitedstates',
                '100vestaviaparkwaybirminghamal35216': 'unitedstates', '1100itbprovout84602': 'unitedstates', 'us': 'unitedstates', '700patrooncreekblvdalbanyny12206': 'unitedstates', '252beechavenuemelrosema2176': 'unitedstates',
                'caymanislands': 'unitedkingdom', 'britishvirginislands': 'unitedkingdom', 'bermuda': 'unitedkingdom', 'anguilla': 'unitedkingdom',
                'hanoi': 'vietnam', 'turkey': 'türkiye', '1605ave.poncedeleón,suite400sanjuan,00909,puertorico': 'puertorico', 'riodejaneiro': 'brazil',
                'mumbai': 'india', 'gujarat': 'india', 'telangana': 'india', 'unitedstates': 'unitedstates', 'θέσηπέτσαβακαλοπούλουβιοπαπαλλήνης15351': 'greece', 'madrid': 'spain',
                'agost,alicante': 'spain', '9631libertyrdb,randallstown,md21133': 'unitedstates', 'benbrook,tx76126': 'unitedstates', 'sc29555': 'unitedstates', '6252egrantrdsuite150tucson,az85712': 'unitedstates',
                ',mo64802': 'unitedstates', 'in46601,ee.uu.': 'unitedstates', 'warren,oh44483.': 'unitedstates', '1600rosecransavebldg7ste101,manhattanbeach,ca90266,': 'unitedstates', 'or97128': 'unitedstates', 'dourados': 'brazil',
                'fozdeiguaçu-pravenidatancredoneves6731jardimitaipu': 'brazil', 'recife': 'brazil', 'sãopaulo,pinheiros.': 'brazil', 'cartagena': 'colombia', 'co80127': 'unitedstates', 'cuiabá': 'brazil',
                'colombiacartagena': 'colombia', 'ny': 'unitedstates', 'carrera11a94-46edificiochico3000piso3bogota': 'colombia', 'va22209': 'unitedstates', 'ohio': 'unitedstates', 'caceres': 'spain',
                'ga31405': 'unitedstates', 'il61615': 'unitedstates', 'squareat,2662gatewayrdsuite165,carlsbad,ca92009': 'unitedstates', 'mo63103': 'unitedstates', 'country': np.nan, 'sd57751': 'unitedstates',
                'lucknow': 'india', 'chennai': 'india', 'tx75098': 'unitedstates', 'mo64506': 'unitedstates', 'ga30039': 'unitedstates', 'nv89119': 'unitedstates', 'nm': 'unitedstates', 'nj': 'unitedstates', 'elche,alicante': 'spain',
                '48201': np.nan, 'rj': 'unitedstates', 'colombiac2:soloinformación': 'colombia', 'bari,italy': 'italy', 'stcloud,mn56303': 'unitedstates', '609medicalcenterdr.decatur,texas,76234': 'unitedstates', '275johnhancockrd.taunton,ma,02780': 'unitedstates', 'pune': 'india', 'firenze,italy': 'italy',
                '3centerplzsuite330boston,ma02108': 'unitedstates', '1919minnesotact,mississauga,onl5n': 'canada', '233southbeaudryavenuelosangelesca': 'unitedstates', 'tx': 'unitedstates', 'englewood,co': 'unitedstates', '5003ladyofthelakedrraleighnc': 'unitedstates', 'ne': 'unitedstates', 'aparecida': 'brazil',
                'colombia-cartagena': 'colombia', 'capãodacanoa': 'brazil', 'bucaramanga': 'colombia', 'sãopaulo': 'brazil', 'joãopessoa': 'brazil', 'centrodeproduçãoaudiovisual-sescsãopaulo': 'brazil', 'saopaulo': 'brazil', '2367n2650wfarrwestut84404': 'unitedstates',
                '4880stevenscreekblvdsanjoseca95129': 'unitedstates', '450riverchasepkwybirminghamal35186': 'unitedstates', '1397etowahdriveatlantaga30319': 'unitedstates', 'tn38120': 'unitedstates', '823gatewaycenterway,sandiego,ca92102': 'unitedstates', 'isleofman': 'unitedkingdom', 'odisha': 'india', 'belohorizonte': 'brazil',
                'barranquilla': 'colombia', '9716mcfarringdrfortworthtx76244': 'unitedstates', '1999sbascombavesuite1000campbellca95008': 'unitedstates', 'pobox112292carrolltontx75011': 'unitedstates', '30winterstreetboston,ma02108': 'unitedstates', '1036nnalderst': np.nan, '1hoagdr.': 'unitedstates',
                '545west111thstsuite7c': 'unitedstates', '21903ranierln': 'unitedstates', '200cabellouisvilleky40206': 'unitedstates', '9820northcrosscentercthuntersvillenc28078': 'unitedstates', '5555': np.nan, '1503lbjparkwaysuite700farmersbranchtx75234': 'unitedstates',
                '6111wplanopkwy#2100planotx75093': 'unitedstates', '100sabineriverdr.huttotx78634': 'unitedstates', '6300harryhinesblvdste.1400dallastx75235': 'unitedstates', '100firststsanfranciscoca94015': 'unitedstates', '899kiferroadsunnyvaleca94086': 'unitedstates', '829jacksonave': 'unitedstates', '3027westbayvillaave': 'unitedstates', '126diabloranchcourt': 'unitedstates',
                '1537rollinghillsdr.': 'unitedstates', '1715forestcovedrive,apt201': 'unitedstates', '6005commercedr.ste.300': 'unitedstates', '17215welbyway': 'unitedstates', '416panzanodrive': 'unitedstates', '15806longshipct': 'unitedstates', '5501headquartersdrplanotx75024': 'unitedstates', '825eastlakeavee': 'unitedstates',
                '602sabercreekdrive': 'unitedstates', 'vt05672': 'unitedstates', 'nd': 'unitedstates', 'br': 'unitedstates', 'kerela': 'india', 'uttarpradesh': 'india', 'anandvihardelhi': 'india', 'hyderabad': 'india',
                'bhilwara': 'india', '9110forestcrossingthewoodlandstx77381': 'unitedstates', 'arlington,ny12603': 'unitedstates', 'grancanariasplayadelingles': 'spain', '24082carmeldr': 'unitedstates', '41720thstnbirminghamal35203': 'unitedstates', '594howardstsanfranciscoca94105': 'unitedstates', '594howardstsanfranciscoca94106': 'unitedstates',
                '305johnstreet': 'unitedstates', 'cra.51#12sur-75,sanfernando,itagüi,medellín,guayabal,medellín,antioquia,colombia': 'colombia', 'il60191': 'unitedstates', 'zip98433': 'unitedstates', '2475washingtonblvdogdenut84401': 'unitedstates', '723svalleyway,palmer,ak99645': 'unitedstates', 'tx77024': 'unitedstates', '1skyviewdrfortworthtx76155': 'unitedstates',
                '1209derbyruncarrollton,tx75007carrolltontx75007': 'unitedstates', '9111cypresswatersblvddallastx75038': 'unitedstates', '2350airportfrwybedfordtx76022': 'unitedstates', '117bernalrdste70-422sanjosesanjoseca95119': 'unitedstates', '1808lithgowrdcelinatx75009': 'unitedstates', '1909forestknolldrhooveral35244': 'unitedstates', '106lakeviewdrhomewoodal35209': 'unitedstates', '65grovestreet,suite204watertown,ma02472': 'unitedstates',
                '13854lakesidecirsterlingheights,mi48313': 'unitedstates', 'indore': 'india', 'ks66217': 'unitedstates', 'nj07013': 'unitedstates', 'ny11358': 'unitedstates', 'gurgaon': 'india', '1112badgervinelanearlingtontx76005': 'unitedstates', '14700caribbeanway': 'unitedstates',
                '335leaguests,sulphursprings,tx75482': 'unitedstates', '300eastparkdrive': 'unitedstates', '6005thstreet': 'unitedstates', 'ma01851': 'unitedstates', '1001mainst': 'unitedstates', 'gurgaon': 'india', 'viadell\'informatica10-37036sanmartinobuonalbergo(veneto),italy': 'italy', '1100leeave,lafayette,la70501,': 'unitedstates',
                '7105northlandterracen,minneapolis,mn55428': 'unitedstates', 'ca95814': 'unitedstates', '6564headquartersdrplanotx75051': 'unitedstates', '1275sistergroverdvanalstyne,tx75495': 'unitedstates', 'ca': 'unitedstates', '101metlifeway,cary,nc,27513–met1': 'unitedstates', '230highlandave,suite531somervillema2143': 'unitedstates', 'netherlandsantilles': 'netherlands',
                'valencia': 'spain', '3100shoredrivevirginiabeach,va23451': 'unitedstates', 'nicolosi(ct),italy': 'italy', '450riverchasepkwybirminghamal35186': 'unitedstates', '7673hempstoncir': 'unitedstates', 'sc29555': 'unitedstates', '2047wsummerdaleave': 'unitedstates', 'mn55024': 'unitedstates', '2266palmerdr': 'unitedstates', '955powellavesw': 'unitedstates', 'a': np.nan,
               '603heritagedrivemountjuliet': 'unitedstates', '136sindustrialsalinemi48176': 'unitedstates', '463industrialparkrd,elysburg,pa17824,us': 'unitedstates', '3131briarparkdrsuite200houstontx77042': 'unitedstates', 'fl33013': 'unitedstates',
               'p.o.box291992,portorange,fl32129': 'unitedstates', 'busshed,6501redhookrd#201,nazareth,stthomas00802,u.s.virginislands': 'unitedstates', 'fl32703': 'unitedstates', 'fl33025': 'unitedstates', '1800congressave.,austin,tx78701': 'unitedstates',
               'w126n7449flintdrivemenomoneefallsva': 'unitedstates', 'manaus': 'brazil', 'sãopaulo,pinheiros': 'brazil', '8003rdave3rdfloor,newyork,ny10022': 'unitedstates', 'usvirginislands': 'unitedstates',
               '30cambriaave,pleasantville,nj08232': 'unitedstates', '9800s.monroestreetsandyut84070': 'unitedstates', '11330clayrdhoustontx77041': 'unitedstates', '9420westsamhoustonpkwynhoustontx77018': 'unitedstates', '2217houstondrivemelissatx75454': 'unitedstates',
               'fl33442.': 'unitedstates', 'fl33716': 'unitedstates', 'ca95618': 'unitedstates', '9420westsamhoustonpkwynhoustontx77018': 'unitedstates', 'newhampshire': 'unitedstates',
               '2266palmerdr.': 'unitedstates', '8454muirwoodtrlfortworthtx76137': 'unitedstates', '410baylorstaustintx78703': 'unitedstates', '9420westsamhoustonpkwynhoustontx77018': 'unitedstates', 'ca92078': 'unitedstates',
               '750floridacentralparkwaysuite#100longwood,fl32750': 'unitedstates', 'fl33404': 'unitedstates', '724wbusinessushighway60,dexter,mo63841,': 'unitedstates', '400centrestnewtonma2458': 'unitedstates', '77massachusettsavecambridgema2139': 'unitedstates',
               '210route4eastfl4 ': 'unitedstates', 'fl33404': 'unitedstates', 'viae.deamicis,23.90044carini(pa)': 'italy', 'ironhorsecustomsllc4443genellawaynorthlasvegas,nv89031': 'unitedstates', 'herndon,va20170': 'unitedstates',
               '3801ewillowst,longbeach,ca90815,ee.uu.': 'unitedstates', '12718kittentrail,hudson,fl34669': 'unitedstates', 'jacksonvilleflorida': 'unitedstates', '4278sbuffalostorchardpark,ny14127': 'unitedstates', '1156warmitageavesuiteb,chicago,il60614,us.': 'unitedstates',
               'turksandcaicosislands': 'unitedkingdom', 'fl33772': 'unitedstates', 
                }

for idx, value in enumerate(df_all['customer_country']):
    if type(value) != float:
        new_value = value.split('/')[-1].replace(' ', '').lower()
        if new_value == '':
            df_all.loc[idx, ['customer_country']]=np.nan
        elif '.com' in new_value:
            df_all.loc[idx, ['customer_country']] = np.nan
        elif '.net' in new_value:
            df_all.loc[idx, ['customer_country']] = np.nan
        elif 'usa' in new_value:
            df_all.loc[idx, ['customer_country']] = 'unitedstates'
        elif 'unitedstates' in new_value:
            df_all.loc[idx, ['customer_country']] = 'unitedstates'
        else:
            df_all.loc[idx, ['customer_country']] = new_value
        for key in outlier_dict:
            if key == new_value:
                df_all.loc[idx, ['customer_country']]= outlier_dict[key]          

# country = df_all['customer_country'].value_counts()
# ban_list = country[country==1].index
# df_all.loc[(df_all['customer_country'].isin(ban_list)), ['customer_country']]=np.nan

In [134]:
df_all = df_all.drop(columns=['customer_country.1'])
df_train = df_all[:len(df_train)]
df_test = df_all[len(df_train):]

In [135]:
def label_encoding(series: pd.Series) -> pd.Series:
    my_dict = {}
    series = series.fillna(-999).astype(str)
    my_dict['-999']=0
    unique = sorted(series.unique())
    if '-999' in unique:
        unique.remove('-999') 
    for idx, value in enumerate(unique):
        my_dict[value] = idx+1
    
    series = series.map(my_dict)

    return series

def label_decoding(original: pd.Series, series: pd.Series) -> pd.Series:
    su = sorted(series.unique())
    ou = sorted(original.astype(str).unique())
    if 0 in su:
        su.remove(0)
    if 'nan' in ou:
        ou.remove('nan')
    my_dict = dict(zip(su, ou))
    my_dict[0] = np.nan
    series = series.map(my_dict)
    return series
    

label_columns = [
    "customer_country",
    "business_subarea",
    "business_area",
    "business_unit",
    "customer_type",
    "enterprise",
    "customer_job",
    "inquiry_type",
    "product_category",
    "product_subcategory",
    "product_modelname",
    "customer_position",
    "response_corporate",
    "expected_timeline",
]

new_df = df_all.copy()

for label in label_columns:
    new_df[label] = label_encoding(df_all[label])

In [136]:
x_train, x_val, y_train, y_val = train_test_split(
    new_df[new_df["business_area"]!=0],
    new_df[new_df["business_area"]!=0]['business_area'],
    test_size = 0.2,
    shuffle = True,
    random_state = seed,
)

In [137]:
# # predict business_area

# # predict
# model_xbgclassifier = XGBClassifier(random_state=seed)

# x_ba = new_df[new_df['business_area']!=0].drop(columns = ['id_strategic_ver', 'it_strategic_ver', 'idit_strategic_ver', 'ver_cus', 'ver_pro', 'ver_win_rate_x', 'ver_win_ratio_per_bu', 'business_area', 'is_converted', 'com_reg_ver_win_rate', 'customer_type'])
# y_ba = new_df[new_df['business_area']!=0]['business_area']
# model_xbgclassifier.fit(x_ba, y_ba-1)

# value = model_xbgclassifier.predict(new_df.drop(columns = ['id_strategic_ver', 'it_strategic_ver', 'idit_strategic_ver', 'ver_cus', 'ver_pro', 'ver_win_rate_x', 'ver_win_ratio_per_bu', 'business_area', 'is_converted', 'com_reg_ver_win_rate', 'customer_type']))+1
# value = label_decoding(df_all['business_area'], pd.Series(value))
# df_all['business_area'] = value

# df_all['business_area']

In [138]:
# id_strategic_ver, it_strategic_ver, idit_strategic_ver

df_all.loc[((df_all['business_unit']=='ID') & (df_all['business_area'].isin(['corporate / office', 'hotel & accommodation']))), ['id_strategic_ver']]=1
df_all.loc[~((df_all['business_unit']=='ID') & (df_all['business_area'].isin(['corporate / office', 'hotel & accommodation']))), ['id_strategic_ver']]=0

df_all.loc[(df_all['business_unit']=='IT') & (df_all['business_area'].isin(['corporate / office', 'hotel & accommodation'])), ['it_strategic_ver']]=1
df_all.loc[~((df_all['business_unit']=='IT') & (df_all['business_area'].isin(['corporate / office', 'hotel & accommodation']))), ['it_strategic_ver']]=0

df_all.loc[((df_all['id_strategic_ver']==1) & (df_all['it_strategic_ver']==1)), ['idit_strategic_ver']]=1
df_all.loc[~((df_all['id_strategic_ver']==1) & (df_all['it_strategic_ver']==1)), ['idit_strategic_ver']]=0

In [139]:
# ver_cus

df_all.loc[((df_all['business_area'].isin(['corporate / office', 'hotel & accommodation', 'education', 'retail']))
             &(df_all['customer_type']=='End-Customer')), ['ver_cus']]=1
df_all.loc[~((df_all['business_area'].isin(['corporate / office', 'hotel & accommodation', 'education', 'retail']))
             &(df_all['customer_type']=='End-Customer')), ['ver_cus']]=0

In [140]:
# ver_pro

co = ['standard signage','high brightness signage','interactive signage','video wall signage','led signage','signage care solution','oled signage','special signage','uhd signage','smart tv signage','signage care solutions','digital signage','monitor signage,commercial tv,monior/monitor tv','monitor signage,monior/monitor tv','monitor signage,commercial tv,monior/monitor tv,projector,tv','monitor signage,commercial tv,monior/monitor tv,tv','monitor signage,commercial tv,solar,ess,monior/monitor tv,pc,projector,robot,system ac,ems,rac,chill','tv signage','signage','monitor signage,tv']
ha = ['hotel tv']
rt = ['led signage','video wall signage','high brightness signage','standard signage','oled signage','interactive signage','special signage','smart tv signage','uhd signage','tv signage','signage care solution','ultra stretch signage','monitor signage,monior/monitor tv','monitor signage,commercial tv,solar,ess,monior/monitor tv,pc,projector,robot,system ac,ems,rac,chill','monitor signage,commercial tv,monior/monitor tv,pc,tv,home beauty,audio/video','monitor signage,monior/monitor tv,tv,audio/video','signage','digital signage','signage care solutions','monitor signage,monior/monitor tv,vacuum cleaner,tv,home beauty,commercial tv,pc,refrigerator,styler']

df_all.loc[((df_all['business_area']=='corporate / office')&(df_all['product_category'].isin(co))|(df_all['business_area']=='hotel & accommodation')&(df_all['product_category'].isin(ha))|(df_all['business_area']=='retail')&(df_all['product_category'].isin(rt))), ['ver_pro']]=1
df_all.loc[~((df_all['business_area']=='corporate / office')&(df_all['product_category'].isin(co))|(df_all['business_area']=='hotel & accommodation')&(df_all['product_category'].isin(ha))|(df_all['business_area']=='retail')&(df_all['product_category'].isin(rt))), ['ver_pro']]=0

In [141]:
# ver_win_rate_x
vwrx = {'corporate / office': 0.0030792876608617, 'education': 0.0005719551277132, 'hotel & accommodation': 0.0007167734380046, 'hospital & health care': 6.044033666058328e-05,
        'special purpose': 0.0005432224318428, 'residential (home)': 0.0002983104051378, 'government department': 9.65915660650443e-05, 'retail': 0.0011827288932506,
        'factory': 0.0002153634176709, 'power plant / renewable energy': 2.3159381337232847e-06, 'transportation': 1.2765902883450302e-05, 'public facility': 2.5889552307882245e-05}

for vw in vwrx.keys():
    df_all.loc[(df_all['business_area']==vw), ['ver_win_rate_x']]=vwrx[vw]    

In [142]:
# ver_win_ratio_per_bu

def new_ratio(df: pd.DataFrame):
    solution = {
        'corporate / office':{'AS': 0.0268456375838926, 'ID': 0.0645661157024793, 'Solution': 0.0344827586206896},
        'education':{'ID': 0.048629531388152, 'AS': 0.0514705882352941},
        'special purpose':{'ID': 0.0640703517587939, 'AS': 0.022633744855967},
        'hospital & health care':{'ID': 0.1311475409836065, 'AS': 0.1285714285714285},
        'residential (home)':{'ID': 0.0354838709677419, 'AS': 0.0201207243460764},
        'government department':{'ID': 0.0794117647058823, 'AS': 0.0227272727272727},
        'retail':{'ID': 0.0498402555910543, 'AS': 0.0115830115830115},
        'hotel & accommodation':{'ID': 0.071345029239766},
        'factory':{'ID': 0.0369127516778523, 'AS': 0.0609243697478991},
        'power plant / renewable energy':{'ID': 0.2857142857142857, 'AS': 0.2272727272727272},
        'transportation':{'ID': 0.0535714285714285},
        'public facility':{'ID': 0.031578947368421, 'AS': 0.0287769784172661}
    }
    
    for business_area in solution.keys():
        for business_unit in solution[business_area].keys():
            df.loc[((df['business_area']==business_area) & (df['business_unit']==business_unit)), ['ver_win_ratio_per_bu']]=solution[business_area][business_unit]
    
    df.loc[(df['ver_win_ratio_per_bu'].isna()), ['ver_win_ratio_per_bu']]=-1

    return df

df_all = new_ratio(df_all)

In [143]:
# com_reg_ver_win_rate

mydict= {
    0.0037878787878787:	['AS',	'residential (home)',	['argentina', 'bolivia', 'brazil', 'chile', 'colombia', 'dominicanrepublic', 'elsalvador', 'honduras', 'mexico', 'nicaragua', 'panama', 'peru', 'stkitts', 'trinidadandtobago', 'unitedkingdom', 'uruguay']],
    0.0039370078740157:	['AS',	'corporate / office',	['antigua', 'argentina', 'bahamas', 'belize', 'brazil', 'chile', 'colombia', 'dominicanrepublic', 'elsalvador', 'guatemala', 'honduras', 'jamaica', 'mexico', 'panama', 'peru', 'uruguay']],
    0.0118577075098814:	['AS',	'special purpose',	['argentina', 'brazil', 'chile', 'colombia', 'dominicanrepublic', 'ecuador', 'elsalvador', 'guatemala', 'jamaica', 'mexico', 'panama', 'paraguay', 'peru', 'unitedkingdom', 'uruguay']],
    0.0135135135135135:	['ID',	'special purpose',	['albania', 'belgium', 'bosniaandherzegovina', 'bulgaria', 'croatia', 'cyprus', 'czech', 'france', 'germany', 'greece', 'ireland', 'italy', 'latvia', 'netherlands', 'poland', 'portugal', 'romania', 'slovenia', 'spain', 'switzerland', 'unitedkingdom']],
    0.0151515151515151:	['AS',	'education',	['argentina', 'brazil', 'chile', 'colombia', 'mexico', 'panama', 'peru']],
    0.0175438596491228:	['ID',	'education',	['france', 'germany', 'greece', 'italy', 'kosovo', 'poland', 'portugal', 'spain', 'unitedkingdom']],
    0.0181818181818181:	['AS',	'special purpose',	['australia', 'indonesia', 'japan', 'papuanewguinea', 'philippines', 'singapore', 'thailand', 'vietnam']],
    0.0199004975124378:	['ID',	'corporate / office',	['albania', 'belgium', 'bulgaria', 'croatia', 'cyprus', 'czech', 'denmark', 'france', 'germany', 'greece', 'hungary', 'ireland', 'italy', 'luxembourg', 'malta', 'netherlands', 'poland', 'portugal', 'romania', 'serbia', 'slovenia', 'spain', 'sweden', 'switzerland', 'unitedkingdom']],
    0.0311958405545927:	['ID',	'education',	['india', 'unitedstates']],
    0.032258064516129:	['ID',	'education',	['afghanistan', 'bahrain', 'canada', 'egypt', 'ghana', 'israel', 'jordan', 'kenya', 'kuwait', 'nigeria', 'oman', 'saudiarabia', 'southafrica', 'türkiye', 'u.a.e', 'unitedstates', 'yemen']],
    0.0327868852459016:	['ID',	'education',	['argentina', 'brazil', 'chile', 'colombia', 'ecuador', 'guatemala', 'jamaica', 'mexico', 'panama', 'peru', 'unitedkingdom']],
    0.0330578512396694:	['ID',	'special purpose',	['india']],
    0.0408163265306122:	['AS',	'corporate / office',	['afghanistan', 'democraticrepublicofthecongo', 'egypt', 'ethiopia', 'ghana', 'iraq', 'israel', 'jordan', 'kenya', 'mauritania', 'mozambique', 'nigeria', 'oman', 'pakistan', 'qatar', 'saudiarabia', 'southafrica', 'togo', 'türkiye', 'u.a.e', 'unitedrepublicoftanzania']],
    0.043103448275862:	['ID',	'special purpose',	['afghanistan', 'angola', 'armenia', 'botswana', 'centralafricanrepublic', 'egypt', 'ethiopia', 'ghana', 'iraq', 'israel', 'kenya', 'kuwait', 'lebanon', 'libya', 'nigeria', 'oman', 'pakistan', 'palestine', 'qatar', 'saudiarabia', 'southafrica', 'türkiye', 'u.a.e', 'zambia']],
    0.0434782608695652:	['IT',	'corporate / office',	['algeria', 'canada', 'egypt', 'kuwait', 'morocco', 'saudiarabia', 'türkiye', 'u.a.e', 'uganda', 'unitedstates']],
    0.0446428571428571:	['ID',	'corporate / office',	['australia', 'brazil', 'canada', 'chile', 'china', 'clinton,ok73601', 'colombia', 'costarica', 'dominicanrepublic', 'ecuador', 'guatemala', 'honduras', 'mexico', 'panama', 'peru', 'saudiarabia', 'unitedstates']],
    0.0485436893203883:	['AS',	'special purpose',	['egypt', 'iraq', 'jordan', 'kenya', 'lebanon', 'nigeria', 'oman', 'pakistan', 'saudiarabia', 'southafrica', 'sudan', 'swaziland', 'tunisia', 'türkiye', 'u.a.e', 'unitedrepublicoftanzania', 'zambia']],
    0.0575342465753424:	['ID',	'corporate / office',	['india']],
    0.0666666666666666:	['AS',	'corporate / office',	['australia', 'bangladesh', 'china', 'india', 'indonesia', 'papuanewguinea', 'philippines', 'singapore', 'srilanka', 'thailand', 'vietnam']],
    0.0888888888888888:	['AS',	'corporate / office',	['india']],
    0.004:	['AS',	'retail',	['antigua', 'argentina', 'bolivia', 'brazil', 'chile', 'colombia', 'ecuador', 'elsalvador', 'guatemala', 'honduras', 'mexico', 'nicaragua', 'panama', 'paraguay', 'peru', 'puertorico', 'trinidadandtobago', 'unitedkingdom', 'venezuela']],
    0.0109890109890109:	['AS',	'residential (home)',	['australia', 'bangladesh', 'indonesia', 'myanmar', 'newzealand', 'papuanewguinea', 'philippines', 'singapore', 'thailand', 'vietnam']],
    0.0169491525423728:	['ID',	'hotel & accommodation',	['albania', 'belgium', 'czech', 'denmark', 'france', 'germany', 'greece', 'hungary', 'italy', 'malta', 'netherlands', 'norway', 'poland', 'portugal', 'romania', 'serbia', 'spain', 'sweden', 'switzerland', 'unitedkingdom']],
    0.0172413793103448:	['ID',	'residential (home)',	['india']],
    0.0196078431372549:	['AS',	'retail',	['algeria', 'egypt', 'iraq', 'kenya', 'kuwait', 'mauritius', 'nigeria', 'oman', 'pakistan', 'palestine', 'qatar', 'saudiarabia', 'serbia', 'southafrica', 'türkiye', 'u.a.e', 'yemen']],
    0.0202020202020202:	['AS',	'factory',	['algeria', 'burkinafaso', 'egypt', 'iran', 'iraq', 'kenya', 'morocco', 'nigeria', 'oman', 'pakistan', 'saudiarabia', 'senegal', 'southafrica', 'türkiye', 'u.a.e']],
    0.0227272727272727:	['AS',	'retail',	['australia', 'bangladesh', 'indonesia', 'papuanewguinea', 'philippines', 'singapore', 'thailand', 'vietnam']],
    0.025:	['ID',	'transportation',	['belgium', 'bulgaria', 'france', 'germany', 'hungary', 'italy', 'poland', 'portugal', 'slovenia', 'spain', 'sweden', 'switzerland', 'unitedkingdom']],
    0.0289256198347107:	['ID',	'retail',	['india']],
    0.0289855072463768:	['AS',	'retail',	['india']],
    0.036036036036036:	['ID',	'factory',	['india']],
    0.037037037037037:	['AS',	'residential (home)',	['azerbaijan', 'belgium', 'bosniaandherzegovina', 'bulgaria', 'egypt', 'ethiopia', 'france', 'germany', 'greece', 'hungary', 'israel', 'oman', 'poland', 'portugal', 'saudiarabia', 'türkiye', 'u.a.e', 'unitedkingdom', 'unitedrepublicoftanzania', 'zimbabwe']],
    0.04:	['AS',	'government department',	['afghanistan', 'algeria', 'australia', 'bangladesh', 'brazil', 'chile', 'colombia', 'egypt', 'ghana', 'honduras', 'indonesia', 'iraq', 'israel', 'kenya', 'kuwait', 'libya', 'mauritius', 'mexico', 'morocco', 'nigeria', 'oman', 'pakistan', 'peru', 'philippines', 'qatar', 'saudiarabia', 'singapore', 'southafrica', 'srilanka', 'thailand', 'türkiye', 'u.a.e', 'uganda', 'yemen']],
    0.0416666666666666:	['AS',	'government department',	['albania', 'bulgaria', 'egypt', 'france', 'germany', 'greece', 'hongkong', 'hungary', 'iraq', 'italy', 'jordan', 'kuwait', 'netherlands', 'nigeria', 'poland', 'portugal', 'romania', 'saudiarabia', 'spain', 'sweden', 'switzerland', 'u.a.e', 'uganda', 'unitedkingdom']],
    0.0422535211267605:	['ID',	'public facility',	['argentina', 'brazil', 'chile', 'colombia', 'guatemala', 'mexico', 'peru', 'puertorico']],
    0.0476190476190476:	['ID',	'transportation',	['argentina', 'bahamas', 'brazil', 'chile', 'colombia', 'costarica', 'dominicanrepublic', 'ecuador', 'guatemala', 'mexico', 'panama', 'peru', 'puertorico', 'stmaarten', 'unitedkingdom']],
    0.0491803278688524:	['ID',	'retail',	['afghanistan', 'azerbaijan', 'egypt', 'ethiopia', 'georgia', 'ghana', 'israel', 'kenya', 'kuwait', 'libya', 'morocco', 'mozambique', 'namibia', 'nigeria', 'oman', 'pakistan', 'qatar', 'rwanda', 'saudiarabia', 'southafrica', 'sudan', 'tunisia', 'türkiye', 'u.a.e', 'uganda', 'unitedrepublicoftanzania', 'zimbabwe']],
    0.0496894409937888:	['AS',	'residential (home)',	['bahrain', 'burkinafaso', 'egypt', 'gabon', 'ghana', 'guinea', 'iran', 'iraq', 'israel', 'jordan', 'kuwait', 'lebanon', 'nigeria', 'saudiarabia', 'southafrica', 'syria', 'türkiye', 'u.a.e', 'unitedrepublicoftanzania']],
    0.0531914893617021:	['ID',	'government department',	['india']],
    0.0538922155688622:	['ID',	'residential (home)',	['argentina', 'brazil', 'chile', 'colombia', 'costarica', 'ecuador', 'elsalvador', 'mexico', 'panama', 'peru', 'trinidadandtobago', 'unitedkingdom']],
    0.0544217687074829:	['ID',	'government department',	['argentina', 'bahamas', 'brazil', 'chile', 'colombia', 'mexico', 'nicaragua', 'panama', 'peru', 'unitedkingdom']],
    0.0555555555555555:	['AS',	'education',	['bahrain', 'egypt', 'ghana', 'jordan', 'nigeria', 'saudiarabia', 'türkiye', 'u.a.e']],
    0.0677966101694915:	['AS',	'residential (home)',	['india']],
    0.0681818181818181:	['AS',	'hospital & health care',	['brazil', 'chile', 'colombia', 'ecuador', 'mexico', 'panama', 'peru', 'venezuela']],
    0.0695652173913043:	['ID',	'factory',	['argentina', 'brazil', 'chile', 'colombia', 'dominicanrepublic', 'elsalvador', 'mexico', 'panama', 'peru', 'unitedkingdom']],
    0.0714285714285714:	['AS',	'government department',	['india']],
    0.0732484076433121:	['ID',	'retail',	['argentina', 'bolivia', 'brazil', 'chile', 'colombia', 'dominicanrepublic', 'ecuador', 'guatemala', 'honduras', 'jamaica', 'mexico', 'panama', 'peru', 'puertorico', 'trinidadandtobago', 'unitedkingdom', 'venezuela']],
    0.0749486652977412:	['ID',	'corporate / office',	['argentina', 'bolivia', 'brazil', 'chile', 'colombia', 'dominicanrepublic', 'ecuador', 'elsalvador', 'guatemala', 'honduras', 'jamaica', 'mexico', 'nicaragua', 'panama', 'paraguay', 'peru', 'puertorico', 'unitedkingdom', 'venezuela']],
    0.075:	['ID',	'corporate / office',	['afghanistan', 'algeria', "coted'ivoire", 'egypt', 'ethiopia', 'ghana', 'iran', 'israel', 'kenya', 'kuwait', 'morocco', 'namibia', 'nigeria', 'oman', 'pakistan', 'qatar', 'saudiarabia', 'senegal', 'southafrica', 'tunisia', 'türkiye', 'u.a.e', 'uganda', 'unitedrepublicoftanzania', 'yemen']],
    0.0806916426512968:	['ID',	'special purpose',	['argentina', 'bolivia', 'brazil', 'chile', 'colombia', 'guatemala', 'jamaica', 'mexico', 'panama', 'peru', 'puertorico', 'unitedkingdom']],
    0.0833333333333333:	['ID',	'hospital & health care',	['india']],
    0.0843373493975903:	['ID',	'corporate / office',	['australia', 'china', 'fiji', 'hongkong', 'indonesia', 'japan', 'malaysia', 'maldives', 'philippines', 'singapore', 'srilanka', 'taiwan', 'vietnam']],
    0.0869565217391304:	['ID',	'government department',	['democraticrepublicofthecongo', 'egypt', 'ethiopia', 'israel', 'kuwait', 'nigeria', 'qatar', 'saudiarabia', 'southafrica', 'türkiye', 'u.a.e', 'zambia']],
    0.1052631578947368:	['ID',	'government department',	['bulgaria', 'czech', 'germany', 'ireland', 'italy', 'malta', 'poland', 'portugal', 'spain', 'unitedkingdom']],
    0.1136363636363636:	['ID',	'factory',	['azerbaijan', 'egypt', 'iran', 'israel', 'kenya', 'nigeria', 'saudiarabia', 'southafrica', 'türkiye', 'u.a.e', 'yemen']],
    0.1162790697674418:	['ID',	'special purpose',	['australia', 'china', 'hongkong', 'indonesia', 'japan', 'laos', 'malaysia', 'newzealand', 'papuanewguinea', 'philippines', 'singapore', 'srilanka', 'taiwan', 'thailand']],
    0.1184210526315789:	['ID',	'retail',	['australia', 'brunei', 'cambodia', 'china', 'hongkong', 'indonesia', 'japan', 'malaysia', 'maldives', 'nepal', 'philippines', 'singapore', 'taiwan', 'thailand', 'unitedstates', 'vietnam']],
    0.1186440677966101:	['ID',	'retail',	['australia', 'canada', 'hongkong', 'unitedstates']],
    0.1241217798594847:	['ID',	'hotel & accommodation',	['argentina', 'bahamas', 'brazil', 'chile', 'colombia', 'ecuador', 'guatemala', 'honduras', 'mexico', 'panama', 'peru', 'puertorico', 'saintlucia', 'unitedkingdom', 'unitedstates']],
    0.125:	['IT',	'retail',	['egypt', 'saudiarabia', 'southafrica', 'u.a.e']],
    0.1363636363636363:	['AS',	'factory',	['australia', 'china', 'indonesia', 'philippines', 'singapore', 'vietnam']],
    0.1470588235294117:	['AS',	'education',	['india']],
    0.1666666666666666:	['ID',	'residential (home)',	['canada', 'unitedstates']],
    0.1818181818181818:	['IT',	'hospital & health care',	['china', 'egypt', 'hongkong', 'indonesia', 'japan', 'malaysia', 'philippines', 'qatar', 'singapore', 'southafrica', 'taiwan', 'türkiye', 'u.a.e', 'vietnam', 'yemen']],
    0.2:	['IT',	'public facility',	['costarica', 'dominicanrepublic', 'guatemala']],
    0.2142857142857142:	['ID',	'education',	['australia', 'bangladesh', 'brunei', 'hongkong', 'indonesia', 'japan', 'myanmar', 'philippines', 'singapore', 'taiwan', 'thailand', 'vietnam']],
    0.2307692307692307:	['IT',	'hospital & health care',	['austria', 'bulgaria', 'cyprus', 'france', 'germany', 'italy', 'poland', 'spain', 'switzerland', 'unitedkingdom']],
    0.25:	['ID',	'power plant / renewable energy',	['argentina', 'brazil', 'chile', 'mexico', 'peru']],
    0.2692307692307692:	['ID',	'hospital & health care',	['argentina', 'brazil', 'chile', 'colombia', 'costarica', 'mexico', 'peru']],
    0.3333333333333333:	['ID',	'special purpose',	['brazil', 'chile', 'colombia', 'germany', 'india', 'mexico', 'peru', 'unitedstates']],
    0.3636363636363636:	['IT',	'residential (home)',	['colombia', 'costarica', 'ecuador', 'mexico', 'peru']],
    0.3902439024390244:	['ID',	'education',	['canada', 'unitedstates']],
    0.4:	['ID',	'public facility',	['canada', 'unitedstates']],
    0.4444444444444444:	['AS',	'power plant / renewable energy',	['mozambique', 'nigeria', 'qatar', 'senegal', 'türkiye', 'u.a.e']],
    0.4615384615384615:	['AS',	'hospital & health care',	['india']],
    0.5:	['IT',	'factory',	['canada', 'unitedstates']],
    0.6153846153846154:	['ID',	'government department',	['canada', 'israel', 'unitedstates']],
    0.6428571428571429:	['IT',	'hospital & health care',	['canada', 'unitedstates']],
    0.8333333333333334:	['ID',	'transportation',	['canada', 'unitedstates']],
    1.0:	['ID',	'power plant / renewable energy',	['australia', 'brazil', 'indonesia', 'malaysia', 'philippines', 'vietnam']],
}

for rate in mydict.keys():
    df_all.loc[(df_all['business_unit']==mydict[rate][0])&(df_all['business_area']==mydict[rate][1])&(df_all['customer_country'].isin(mydict[rate][2])), ['com_reg_ver_win_rate']] = rate

In [144]:
df_train = df_all[:len(df_train)].reset_index(drop=True)
df_test = df_all[len(df_train):].reset_index(drop=True)

df_train.to_csv('new_train.csv', index=False)
df_test.to_csv('new_test.csv', index=False)

In [145]:
label_columns = [
    "customer_country",
    "business_subarea",
    "business_area",
    "business_unit",
    "customer_type",
    "enterprise",
    "customer_job",
    "inquiry_type",
    "product_category",
    "product_subcategory",
    "product_modelname",
    "customer_position",
    "response_corporate",
    "expected_timeline",
]

df_all = pd.concat([df_train[label_columns], df_test[label_columns]])

for col in label_columns:
    df_all[col] = label_encoding(df_all[col])
    
for col in label_columns:  
    df_train[col] = df_all.iloc[: len(df_train)][col]
    df_test[col] = df_all.iloc[len(df_train) :][col]

In [146]:

# x_train, x_val, y_train, y_val = train_test_split(
#     df_train.drop("is_converted", axis=1),
#     df_train["is_converted"],
#     test_size = 0.2,
#     shuffle = True,
#     random_state = seed,
# )

mydict=dict(zip(df_train.columns,[x for x in range(len(df_train.columns))]))
cat_features = []

for label in label_columns:
    cat_features.append(mydict[label])

In [147]:
# sampler = TPESampler(seed=10)

# def objective(trial):

#     cbrm_param = {
#         'iterations':trial.suggest_int("iterations", 1000, 20000),
#         'od_wait':trial.suggest_int('od_wait', 500, 2300),
#         'learning_rate' : trial.suggest_float('learning_rate',0.01, 1),
#         'reg_lambda': trial.suggest_float('reg_lambda',1e-5,100),
#         'subsample': trial.suggest_float('subsample', 0,1),
#         'random_strength': trial.suggest_float('random_strength',10,50),
#         'depth': trial.suggest_int('depth',1, 15),
#         'min_data_in_leaf': trial.suggest_int('min_data_in_leaf',1,30),
#         'leaf_estimation_iterations': trial.suggest_int('leaf_estimation_iterations',1,15),
#         'bagging_temperature' :trial.suggest_float('bagging_temperature', 0.01, 100.00, log=True),
#         'colsample_bylevel':trial.suggest_float('colsample_bylevel', 0.4, 1.0),
#     }
    
#     model_cbrm = CatBoostRegressor(cat_features=cat_features, **cbrm_param)
#     model_cbrm = model_cbrm.fit(x_train.fillna(-1), y_train.astype(int), eval_set=[(x_val.fillna(-1), y_val.astype(int))], 
#                            verbose=0, early_stopping_rounds=25)
#     f1 = f1_score(y_val.astype(int), (model_cbrm.predict(x_val.fillna(-1)) > threshold).astype(int), labels=[0, 1])
#     return f1

# optuna_cbrm = optuna.create_study(direction='maximize', sampler=sampler)
# optuna_cbrm.optimize(objective, n_trials=25)

In [148]:
# model_xgb = XGBRegressor(
#     n_estimators = iteration,
#     eta = 0.01,
#     min_child_weight = 50,
#     max_depth = 10,
#     colsample_bytree = 0.9,
#     subsample = 0.9,
#     random_state = seed,
#     objective = "binary:logistic",
#     eval_metric = 'auc',
# )


In [149]:
# model_xgb.fit(x_train.fillna(-1).values, y_train.astype(int).values.reshape(-1,1))

# optuna_cbrm.fit(x_train.fillna(-1), y_train.astype(int))
# cbrm_trial = optuna_cbrm.best_trial
# cbrm_trial_params = cbrm_trial.params
# print('Best Trial: score {},\nparams {}'.format(cbrm_trial.value, cbrm_trial_params))

In [150]:
def get_clf_eval(y_test, y_pred=None):
    confusion = confusion_matrix(y_test, y_pred, labels=[0, 1])
    accuracy = accuracy_score(y_test, y_pred)
    precision = precision_score(y_test, y_pred, labels=[0, 1])
    recall = recall_score(y_test, y_pred)
    F1 = f1_score(y_test, y_pred, labels=[0, 1])

    print("오차행렬:\n", confusion)
    print("\n정확도: {:.4f}".format(accuracy))
    print("정밀도: {:.4f}".format(precision))
    print("재현율: {:.4f}".format(recall))
    print("F1: {:.4f}".format(F1))
    
    return confusion, accuracy, precision, recall, F1

In [151]:
# pred_xgb = model_xgb.predict(x_val.fillna(-1).values)
# pred_xgb = (pred_xgb > threshold).astype(int)
# pred_xgb

# get_clf_eval(y_val.astype(int), pred_xgb)

In [152]:
# print(sum(pred_xgb))
# print(len(pred_xgb))
# print(sum(y_val))

In [153]:
# x_test = df_test.drop(["is_converted"], axis=1)
# test_pred_xgb = model_xgb.predict(x_test.fillna(-1))
# test_pred_xgb = (test_pred_xgb > threshold).astype(int)
# test_pred_xgb
# print(f'True: {sum(test_pred_xgb)}') # True로 예측된 개수

In [154]:
# df_sub = pd.read_csv("submission_original.csv")
# df_sub["is_converted"] = test_pred_xgb

# df_sub.to_csv("submission.csv", index=False)

In [157]:
x = df_train.drop('is_converted', axis=1)
y = df_train['is_converted'].astype(int)
x_test = df_test.drop(["is_converted"], axis=1)
kf = StratifiedKFold(shuffle=True, random_state=seed)
mean_c = []
mean_a = []
mean_p = []
mean_r = []
mean_f = []
result = []

for idx, (t_index, v_index) in enumerate(kf.split(x, y)):
    x_train , x_valid = x.iloc[t_index], x.iloc[v_index]
    y_train, y_valid = y.iloc[t_index], y.iloc[v_index]
    model_xgb = XGBRegressor(
        n_estimators = iteration,
        eta = 0.01,
        min_child_weight = 50,
        max_depth = 10,
        colsample_bytree = 0.9,
        subsample = 0.9,
        random_state = seed,
        objective = "binary:logistic",
        eval_metric = 'auc',
    )
    model_xgb.fit(x_train.fillna(-1).values, y_train.astype(int).values.reshape(-1,1))
    pred_xgb = model_xgb.predict(x_valid.fillna(-1).values)
    pred_xgb = (pred_xgb > threshold).astype(int)
    pred_xgb
    
    print(f'\n########### {idx} fold validation result ##############\n')
    c, a, p ,r, f = get_clf_eval(y_valid.astype(int), pred_xgb)
    mean_c.append(c)
    mean_a.append(a)
    mean_p.append(p)
    mean_r.append(r)
    mean_f.append(f)
    
    test_pred_xgb = model_xgb.predict(x_test.fillna(-1))
    result.append(test_pred_xgb)

print(f'\n########### result ##############\n')
print(f'\n오차행렬:\n{np.mean(mean_c,axis=0).astype(int)}')
print(f'정확도:{np.round(np.mean(mean_a,axis=0),4)}')
print(f'정밀도:{np.round(np.mean(mean_p,axis=0),4)}')
print(f'재현률:{np.round(np.mean(mean_r,axis=0),4)}')
print(f'f1: {np.round(np.mean(mean_f,axis=0), 4)}')

result = np.mean(result, axis=0)
test_pred_xgb = (result > threshold).astype(int)
print(f'True: {sum(test_pred_xgb)}') # True로 예측된 개수


########### 0 fold validation result ##############

오차행렬:
 [[9866 1024]
 [  38  932]]

정확도: 0.9105
정밀도: 0.4765
재현율: 0.9608
F1: 0.6370

########### 1 fold validation result ##############

오차행렬:
 [[9814 1076]
 [  36  934]]

정확도: 0.9062
정밀도: 0.4647
재현율: 0.9629
F1: 0.6268

########### 2 fold validation result ##############

오차행렬:
 [[9755 1135]
 [  37  933]]

정확도: 0.9012
정밀도: 0.4512
재현율: 0.9619
F1: 0.6142

########### 3 fold validation result ##############

오차행렬:
 [[9832 1058]
 [  37  933]]

정확도: 0.9077
정밀도: 0.4686
재현율: 0.9619
F1: 0.6302

########### 4 fold validation result ##############

오차행렬:
 [[9801 1088]
 [  46  924]]

정확도: 0.9044
정밀도: 0.4592
재현율: 0.9526
F1: 0.6197

########### result ##############


오차행렬:
[[9813 1076]
 [  38  931]]
정확도:0.906
정밀도:0.464
재현률:0.96
f1: 0.6256
True: 2288


In [None]:
# #################### iter: 3000 ####################

# 오차행렬:
#  [[9862 1013]
#  [  57  928]]

# 정확도: 0.9098
# 정밀도: 0.4781
# 재현율: 0.9421
# F1: 0.6343

# 2279

# #################### iter: 8000 ####################

# 오차행렬:
#  [[9996  879]
#  [  63  922]]

# 정확도: 0.9206
# 정밀도: 0.5119
# 재현율: 0.9360
# F1: 0.6619

# True: 2095

# #################### iter: 10000 ####################

# 오차행렬:
#  [[10028   847]
#  [   67   918]]

# 정확도: 0.9229
# 정밀도: 0.5201
# 재현율: 0.9320
# F1: 0.6676