<a href="https://colab.research.google.com/github/deniskapel/autoskill/blob/main/catboost_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%bash
pip install catboost
pip install ipywidgets
jupyter nbextension enable --py widgetsnbextension

In [None]:
%%bash
mkdir data
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1lQW87pMibsYvweA65Ke3m8DpI_NWkLRi' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1lQW87pMibsYvweA65Ke3m8DpI_NWkLRi" -O data/vectorized_val.npy && rm -rf /tmp/cookies.txt
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1L8EGsbR40LxI6_BMdm_T-SGD-OtugOas' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1L8EGsbR40LxI6_BMdm_T-SGD-OtugOas" -O data/vectorized_train.npy && rm -rf /tmp/cookies.txt

In [3]:
import json
import pickle
from collections import Counter

import numpy as np
from catboost import CatBoostClassifier, Pool
from sklearn.metrics import f1_score, accuracy_score
from joblib import dump, load
# models
from catboost import CatBoostClassifier, Pool
from sklearn.multiclass import OneVsRestClassifier
# for multilabel classification
# metrics
from sklearn.metrics import f1_score, accuracy_score
from sklearn.metrics import classification_report
from tensorflow.keras.utils import to_categorical

In [4]:
Midas2ID = {
    "appreciation": 0, "command": 1, "comment": 2,"complaint": 3,
    "dev_command": 4, "neg_answer": 5, "open_question_factual": 6,
    "open_question_opinion": 7, "opinion": 8, "other_answers": 9,
    "pos_answer": 10, "statement": 11, "yes_no_question": 12,
}

ID2Midas = list(Midas2ID.keys())

Entity2ID = {
    'person': 0, 'location': 1, 'videoname': 2, 'organization': 3,
    'device': 4, 'sport': 5, 'duration': 6, 'number': 7, 'genre': 8,
    'sportteam': 9, 'position': 10, 'event': 11, 'softwareapplication': 12,
    'vehicle': 13, 'party': 14, 'year': 15, 'date': 16, 'gamename': 17,
    'songname': 18, 'bookname': 19}

ID2Entity = list(Entity2ID.keys())

In [5]:
with open('data/vectorized_train.npy', 'rb') as f:
    X_train = np.load(f)
    y_midas_train = np.load(f)
    y_entity_train = np.load(f)

with open('data/vectorized_val.npy', 'rb') as f:
    X_val = np.load(f)
    y_midas_val = np.load(f)
    y_entity_val = np.load(f)

In [6]:
X_train.shape, y_midas_train.shape, y_entity_train.shape

((179286, 1641), (179286,), (179286, 20))

In [7]:
X_val.shape, y_midas_val.shape, y_entity_val.shape

((39089, 1641), (39089,), (39089, 20))

In [8]:
def prediction_by_heuristic(
    probas: np.ndarray, top_n:int=1, num_classes=20) -> np.ndarray:
    """ extract top_n predictions from from given probabilities """
    preds = np.argsort(probas, axis=-1)[:,::-1][:,:top_n]
    preds = to_categorical(preds, num_classes=num_classes)
    
    if top_n > 1:
        preds = np.max(preds, axis=1)
    
    return preds

# Catboost

## SymmetricTree

### Midas

In [None]:
X_midas_train = Pool(np.float32(X_train), label=y_midas_train)
X_midas_val = Pool(np.float32(X_val), label=y_midas_val)

In [17]:
model_params = {
    'verbose': True,
    'random_seed': 42,
    'use_best_model': True,
    'devices':'0:1'
}

fit_params = {
    'use_best_model': True,
    'early_stopping_rounds': 5   
}

In [None]:
midas_clf = CatBoostClassifier(
    grow_policy='SymmetricTree', 
    loss_function='MultiClass', 
    eval_metric='Accuracy', 
    task_type='GPU', **model_params)

In [None]:
midas_clf.fit(
    X_midas_train, eval_set=X_midas_val, **fit_params)

