In [1]:
import pandas as pd
from catboost import CatBoostClassifier , Pool
from sklearn.model_selection import train_test_split
import numpy as np
from catboost.utils import eval_metric

In [2]:
# df = pd.read_csv('data.csv')

In [3]:
# dataset:
# https://cogcomp.seas.upenn.edu/Data/QA/QC/
# Training set 1(1000 labeled questions)
# Load data
with open('data.csv', 'r') as f:
    lines = f.readlines()
    df = pd.DataFrame()
    for line in lines:
        # find first : position in line
        pos = line.find(':')
        # split line by :
        # first element is ENTY
        # second element is DESC
        row = {'category': line[:pos], 'text': line[pos+1:].strip()}
        df = df.append(row, ignore_index=True)

In [4]:
filter_ = ~(df.isin([np.nan, np.inf, -np.inf]).any(axis=1))
df = df.loc[filter_]
X_tr, X_eval = train_test_split(df, test_size=0.5)

y_tr = X_tr.category
y_eval = X_eval.category

features = [col_name for col_name in X_tr.columns]
cat_features = [col_name for col_name in features if X_tr[col_name].dtype == 'object']

train_dataset = Pool(X_tr, y_tr, feature_names=list(X_tr.columns), cat_features=cat_features)

model_params = {
    # 'iterations': 500, 
    'loss_function': 'MultiClass', 
    'train_dir': 'crossentropy',
    'allow_writing_files': False,
    'random_seed': 42,
    'task_type': "GPU",    
}

model = CatBoostClassifier(**model_params)

In [5]:
model.fit(train_dataset, verbose=True, plot=True, use_best_model=True)

MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))

You should provide test set for use best model. use_best_model parameter has been switched to false value.


Learning rate set to 0.056347
0:	learn: 1.5640667	total: 5.35ms	remaining: 5.34s
1:	learn: 1.3891520	total: 10.2ms	remaining: 5.07s
2:	learn: 1.2480291	total: 14.5ms	remaining: 4.83s
3:	learn: 1.1307513	total: 18.7ms	remaining: 4.67s
4:	learn: 1.0312062	total: 23.2ms	remaining: 4.61s
5:	learn: 0.9445521	total: 27.9ms	remaining: 4.62s
6:	learn: 0.8692775	total: 32.3ms	remaining: 4.58s
7:	learn: 0.8034062	total: 37ms	remaining: 4.58s
8:	learn: 0.7440964	total: 41.2ms	remaining: 4.53s
9:	learn: 0.6908525	total: 45.7ms	remaining: 4.53s
10:	learn: 0.6427841	total: 50.2ms	remaining: 4.52s
11:	learn: 0.5992256	total: 55.1ms	remaining: 4.54s
12:	learn: 0.5597764	total: 59.3ms	remaining: 4.5s
13:	learn: 0.5246897	total: 63.4ms	remaining: 4.46s
14:	learn: 0.4920057	total: 67.6ms	remaining: 4.44s
15:	learn: 0.4617532	total: 71.9ms	remaining: 4.42s
16:	learn: 0.4341418	total: 76.7ms	remaining: 4.43s
17:	learn: 0.4094401	total: 81.1ms	remaining: 4.42s
18:	learn: 0.3855067	total: 86ms	remaining: 4.4

163:	learn: 0.0123956	total: 754ms	remaining: 3.84s
164:	learn: 0.0122888	total: 758ms	remaining: 3.84s
165:	learn: 0.0121831	total: 764ms	remaining: 3.84s
166:	learn: 0.0120885	total: 769ms	remaining: 3.83s
167:	learn: 0.0119954	total: 773ms	remaining: 3.83s
168:	learn: 0.0119039	total: 778ms	remaining: 3.83s
169:	learn: 0.0118052	total: 783ms	remaining: 3.82s
170:	learn: 0.0117135	total: 787ms	remaining: 3.81s
171:	learn: 0.0116261	total: 791ms	remaining: 3.81s
172:	learn: 0.0115371	total: 795ms	remaining: 3.8s
173:	learn: 0.0114495	total: 799ms	remaining: 3.79s
174:	learn: 0.0113658	total: 803ms	remaining: 3.79s
175:	learn: 0.0112749	total: 807ms	remaining: 3.78s
176:	learn: 0.0111910	total: 812ms	remaining: 3.77s
177:	learn: 0.0111110	total: 816ms	remaining: 3.77s
178:	learn: 0.0110239	total: 820ms	remaining: 3.76s
179:	learn: 0.0109388	total: 824ms	remaining: 3.75s
180:	learn: 0.0108550	total: 828ms	remaining: 3.75s
181:	learn: 0.0107724	total: 833ms	remaining: 3.74s
182:	learn: 0

