In [1]:
from scribe_classifier.data.canada import AllCodes, TitleSet, SimpleModel, CombinedModels
from sklearn import metrics

In [2]:
all_codes = AllCodes.load_from_pickle("./source_data/pickles/canada/tidy_sets/all_codes.P", is_path=True)

target_level = 3
emptyset_label = "NA"

mdl_sgd = SimpleModel.load_from_pickle(
    "./source_data/pickles/canada/trained_models/simple.lvl%d.P" % target_level,
    is_path=True
)
mdl_nb = SimpleModel.load_from_pickle(
    "./source_data/pickles/canada/trained_models/simple.lvl%d.bayes.P" % target_level,
    is_path=True
)

print(mdl_sgd.target_level)
print(mdl_nb.target_level)



3
3


In [3]:
test = TitleSet.load_from_pickle('source_data/pickles/canada/test_sets/test.set.lvl%d.P' % target_level, is_path=True)
valid = TitleSet.load_from_pickle('source_data/pickles/canada/test_sets/valid.set.lvl%d.P' % target_level, is_path=True)
train = TitleSet.load_from_pickle('source_data/pickles/canada/test_sets/train.set.lvl%d.P' % target_level, is_path=True)

valid = valid.copy_and_append_empty_string_class(label=emptyset_label, prop_records=0.25)
test = test.copy_and_append_empty_string_class(label=emptyset_label, prop_records=0.25)

In [5]:
cmb_mdl = CombinedModels(
    target_level=target_level, 
    emptyset_label="NA",
    trained_simple_sgd=mdl_sgd, 
    trained_simple_multinom_nb=mdl_nb, 
    all_codes=all_codes)
cmb_mdl.fit_titleset(train)

CombinedModels(all_codes=None, emptyset_label='NA', target_level=3,
        trained_simple_multinom_nb=SimpleModel(emptyset_label='NA', target_level=3, use_bayes=True),
        trained_simple_sgd=SimpleModel(emptyset_label='NA', target_level=3, use_bayes=False))

In [6]:
valid_pred = cmb_mdl.predict_titleset(valid)
test_pred = cmb_mdl.predict_titleset(test)
print(valid_pred)
# print(valid.get_code_vec(target_level=target_level))
print(test_pred)
# print(test.get_code_vec(target_level=target_level))

['961' '961' '753' ..., '961' '961' '961']
['122' '941' '844' ..., '961' '961' '961']


In [7]:
print("Validation Set:")
print(metrics.classification_report(valid.get_code_vec(target_level=target_level), valid_pred))
# print(metrics.confusion_matrix(valid.Y, valid_pred))

Validation Set:
             precision    recall  f1-score   support

        001       0.97      0.46      0.63        82
        011       0.89      0.72      0.80        57
        012       0.98      0.76      0.85        58
        013       1.00      0.47      0.64        15
        021       1.00      0.80      0.89        40
        031       0.93      0.54      0.68        26
        041       0.88      0.77      0.82        82
        042       0.74      0.69      0.71        54
        043       1.00      0.33      0.50        18
        051       0.92      0.84      0.88        58
        060       0.00      0.00      0.00         8
        062       1.00      0.39      0.56        28
        063       0.94      0.75      0.83        20
        065       0.93      0.70      0.80        20
        071       0.00      0.00      0.00        33
        073       0.93      0.52      0.67        27
        081       0.14      0.67      0.23         9
        082       0.92      0

  'precision', 'predicted', average, warn_for)


In [8]:
print("Test Set:")
print(metrics.classification_report(test.get_code_vec(target_level=target_level), test_pred))

Test Set:
             precision    recall  f1-score   support

        001       1.00      0.46      0.63        82
        011       0.88      0.79      0.83        57
        012       0.94      0.75      0.83        59
        013       1.00      0.56      0.72        16
        021       0.97      0.68      0.80        41
        031       1.00      0.58      0.73        26
        041       0.96      0.80      0.87        82
        042       0.90      0.70      0.79        53
        043       1.00      0.50      0.67        18
        051       0.94      0.81      0.87        57
        060       0.00      0.00      0.00         8
        062       1.00      0.46      0.63        28
        063       0.85      0.85      0.85        20
        065       0.93      0.68      0.79        19
        071       0.00      0.00      0.00        33
        073       0.82      0.52      0.64        27
        081       0.19      0.80      0.30        10
        082       1.00      0.77   

  'precision', 'predicted', average, warn_for)