In [None]:
midas_preds = midas_clf.predict(X_midas_val).squeeze()

In [None]:
Counter(midas_preds)

Counter({2: 272, 5: 44, 6: 5, 8: 22774, 10: 788, 11: 15205, 12: 1})

In [None]:
print(
    classification_report(y_midas_val, midas_preds, target_names=ID2Midas)
)

                       precision    recall  f1-score   support

         appreciation       0.00      0.00      0.00      1032
              command       0.00      0.00      0.00       615
              comment       0.29      0.02      0.04      4035
            complaint       0.00      0.00      0.00       826
          dev_command       0.00      0.00      0.00        54
           neg_answer       0.25      0.01      0.02      1204
open_question_factual       0.20      0.00      0.00       809
open_question_opinion       0.00      0.00      0.00       497
              opinion       0.41      0.75      0.53     12549
        other_answers       0.00      0.00      0.00       309
           pos_answer       0.41      0.07      0.12      4639
            statement       0.37      0.52      0.44     10875
      yes_no_question       0.00      0.00      0.00      1645

             accuracy                           0.40     39089
            macro avg       0.15      0.11      0.09 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


### Entity

In [9]:
X_entity_train = Pool(np.float32(X_train), label=y_entity_train)
X_entity_val = Pool(np.float32(X_val), label=y_entity_val)

In [None]:
entity_clf = CatBoostClassifier(
    grow_policy='SymmetricTree', 
    loss_function='MultiLogloss', 
    eval_metric='Accuracy',
    **model_params, task_type='GPU')

In [None]:
entity_clf.fit(X_entity_train, eval_set=X_entity_val, **fit_params)

Learning rate set to 0.114166
0:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 1m 15s	remaining: 20h 51m 31s
1:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 2m 25s	remaining: 20h 8m 6s
2:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 3m 33s	remaining: 19h 43m 6s
3:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 4m 46s	remaining: 19h 47m 53s
4:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 5m 52s	remaining: 19h 28m 51s
5:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 7m 3s	remaining: 19h 28m 40s
Stopped by overfitting detector  (5 iterations wait)

bestTest = 0.8457622349
bestIteration = 0

Shrink model to first 1 iterations.


<catboost.core.CatBoostClassifier at 0x7f8f95825a90>

In [None]:
entity_preds = entity_clf.predict(X_entity_val).squeeze()
entity_probas = entity_clf.predict_proba(X_entity_val)

In [None]:
print(
    classification_report(y_entity_val, entity_preds, target_names=ID2Entity)
)

                     precision    recall  f1-score   support

             person       0.00      0.00      0.00      1336
           location       0.00      0.00      0.00      1007
          videoname       0.00      0.00      0.00       797
       organization       0.00      0.00      0.00       690
             device       0.00      0.00      0.00       486
              sport       0.00      0.00      0.00       376
           duration       0.00      0.00      0.00       455
             number       0.00      0.00      0.00       455
              genre       0.00      0.00      0.00       339
          sportteam       0.00      0.00      0.00       199
           position       0.00      0.00      0.00       158
              event       0.00      0.00      0.00       159
softwareapplication       0.00      0.00      0.00       198
            vehicle       0.00      0.00      0.00       116
              party       0.00      0.00      0.00        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
print(
    classification_report(
        y_entity_val, 
        prediction_by_heuristic(entity_probas, top_n=1),
        target_names=ID2Entity)
)

                     precision    recall  f1-score   support

             person       0.03      0.93      0.07      1336
           location       0.02      0.07      0.03      1007
          videoname       0.03      0.01      0.02       797
       organization       0.05      0.02      0.03       690
             device       0.00      0.00      0.00       486
              sport       0.00      0.00      0.00       376
           duration       0.00      0.00      0.00       455
             number       0.00      0.00      0.00       455
              genre       0.00      0.00      0.00       339
          sportteam       0.00      0.00      0.00       199
           position       0.00      0.00      0.00       158
              event       0.00      0.00      0.00       159
