# Knowledge Base Classification

## Load imports.

In [1]:
# Make common scripts visible and knowledge base classifier code
import sys
sys.path.append('../common/')
sys.path.append('../kb-classifier/')

import numpy as np

from loader import load_preprocessed_data
from kb_classifier import KnowledgeBasePredictor
from kb_common import wiki_topics_to_index
from lookup_tables import topic_to_int, int_to_topic

## Load the data

In [2]:
# Load the already lowercased, lemmatised data
train_x, train_y = load_preprocessed_data('data/ohsumed_lemmatized_train.csv')
test_x, test_y = load_preprocessed_data('data/ohsumed_lemmatized_test.csv')

## Initialise and tune class probabilities for knowledge base classifier

In [None]:
np.random.seed(42)
kb_predictor = KnowledgeBasePredictor(topic_to_int.keys(),
                                      topic_depth='all',
                                      top_level_prediction_number=len(wiki_topics_to_index.keys()))
kb_predictor.train(train_x[:100], train_y[:100], balanced_classes=False)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96


## Assess knowledge base classifier performance.

In [None]:
print('Making predictions for {} documents'.format(len(test_y)))
predict_y = kb_predictor.predict(test_x)
classification_report, confusion_matrix = kb_predictor.get_classification_report(test_y, predict_y)

print(classification_report)
print(confusion_matrix)

## Find examples where predictions went wrong

In [None]:
test_x = np.array(test_x)
test_y = np.array(test_y)

for topic, index in topic_to_int.items():    
    topic_subset = predict_y[test_y == index]    
    topic_subset_incorrect = topic_subset[topic_subset != index]
    document_subset = test_x[test_y == index]
    document_subset = document_subset[topic_subset != index]
    
    if len(document_subset) > 0:
        print('------ 5 random erroneous predictions for {} ------'.format(topic))
        print('')
        random_indices = np.random.choice(np.arange(len(topic_subset_incorrect)), 5)
        for index in random_indices:
            print(document_subset[index])
            print('')
            print('Above classified as {}'.format(int_to_topic[topic_subset_incorrect[index]]))
            print('')
        print('')