# Naive Bayes

## Load imports.

In [1]:
from sklearn.metrics import classification_report, confusion_matrix

# Make common scripts visible
import sys
sys.path.append('../common/')

from loader import load_preprocessed_data
from classification import run_bernoulli_naive_bayes
from lookup_tables import topic_code_to_topic_dict

## Load the train and test data.

Use already lemmatized data.

In [2]:
x, y = load_preprocessed_data('data/rcv1_lemmatized.csv')

# Split data into 80% train, 20% test
total_examples = len(y)
split_point = int(total_examples * 0.8)
train_x = x[:split_point]
train_y = y[:split_point]
test_x = x[split_point:]
test_y = y[split_point:]

## Assess Bernoulli Naive Bayes baseline classification performance.

Run Bernoulli Naive Bayes and report classification accuracy.

In [3]:
predict_y = run_bernoulli_naive_bayes(train_x,
                                      train_y,
                                      test_x,
                                      test_y,
                                      ngram_range = (1, 1))

In [4]:
print(classification_report(test_y, predict_y, digits=6, target_names=topic_code_to_topic_dict.values()))
print(confusion_matrix(test_y, predict_y))

                        precision    recall  f1-score   support

CRIME, LAW ENFORCEMENT   0.887388  0.954924  0.919918      6123
  ECONOMIC PERFORMANCE   0.927976  0.952934  0.940290      1636
             ELECTIONS   0.844372  0.861124  0.852666      2117
                HEALTH   0.934783  0.783401  0.852423       988
              RELIGION   0.889362  0.494090  0.635258       423
                SPORTS   0.990688  0.960770  0.975500      6755

             micro avg   0.925729  0.925729  0.925729     18042
             macro avg   0.912428  0.834540  0.862676     18042
          weighted avg   0.927339  0.925729  0.924314     18042

[[5847   36  171   27   15   27]
 [  18 1559   53    2    0    4]
 [ 208   61 1823    6    1   18]
 [ 161   18   26  774    3    6]
 [ 162    1   38    7  209    6]
 [ 193    5   48   12    7 6490]]