softwareapplication       0.00      0.00      0.00       198
            vehicle       0.00      0.00      0.00       116
              party       0.00      0.00      0.00        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
print(
    classification_report(
        y_entity_val, 
        prediction_by_heuristic(entity_probas, top_n=3),
        target_names=ID2Entity)
)

                     precision    recall  f1-score   support

             person       0.03      0.99      0.07      1336
           location       0.03      0.95      0.05      1007
          videoname       0.03      0.65      0.05       797
       organization       0.02      0.52      0.04       690
             device       0.02      0.06      0.02       486
              sport       0.02      0.01      0.01       376
           duration       0.02      0.04      0.02       455
             number       0.02      0.03      0.02       455
              genre       0.00      0.00      0.00       339
          sportteam       0.00      0.00      0.00       199
           position       0.00      0.00      0.00       158
              event       0.00      0.00      0.00       159
softwareapplication       0.01      0.02      0.01       198
            vehicle       0.00      0.00      0.00       116
              party       0.00      0.00      0.00        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
for_ovr = {
    'verbose': True,
    'random_seed': 42,
    'use_best_model': False,
    'devices':'0:1',
    }

ovr_clf = OneVsRestClassifier(
    estimator=CatBoostClassifier(
        grow_policy='SymmetricTree', loss_function='MultiClass', 
        eval_metric='TotalF1', task_type='GPU', **for_ovr)
)

In [None]:
ovr_clf.fit(X_train, y_entity_train)

In [None]:
entity_preds = ovr_clf.predict(X_entity_val).squeeze()
entity_probas = ovr_clf.predict_proba(X_entity_val)

In [None]:
print(
    classification_report(y_entity_val, entity_preds, target_names=ID2Entity)
)

                     precision    recall  f1-score   support

             person       0.50      0.01      0.03      1336
           location       0.91      0.03      0.06      1007
          videoname       0.50      0.01      0.02       797
       organization       0.33      0.00      0.01       690
             device       0.55      0.02      0.05       486
              sport       0.71      0.01      0.03       376
           duration       0.77      0.09      0.17       455
             number       0.68      0.05      0.09       455
              genre       0.40      0.01      0.02       339
          sportteam       0.56      0.05      0.09       199
           position       0.00      0.00      0.00       158
              event       1.00      0.01      0.02       159
softwareapplication       0.00      0.00      0.00       198
            vehicle       0.75      0.08      0.14       116
              party       0.00      0.00      0.00        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
print(
    classification_report(
        y_entity_val, 
        prediction_by_heuristic(entity_probas, top_n=1),
        target_names=ID2Entity)
)

                     precision    recall  f1-score   support

             person       0.07      0.61      0.13      1336
           location       0.06      0.52      0.11      1007
          videoname       0.08      0.38      0.14       797
       organization       0.07      0.36      0.12       690
             device       0.09      0.48      0.15       486
              sport       0.11      0.43      0.17       376
           duration       0.06      0.39      0.11       455
             number       0.07      0.31      0.12       455
              genre       0.13      0.34      0.19       339
          sportteam       0.14      0.36      0.21       199
           position       0.08      0.13      0.10       158
              event       0.13      0.13      0.13       159
softwareapplication       0.07      0.16      0.10       198
            vehicle       0.10      0.36      0.16       116
              party       0.18      0.18      0.18        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
print(
    classification_report(
        y_entity_val, 
        prediction_by_heuristic(entity_probas, top_n=3),
        target_names=ID2Entity)
)

  _warn_prf(average, modifier, msg_start, len(result))


                     precision    recall  f1-score   support

             person       0.05      0.93      0.09      1336
           location       0.04      0.85      0.07      1007
          videoname       0.04      0.75      0.08       797
       organization       0.04      0.69      0.07       690
             device       0.05      0.65      0.09       486
              sport       0.07      0.68      0.13       376
           duration       0.03      0.62      0.06       455
             number       0.03      0.55      0.06       455
              genre       0.06      0.64      0.11       339
          sportteam       0.08      0.64      0.14       199
           position       0.05      0.51      0.09       158
              event       0.06      0.37      0.11       159
