In [1]:
import os
from xgboost import XGBClassifier
from data_preprocess import *
from cross_validate import *
from parse_and_create_trees import *
import warnings
warnings.filterwarnings("ignore")

In [2]:
root = '/home/kirb/work/ISP-projects/xgboost_diploma/'

In [3]:
law = pd.read_csv(os.path.join(root, 'law/dataset_prepared/dataset.csv'), index_col=0)
tz = pd.read_csv(os.path.join(root, 'tz/dataset_prepared/dataset.csv'), index_col=0)

In [4]:
law_prepared = data_preprocessing(law, type='law')
tz_prepared = data_preprocessing(tz, type='tz')

law data successfully preprocessed
tz data successfully preprocessed


In [8]:
list(tz_prepared.columns)

['day_month_regexp',
 'day_month_regexp_next_1',
 'day_month_regexp_next_2',
 'day_month_regexp_next_3',
 'day_month_regexp_prev_1',
 'day_month_regexp_prev_2',
 'day_month_regexp_prev_3',
 'dot_number_regexp',
 'dot_number_regexp_len',
 'dot_number_regexp_len_next_1',
 'dot_number_regexp_len_next_2',
 'dot_number_regexp_len_next_3',
 'dot_number_regexp_len_prev_1',
 'dot_number_regexp_len_prev_2',
 'dot_number_regexp_len_prev_3',
 'dot_number_regexp_max',
 'dot_number_regexp_max_next_1',
 'dot_number_regexp_max_next_2',
 'dot_number_regexp_max_next_3',
 'dot_number_regexp_max_prev_1',
 'dot_number_regexp_max_prev_2',
 'dot_number_regexp_max_prev_3',
 'dot_number_regexp_next_1',
 'dot_number_regexp_next_2',
 'dot_number_regexp_next_3',
 'dot_number_regexp_prev_1',
 'dot_number_regexp_prev_2',
 'dot_number_regexp_prev_3',
 'is_in_toc',
 'is_lower',
 'is_lower_next_1',
 'is_lower_next_2',
 'is_lower_next_3',
 'is_lower_prev_1',
 'is_lower_prev_2',
 'is_lower_prev_3',
 'is_toc_line',
 'is

In [7]:
list(law_prepared.columns)

['application',
 'bracket_num',
 'bracket_num_next_1',
 'bracket_num_next_2',
 'bracket_num_next_3',
 'bracket_num_prev_1',
 'bracket_num_prev_2',
 'bracket_num_prev_3',
 'current_regexp',
 'current_regexp_next_1',
 'current_regexp_next_2',
 'current_regexp_next_3',
 'current_regexp_prev_1',
 'current_regexp_prev_2',
 'current_regexp_prev_3',
 'endswith_colon',
 'endswith_colon_next_1',
 'endswith_colon_next_2',
 'endswith_colon_next_3',
 'endswith_colon_prev_1',
 'endswith_colon_prev_2',
 'endswith_colon_prev_3',
 'endswith_comma',
 'endswith_comma_next_1',
 'endswith_comma_next_2',
 'endswith_comma_next_3',
 'endswith_comma_prev_1',
 'endswith_comma_prev_2',
 'endswith_comma_prev_3',
 'endswith_dot',
 'endswith_dot_next_1',
 'endswith_dot_next_2',
 'endswith_dot_next_3',
 'endswith_dot_prev_1',
 'endswith_dot_prev_2',
 'endswith_dot_prev_3',
 'endswith_semicolon',
 'endswith_semicolon_next_1',
 'endswith_semicolon_next_2',
 'endswith_semicolon_next_3',
 'endswith_semicolon_prev_1',
 

In [7]:
law_dfs_of_every_doc = split_data_by_docs(law_prepared)
tz_dfs_of_every_doc = split_data_by_docs(tz_prepared)

In [6]:
law_train_val, law_test = my_train_test_split(law_dfs_of_every_doc, test_size=0.8)
tz_train_val, tz_test = my_train_test_split(tz_dfs_of_every_doc, test_size=0.8)

In [7]:
law_model = XGBClassifier(learning_rate=0.8,
                          n_estimators=300,
                          booster="gbtree",
                          tree_method="gpu_hist",
                          max_depth=5,
                          random_state=42,
                          verbosity=0)

tz_model = XGBClassifier(learning_rate=0.8,
                         n_estimators=300,
                         booster="gbtree",
                         tree_method="gpu_hist",
                         max_depth=5,
                         random_state=42,
                         verbosity=0)

In [8]:
law_best_model, law_train, law_test, law_metrics = my_cross_validate(law_model, law_train_val, law_test)
tz_best_model, tz_train, tz_test, tz_metrics = my_cross_validate(tz_model, tz_train_val, tz_test)

starting cross validate
starting cross validate


In [38]:
main_key = 'other_test_metrics'

In [39]:
print('baseline law results:\n')
print('best_train_accuracy: {}\n'.format(law_metrics['best_train_accuracy']))
print('test_accuracy: {}'.format(law_metrics['test_accuracy']))
for metric in law_metrics[main_key].keys():
    print(f'\n{metric}:')
    for average in law_metrics[main_key][metric].keys():
        print(f'{average}: {law_metrics[main_key][metric][average]}')

baseline law results:

best_train_accuracy: 0.9875954198473282

test_accuracy: 0.9829529037850332

f1:
None: [0.84805654 0.98398398 0.98905814 0.88372093 0.89      ]
micro: 0.9829529037850332
macro: 0.9189639186697244
weighted: 0.9830514584945006

precision:
None: [0.82758621 0.9894313  0.98799829 0.9047619  0.86407767]
micro: 0.9829529037850332
macro: 0.9147710741006586
weighted: 0.9832150676479011

recall:
None: [0.86956522 0.97859632 0.99012027 0.86363636 0.91752577]
micro: 0.9829529037850332
macro: 0.9238887891426089
weighted: 0.9829529037850332

roc_auc:
macro: 0.9978144312951764
weighted: 0.9977834151840235


In [40]:
print('baseline tz results:\n')
print('best_train_accuracy: {}\n'.format(tz_metrics['best_train_accuracy']))
print('test_accuracy: {}'.format(tz_metrics['test_accuracy']))
for metric in tz_metrics[main_key].keys():
    print(f'\n{metric}:')
    for average in tz_metrics[main_key][metric].keys():
        print(f'{average}: {tz_metrics[main_key][metric][average]}')

baseline tz results:

best_train_accuracy: 0.9690721649484536

test_accuracy: 0.9565217391304348

f1:
None: [0.93975904 0.78740157 0.9815818  0.92063492 0.98533724]
micro: 0.9565217391304348
macro: 0.922942914693523
weighted: 0.9564135501605822

precision:
None: [0.94660194 0.79365079 0.98264642 0.90625    0.97674419]
micro: 0.9565217391304348
macro: 0.9211786684538346
weighted: 0.9563906526724413

recall:
None: [0.93301435 0.78125    0.98051948 0.93548387 0.99408284]
micro: 0.9565217391304348
macro: 0.9248701091581788
weighted: 0.9565217391304348

roc_auc:
macro: 0.9918086784139388
weighted: 0.9950912366331469


In [37]:
law_best_model.save_model('law_best_model.json')
tz_best_model.save_model('tz_best_model.json')

In [None]:
law_train.to_csv('law_train.csv', index=False)
law_test.to_csv('law_test.csv', index=False)

tz_train.to_csv('tz_train.csv', index=False)
tz_test.to_csv('tz_test.csv', index=False)