350:	learn: 0.0047611	total: 1.55s	remaining: 2.87s
351:	learn: 0.0047438	total: 1.56s	remaining: 2.87s
352:	learn: 0.0047277	total: 1.56s	remaining: 2.87s
353:	learn: 0.0047153	total: 1.57s	remaining: 2.86s
354:	learn: 0.0046984	total: 1.57s	remaining: 2.85s
355:	learn: 0.0046826	total: 1.58s	remaining: 2.85s
356:	learn: 0.0046669	total: 1.58s	remaining: 2.85s
357:	learn: 0.0046517	total: 1.58s	remaining: 2.84s
358:	learn: 0.0046362	total: 1.59s	remaining: 2.83s
359:	learn: 0.0046213	total: 1.59s	remaining: 2.83s
360:	learn: 0.0046094	total: 1.6s	remaining: 2.83s
361:	learn: 0.0045933	total: 1.6s	remaining: 2.82s
362:	learn: 0.0045781	total: 1.6s	remaining: 2.81s
363:	learn: 0.0045621	total: 1.61s	remaining: 2.81s
364:	learn: 0.0045477	total: 1.61s	remaining: 2.81s
365:	learn: 0.0045333	total: 1.62s	remaining: 2.8s
366:	learn: 0.0045231	total: 1.62s	remaining: 2.8s
367:	learn: 0.0045073	total: 1.63s	remaining: 2.79s
368:	learn: 0.0044927	total: 1.63s	remaining: 2.79s
369:	learn: 0.004

542:	learn: 0.0028399	total: 2.35s	remaining: 1.98s
543:	learn: 0.0028342	total: 2.36s	remaining: 1.98s
544:	learn: 0.0028280	total: 2.37s	remaining: 1.97s
545:	learn: 0.0028216	total: 2.37s	remaining: 1.97s
546:	learn: 0.0028159	total: 2.37s	remaining: 1.96s
547:	learn: 0.0028102	total: 2.38s	remaining: 1.96s
548:	learn: 0.0028043	total: 2.38s	remaining: 1.96s
549:	learn: 0.0027980	total: 2.38s	remaining: 1.95s
550:	learn: 0.0027923	total: 2.39s	remaining: 1.95s
551:	learn: 0.0027861	total: 2.39s	remaining: 1.94s
552:	learn: 0.0027799	total: 2.4s	remaining: 1.94s
553:	learn: 0.0027753	total: 2.4s	remaining: 1.93s
554:	learn: 0.0027692	total: 2.4s	remaining: 1.93s
555:	learn: 0.0027643	total: 2.41s	remaining: 1.92s
556:	learn: 0.0027581	total: 2.41s	remaining: 1.92s
557:	learn: 0.0027521	total: 2.42s	remaining: 1.92s
558:	learn: 0.0027460	total: 2.42s	remaining: 1.91s
559:	learn: 0.0027403	total: 2.43s	remaining: 1.91s
560:	learn: 0.0027346	total: 2.43s	remaining: 1.9s
561:	learn: 0.00