softwareapplication       0.06      0.41      0.10       198
            vehicle       0.06      0.59      0.11       116
              party       0.08      0.41      0.14        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))


## Depthwise

### Midas

In [None]:
midas_clf = CatBoostClassifier(
    grow_policy='Depthwise', 
    loss_function='MultiClass', 
    eval_metric='Accuracy', 
    task_type='GPU', **model_params)

In [None]:
midas_clf.fit(
    X_midas_train, eval_set=X_midas_val, **fit_params)

Learning rate set to 0.177599
0:	learn: 0.3875707	test: 0.3844560	best: 0.3844560 (0)	total: 780ms	remaining: 12m 59s
1:	learn: 0.3894504	test: 0.3862724	best: 0.3862724 (1)	total: 1.52s	remaining: 12m 39s
2:	learn: 0.3907500	test: 0.3876538	best: 0.3876538 (2)	total: 2.24s	remaining: 12m 25s
3:	learn: 0.3916145	test: 0.3888050	best: 0.3888050 (3)	total: 2.94s	remaining: 12m 10s
4:	learn: 0.3928528	test: 0.3887283	best: 0.3888050 (3)	total: 3.64s	remaining: 12m 5s
5:	learn: 0.3945093	test: 0.3898539	best: 0.3898539 (5)	total: 4.48s	remaining: 12m 22s
6:	learn: 0.3957810	test: 0.3911842	best: 0.3911842 (6)	total: 5.2s	remaining: 12m 17s
7:	learn: 0.3967348	test: 0.3920796	best: 0.3920796 (7)	total: 5.95s	remaining: 12m 17s
8:	learn: 0.3976384	test: 0.3920540	best: 0.3920796 (7)	total: 6.63s	remaining: 12m 9s
9:	learn: 0.3992615	test: 0.3924889	best: 0.3924889 (9)	total: 7.42s	remaining: 12m 14s
10:	learn: 0.3999978	test: 0.3933332	best: 0.3933332 (10)	total: 8.17s	remaining: 12m 14s
11:

<catboost.core.CatBoostClassifier at 0x7f8f958c6890>

In [None]:
midas_preds = midas_clf.predict(X_midas_val).squeeze()

In [None]:
Counter(midas_preds)

Counter({2: 368, 5: 144, 6: 68, 7: 6, 8: 22826, 10: 1164, 11: 14473, 12: 40})

In [None]:
print(
    classification_report(y_midas_val, midas_preds, target_names=ID2Midas)
)

                       precision    recall  f1-score   support

         appreciation       0.00      0.00      0.00      1032
              command       0.00      0.00      0.00       615
              comment       0.30      0.03      0.05      4035
            complaint       0.00      0.00      0.00       826
          dev_command       0.00      0.00      0.00        54
           neg_answer       0.20      0.02      0.04      1204
open_question_factual       0.24      0.02      0.04       809
open_question_opinion       0.17      0.00      0.00       497
              opinion       0.41      0.75      0.53     12549
        other_answers       0.00      0.00      0.00       309
           pos_answer       0.37      0.09      0.15      4639
            statement       0.38      0.51      0.44     10875
      yes_no_question       0.30      0.01      0.01      1645

             accuracy                           0.40     39089
            macro avg       0.18      0.11      0.10 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


### Entity

In [None]:
entity_clf = CatBoostClassifier(
    grow_policy='Depthwise', 
    loss_function='MultiLogloss', 
    eval_metric='Accuracy',
    **model_params, task_type='CPU')

In [None]:
entity_clf.fit(X_entity_train, eval_set=X_entity_val, **fit_params)

