In [1]:
from scribe_classifier.data.canada import AllCodes, TitleSet, SimpleModel, CombinedModels
from sklearn import metrics

In [2]:
all_codes = AllCodes.load_from_pickle("./source_data/pickles/canada/tidy_sets/all_codes.P", is_path=True)

target_level = 3
emptyset_label = "NA"

mdl_sgd = SimpleModel.load_from_pickle(
    "./source_data/pickles/canada/trained_models/simple.lvl%d.P" % target_level,
    is_path=True
)
mdl_nb = SimpleModel.load_from_pickle(
    "./source_data/pickles/canada/trained_models/simple.lvl%d.bayes.P" % target_level,
    is_path=True
)

In [3]:
test = TitleSet.load_from_pickle('source_data/pickles/canada/test_sets/test.set.lvl%d.P' % target_level, is_path=True)
valid = TitleSet.load_from_pickle('source_data/pickles/canada/test_sets/valid.set.lvl%d.P' % target_level, is_path=True)
train = TitleSet.load_from_pickle('source_data/pickles/canada/test_sets/train.set.lvl%d.P' % target_level, is_path=True)

valid = valid.copy_and_append_empty_string_class(label=emptyset_label, prop_records=0.25)
test = test.copy_and_append_empty_string_class(label=emptyset_label, prop_records=0.25)

In [4]:
cmb_mdl = CombinedModels(
    target_level=target_level, 
    emptyset_label="NA",
    trained_simple_sgd=mdl_sgd, 
    trained_simple_multinom_nb=mdl_nb, 
    all_codes=all_codes)
cmb_mdl.fit_titleset(train)

CombinedModels(all_codes=None, emptyset_label='NA', target_level=3,
        trained_simple_multinom_nb=SimpleModel(emptyset_label='NA', target_level=3, use_bayes=True),
        trained_simple_sgd=SimpleModel(emptyset_label='NA', target_level=3, use_bayes=False))

In [5]:
valid_pred = cmb_mdl.predict_titleset(valid)
test_pred = cmb_mdl.predict_titleset(test)
print(valid_pred)
# print(valid.get_code_vec(target_level=target_level))
print(test_pred)
# print(test.get_code_vec(target_level=target_level))

['733' '946' '082' ..., '961' '961' '961']
['152' '941' '761' ..., '961' '961' '961']


In [6]:
print("Validation Set:")
print(metrics.classification_report(valid.get_code_vec(target_level=target_level), valid_pred))
# print(metrics.confusion_matrix(valid.Y, valid_pred))

Validation Set:
             precision    recall  f1-score   support

        001       0.94      0.80      0.87        82
        011       0.84      0.82      0.83        57
        012       0.96      0.78      0.86        58
        013       0.64      0.47      0.54        15
        021       0.74      0.88      0.80        40
        031       0.41      0.42      0.42        26
        041       0.92      0.85      0.89        82
        042       0.52      0.59      0.56        54
        043       0.00      0.00      0.00        18
        051       0.56      0.86      0.68        58
        060       0.02      0.12      0.04         8
        062       0.26      0.39      0.31        28
        063       0.00      0.00      0.00        20
        065       0.35      0.65      0.46        20
        071       0.60      0.09      0.16        33
        073       0.00      0.00      0.00        27
        081       0.00      0.00      0.00         9
        082       0.34      0

  'precision', 'predicted', average, warn_for)


In [7]:
print("Test Set:")
print(metrics.classification_report(test.get_code_vec(target_level=target_level), test_pred))

Test Set:
             precision    recall  f1-score   support

        001       0.89      0.80      0.85        82
        011       0.86      0.88      0.87        57
        012       0.91      0.83      0.87        59
        013       0.89      0.50      0.64        16
        021       0.69      0.88      0.77        41
        031       0.38      0.50      0.43        26
        041       0.89      0.80      0.85        82
        042       0.69      0.62      0.65        53
        043       0.00      0.00      0.00        18
        051       0.57      0.89      0.70        57
        060       0.04      0.25      0.07         8
        062       0.33      0.54      0.41        28
        063       0.00      0.00      0.00        20
        065       0.28      0.58      0.37        19
        071       0.80      0.12      0.21        33
        073       0.00      0.00      0.00        27
        081       0.00      0.00      0.00        10
        082       0.32      0.77   

  'precision', 'predicted', average, warn_for)