736:	learn: 0.0019995	total: 3.16s	remaining: 1.13s
737:	learn: 0.0019963	total: 3.17s	remaining: 1.13s
738:	learn: 0.0019929	total: 3.17s	remaining: 1.12s
739:	learn: 0.0019900	total: 3.18s	remaining: 1.12s
740:	learn: 0.0019871	total: 3.18s	remaining: 1.11s
741:	learn: 0.0019841	total: 3.19s	remaining: 1.11s
742:	learn: 0.0019812	total: 3.19s	remaining: 1.1s
743:	learn: 0.0019782	total: 3.19s	remaining: 1.1s
744:	learn: 0.0019750	total: 3.2s	remaining: 1.09s
745:	learn: 0.0019718	total: 3.2s	remaining: 1.09s
746:	learn: 0.0019689	total: 3.21s	remaining: 1.09s
747:	learn: 0.0019658	total: 3.21s	remaining: 1.08s
748:	learn: 0.0019627	total: 3.21s	remaining: 1.08s
749:	learn: 0.0019595	total: 3.22s	remaining: 1.07s
750:	learn: 0.0019564	total: 3.22s	remaining: 1.07s
751:	learn: 0.0019534	total: 3.23s	remaining: 1.06s
752:	learn: 0.0019506	total: 3.23s	remaining: 1.06s
753:	learn: 0.0019474	total: 3.23s	remaining: 1.05s
754:	learn: 0.0019446	total: 3.24s	remaining: 1.05s
755:	learn: 0.00

920:	learn: 0.0015538	total: 3.96s	remaining: 340ms
921:	learn: 0.0015520	total: 3.97s	remaining: 336ms
922:	learn: 0.0015502	total: 3.97s	remaining: 331ms
923:	learn: 0.0015481	total: 3.98s	remaining: 327ms
924:	learn: 0.0015461	total: 3.98s	remaining: 323ms
925:	learn: 0.0015444	total: 3.98s	remaining: 318ms
926:	learn: 0.0015426	total: 3.99s	remaining: 314ms
927:	learn: 0.0015412	total: 3.99s	remaining: 310ms
928:	learn: 0.0015392	total: 4s	remaining: 305ms
929:	learn: 0.0015373	total: 4s	remaining: 301ms
930:	learn: 0.0015354	total: 4.01s	remaining: 297ms
931:	learn: 0.0015336	total: 4.01s	remaining: 293ms
932:	learn: 0.0015321	total: 4.02s	remaining: 288ms
933:	learn: 0.0015300	total: 4.02s	remaining: 284ms
934:	learn: 0.0015282	total: 4.03s	remaining: 280ms
935:	learn: 0.0015263	total: 4.03s	remaining: 276ms
936:	learn: 0.0015244	total: 4.04s	remaining: 271ms
937:	learn: 0.0015227	total: 4.04s	remaining: 267ms
938:	learn: 0.0015214	total: 4.04s	remaining: 263ms
939:	learn: 0.0015

<catboost.core.CatBoostClassifier at 0x7f77933a6550>

In [6]:
# evaluate model score
response = ''
pred = model.predict(X_eval)
params = model.get_params()
response += str(params)+'\n'
#response += '\n'+params['loss_function']+' loss: '+ str(eval_metric(y_eval.to_numpy(), pred, params['loss_function']))
response += '\nFitted: '+str(model.is_fitted())
response += '\nModel score:\n'+str(model.score(X_tr,y_tr))
response += '\nFeature importance:'
try:
    importance = model.get_feature_importance()
    for i in range(len(model.feature_names_)):
        response += '\n'+str(np.round(importance[i],2)) + ' ' + model.feature_names_[i]
except Exception as e:
    response += '\n'+str(e)
print(response)

{'loss_function': 'MultiClass', 'random_seed': 42, 'train_dir': 'crossentropy', 'allow_writing_files': False, 'task_type': 'GPU'}

Fitted: True
Model score:
1.0
Feature importance:
100.0 category
0.0 text


In [7]:
# Get first n rows and predict
n = 3
pred = model.predict(X_eval.head(n))
for i in range(n):
    print(X_eval.iloc[i]['text'], pred[i])

animal What is Mississippi 's state animal ? ['ENTY']
def What does the name Shawn mean ? ['DESC']
ind Who killed Gandhi ? ['HUM']


In [8]:
X_eval.head(n)

Unnamed: 0,category,text
817,ENTY,animal What is Mississippi 's state animal ?
750,DESC,def What does the name Shawn mean ?
13,HUM,ind Who killed Gandhi ?