Custom logger is already specified. Specify more than one logger at same time is not thread safe.

Learning rate set to 0.114166
0:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 10.6s	remaining: 2h 57m 3s
1:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 21.6s	remaining: 2h 59m 25s
2:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 32.5s	remaining: 3h 16s
3:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 43.5s	remaining: 3h 19s
4:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 54s	remaining: 2h 59m 12s
5:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 1m 4s	remaining: 2h 58m 29s
Stopped by overfitting detector  (5 iterations wait)

bestTest = 0.8457622349
bestIteration = 0

Shrink model to first 1 iterations.


<catboost.core.CatBoostClassifier at 0x7f8f9453fb50>

In [None]:
entity_preds = entity_clf.predict(X_entity_val).squeeze()
entity_probas = entity_clf.predict_proba(X_entity_val)

In [None]:
print(
    classification_report(y_entity_val, entity_preds, target_names=ID2Entity)
)

                     precision    recall  f1-score   support

             person       0.00      0.00      0.00      1336
           location       0.00      0.00      0.00      1007
          videoname       0.00      0.00      0.00       797
       organization       0.00      0.00      0.00       690
             device       0.00      0.00      0.00       486
              sport       0.00      0.00      0.00       376
           duration       0.00      0.00      0.00       455
             number       0.00      0.00      0.00       455
              genre       0.00      0.00      0.00       339
          sportteam       0.00      0.00      0.00       199
           position       0.00      0.00      0.00       158
              event       0.00      0.00      0.00       159
softwareapplication       0.00      0.00      0.00       198
            vehicle       0.00      0.00      0.00       116
              party       0.00      0.00      0.00        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
print(
    classification_report(
        y_entity_val, 
        prediction_by_heuristic(entity_probas, top_n=1),
        target_names=ID2Entity)
)

                     precision    recall  f1-score   support

             person       0.03      1.00      0.07      1336
           location       0.00      0.00      0.00      1007
          videoname       0.00      0.00      0.00       797
       organization       0.00      0.00      0.00       690
             device       0.00      0.00      0.00       486
              sport       0.00      0.00      0.00       376
           duration       0.00      0.00      0.00       455
             number       0.00      0.00      0.00       455
              genre       0.00      0.00      0.00       339
          sportteam       0.00      0.00      0.00       199
           position       0.00      0.00      0.00       158
              event       0.00      0.00      0.00       159
softwareapplication       0.00      0.00      0.00       198
            vehicle       0.00      0.00      0.00       116
              party       0.00      0.00      0.00        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
print(
    classification_report(
        y_entity_val, 
        prediction_by_heuristic(entity_probas, top_n=3),
        target_names=ID2Entity)
)

  _warn_prf(average, modifier, msg_start, len(result))


                     precision    recall  f1-score   support

             person       0.03      1.00      0.07      1336
           location       0.03      1.00      0.05      1007
          videoname       0.02      1.00      0.04       797
       organization       0.00      0.00      0.00       690
             device       0.00      0.00      0.00       486
              sport       0.00      0.00      0.00       376
           duration       0.00      0.00      0.00       455
             number       0.00      0.00      0.00       455
              genre       0.00      0.00      0.00       339
          sportteam       0.00      0.00      0.00       199
           position       0.00      0.00      0.00       158
              event       0.00      0.00      0.00       159
softwareapplication       0.00      0.00      0.00       198
            vehicle       0.00      0.00      0.00       116
              party       0.00      0.00      0.00        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))


In [10]:
for_ovr = {
    'verbose': False,
    'random_seed': 42,
    'use_best_model': False,
    'devices':'0:1',
    }

ovr_clf = OneVsRestClassifier(
    estimator=CatBoostClassifier(
        grow_policy='Depthwise', loss_function='MultiClass', 
        eval_metric='TotalF1', task_type='GPU', **for_ovr)
)

In [11]:
ovr_clf.fit(X_train, y_entity_train)

OneVsRestClassifier(estimator=<catboost.core.CatBoostClassifier object at 0x7fa9f30d1ad0>)

In [12]:
entity_preds = ovr_clf.predict(X_val).squeeze()
entity_probas = ovr_clf.predict_proba(X_val)

In [13]:
print(
    classification_report(y_entity_val, entity_preds, target_names=ID2Entity)
)

                     precision    recall  f1-score   support

             person       0.72      0.02      0.04      1336
           location       0.84      0.04      0.07      1007
          videoname       0.82      0.02      0.03       797
       organization       0.73      0.01      0.02       690
             device       0.84      0.05      0.10       486
              sport       0.73      0.02      0.04       376
           duration       0.85      0.14      0.24       455
             number       0.81      0.08      0.14       455
              genre       0.73      0.02      0.05       339
          sportteam       0.64      0.05      0.08       199
           position       0.00      0.00      0.00       158
              event       1.00      0.04      0.07       159
softwareapplication       0.00      0.00      0.00       198
            vehicle       0.82      0.08      0.14       116
              party       0.00      0.00      0.00        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [14]:
print(
    classification_report(
        y_entity_val, 
        prediction_by_heuristic(entity_probas, top_n=1),
        target_names=ID2Entity)
)

                     precision    recall  f1-score   support

             person       0.07      0.60      0.13      1336
           location       0.06      0.53      0.10      1007
          videoname       0.08      0.37      0.13       797
       organization       0.07      0.34      0.11       690
             device       0.08      0.47      0.14       486
              sport       0.12      0.44      0.19       376
           duration       0.06      0.39      0.11       455
             number       0.07      0.32      0.12       455
              genre       0.14      0.28      0.18       339
          sportteam       0.17      0.30      0.22       199
           position       0.08      0.11      0.09       158
              event       0.14      0.11      0.12       159
softwareapplication       0.09      0.14      0.11       198
            vehicle       0.11      0.29      0.16       116
              party       0.13      0.09      0.11        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [15]:
print(
    classification_report(
        y_entity_val, 
        prediction_by_heuristic(entity_probas, top_n=3),
        target_names=ID2Entity)
)

  _warn_prf(average, modifier, msg_start, len(result))


                     precision    recall  f1-score   support

             person       0.05      0.92      0.09      1336
           location       0.03      0.85      0.07      1007
          videoname       0.04      0.77      0.08       797
       organization       0.04      0.72      0.07       690
             device       0.04      0.68      0.08       486
              sport       0.08      0.69      0.14       376
           duration       0.03      0.58      0.05       455
             number       0.03      0.55      0.06       455
              genre       0.07      0.58      0.12       339
          sportteam       0.09      0.61      0.15       199
           position       0.06      0.55      0.10       158
              event       0.06      0.29      0.10       159
softwareapplication       0.06      0.36      0.10       198
            vehicle       0.06      0.48      0.10       116
              party       0.10      0.32      0.15        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))


## Lossguide

In [None]:
midas_clf = CatBoostClassifier(
    grow_policy='Lossguide', 
    loss_function='MultiClass', 
    eval_metric='Accuracy', 
    task_type='GPU', **model_params)

In [None]:
midas_clf.fit(
    X_midas_train, eval_set=X_midas_val, **fit_params)

In [None]:
midas_preds = midas_clf.predict(X_midas_val).squeeze()

In [None]:
Counter(midas_preds)

Counter({2: 590, 5: 184, 6: 94, 7: 2, 8: 22746, 10: 1205, 11: 14212, 12: 56})

In [None]:
print(
    classification_report(y_midas_val, midas_preds, target_names=ID2Midas)
)

                       precision    recall  f1-score   support

         appreciation       0.00      0.00      0.00      1032
              command       0.00      0.00      0.00       615
              comment       0.30      0.04      0.08      4035
            complaint       0.00      0.00      0.00       826
          dev_command       0.00      0.00      0.00        54
           neg_answer       0.23      0.04      0.06      1204
open_question_factual       0.21      0.02      0.04       809
open_question_opinion       0.00      0.00      0.00       497
              opinion       0.41      0.75      0.53     12549
        other_answers       0.00      0.00      0.00       309
           pos_answer       0.36      0.09      0.15      4639
            statement       0.39      0.51      0.44     10875
      yes_no_question       0.32      0.01      0.02      1645

             accuracy                           0.40     39089
            macro avg       0.17      0.11      0.10 

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


### Entity

In [18]:
entity_clf = CatBoostClassifier(
    grow_policy='Lossguide', 
    loss_function='MultiLogloss', 
    eval_metric='Accuracy',
    **model_params, task_type='CPU')

In [19]:
entity_clf.fit(X_entity_train, eval_set=X_entity_val, **fit_params)

Learning rate set to 0.114166
0:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 11.7s	remaining: 3h 14m 54s
1:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 23.6s	remaining: 3h 15m 57s
2:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 35.4s	remaining: 3h 15m 53s
3:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 47.1s	remaining: 3h 15m 30s
4:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 58.8s	remaining: 3h 14m 59s
5:	learn: 0.8449516	test: 0.8457622	best: 0.8457622 (0)	total: 1m 10s	remaining: 3h 14m 52s
Stopped by overfitting detector  (5 iterations wait)

bestTest = 0.8457622349
bestIteration = 0

Shrink model to first 1 iterations.


<catboost.core.CatBoostClassifier at 0x7fa9e8640890>

In [20]:
entity_preds = entity_clf.predict(X_entity_val).squeeze()
entity_probas = entity_clf.predict_proba(X_entity_val)

In [21]:
print(
    classification_report(y_entity_val, entity_preds, target_names=ID2Entity)
)

                     precision    recall  f1-score   support

             person       0.00      0.00      0.00      1336
           location       0.00      0.00      0.00      1007
          videoname       0.00      0.00      0.00       797
       organization       0.00      0.00      0.00       690
             device       0.00      0.00      0.00       486
              sport       0.00      0.00      0.00       376
           duration       0.00      0.00      0.00       455
             number       0.00      0.00      0.00       455
              genre       0.00      0.00      0.00       339
          sportteam       0.00      0.00      0.00       199
           position       0.00      0.00      0.00       158
              event       0.00      0.00      0.00       159
softwareapplication       0.00      0.00      0.00       198
            vehicle       0.00      0.00      0.00       116
              party       0.00      0.00      0.00        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [22]:
print(
    classification_report(
        y_entity_val, 
        prediction_by_heuristic(entity_probas, top_n=1),
        target_names=ID2Entity)
)

                     precision    recall  f1-score   support

             person       0.03      1.00      0.07      1336
           location       0.00      0.00      0.00      1007
          videoname       0.00      0.00      0.00       797
       organization       0.00      0.00      0.00       690
             device       0.00      0.00      0.00       486
              sport       0.00      0.00      0.00       376
           duration       0.00      0.00      0.00       455
             number       0.00      0.00      0.00       455
              genre       0.00      0.00      0.00       339
          sportteam       0.00      0.00      0.00       199
           position       0.00      0.00      0.00       158
              event       0.00      0.00      0.00       159
softwareapplication       0.00      0.00      0.00       198
            vehicle       0.00      0.00      0.00       116
              party       0.00      0.00      0.00        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [23]:
print(
    classification_report(
        y_entity_val, 
        prediction_by_heuristic(entity_probas, top_n=3),
        target_names=ID2Entity)
)

  _warn_prf(average, modifier, msg_start, len(result))


                     precision    recall  f1-score   support

             person       0.03      1.00      0.07      1336
           location       0.03      1.00      0.05      1007
          videoname       0.02      1.00      0.04       797
       organization       0.00      0.00      0.00       690
             device       0.00      0.00      0.00       486
              sport       0.00      0.00      0.00       376
           duration       0.00      0.00      0.00       455
             number       0.00      0.00      0.00       455
              genre       0.00      0.00      0.00       339
          sportteam       0.00      0.00      0.00       199
           position       0.00      0.00      0.00       158
              event       0.00      0.00      0.00       159
softwareapplication       0.00      0.00      0.00       198
            vehicle       0.00      0.00      0.00       116
              party       0.00      0.00      0.00        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))


In [24]:
for_ovr = {
    'verbose': False,
    'random_seed': 42,
    'use_best_model': False,
    'devices':'0:1',
    }

ovr_clf = OneVsRestClassifier(
    estimator=CatBoostClassifier(
        grow_policy='Lossguide', loss_function='MultiClass', 
        eval_metric='TotalF1', task_type='GPU', **for_ovr)
)

In [25]:
ovr_clf.fit(X_train, y_entity_train)

OneVsRestClassifier(estimator=<catboost.core.CatBoostClassifier object at 0x7fa9e85a2110>)

In [26]:
entity_preds = ovr_clf.predict(X_val).squeeze()
entity_probas = ovr_clf.predict_proba(X_val)

In [27]:
print(
    classification_report(y_entity_val, entity_preds, target_names=ID2Entity)
)

                     precision    recall  f1-score   support

             person       0.52      0.02      0.03      1336
           location       0.87      0.03      0.06      1007
          videoname       0.69      0.01      0.02       797
       organization       1.00      0.01      0.01       690
             device       0.81      0.05      0.10       486
              sport       0.67      0.02      0.03       376
           duration       0.85      0.13      0.22       455
             number       0.79      0.07      0.12       455
              genre       0.77      0.03      0.06       339
          sportteam       0.73      0.04      0.08       199
           position       0.00      0.00      0.00       158
              event       1.00      0.04      0.07       159
softwareapplication       0.00      0.00      0.00       198
            vehicle       0.77      0.09      0.16       116
              party       0.00      0.00      0.00        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [28]:
print(
    classification_report(
        y_entity_val, 
        prediction_by_heuristic(entity_probas, top_n=1),
        target_names=ID2Entity)
)

                     precision    recall  f1-score   support

             person       0.07      0.66      0.12      1336
           location       0.06      0.56      0.10      1007
          videoname       0.08      0.36      0.13       797
       organization       0.07      0.34      0.12       690
             device       0.10      0.45      0.16       486
              sport       0.12      0.37      0.18       376
           duration       0.07      0.39      0.12       455
             number       0.08      0.29      0.13       455
              genre       0.17      0.25      0.20       339
          sportteam       0.20      0.28      0.23       199
           position       0.10      0.06      0.07       158
              event       0.16      0.09      0.12       159
softwareapplication       0.08      0.09      0.08       198
            vehicle       0.14      0.25      0.18       116
              party       0.50      0.12      0.19        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [29]:
print(
    classification_report(
        y_entity_val, 
        prediction_by_heuristic(entity_probas, top_n=3),
        target_names=ID2Entity)
)

  _warn_prf(average, modifier, msg_start, len(result))


                     precision    recall  f1-score   support

             person       0.04      0.95      0.08      1336
           location       0.03      0.89      0.06      1007
          videoname       0.04      0.78      0.08       797
       organization       0.04      0.73      0.07       690
             device       0.05      0.64      0.10       486
              sport       0.09      0.64      0.15       376
           duration       0.03      0.58      0.06       455
             number       0.03      0.53      0.06       455
              genre       0.08      0.54      0.13       339
          sportteam       0.10      0.55      0.17       199
           position       0.07      0.46      0.11       158
              event       0.08      0.25      0.12       159
softwareapplication       0.07      0.28      0.11       198
            vehicle       0.08      0.42      0.13       116
              party       0.17      0.21      0.18        34
               year    

  _warn_prf(average, modifier, msg_start, len(result))
