# Naive Bayes

## Load imports.

In [1]:
from sklearn.metrics import classification_report, confusion_matrix

# Make common scripts visible
import sys
sys.path.append('../common/')

from loader import load_preprocessed_data
from classification import run_bernoulli_naive_bayes
from lookup_tables import topic_code_to_topic_dict

## Load the train and test data.

Use already lemmatized data.

In [2]:
x, y = load_preprocessed_data('data/rcv1_lemmatized.csv')

# Split data into 80% train, 20% test
total_examples = len(y)
split_point = int(total_examples * 0.8)
train_x = x[:split_point]
train_y = y[:split_point]
test_x = x[split_point:]
test_y = y[split_point:]

## Assess Bernoulli Naive Bayes baseline classification performance.

Run Bernoulli Naive Bayes and report classification accuracy.

In [3]:
predict_y = run_bernoulli_naive_bayes(train_x,
                                      train_y,
                                      test_x,
                                      test_y,
                                      ngram_range = (1, 1))

In [4]:
print(classification_report(test_y, predict_y, digits=6, target_names=topic_code_to_topic_dict.values()))
print(confusion_matrix(test_y, predict_y))

                        precision    recall  f1-score   support

CRIME, LAW ENFORCEMENT   0.886601  0.955770  0.919887      6127
  ECONOMIC PERFORMANCE   0.922444  0.951515  0.936754      1650
             ELECTIONS   0.842251  0.859699  0.850885      2124
                HEALTH   0.936636  0.820383  0.874664       991
              RELIGION   0.884000  0.522459  0.656761       423
                SPORTS   0.992131  0.951324  0.971299      6759

             micro avg   0.924864  0.924864  0.924864     18074
             macro avg   0.910677  0.843525  0.868375     18074
          weighted avg   0.926808  0.924864  0.923906     18074

[[5856   40  168   28   12   23]
 [  20 1570   55    2    0    3]
 [ 215   65 1826    4    1   13]
 [ 137   20   14  813    2    5]
 [ 156    1   30    8  221    7]
 [ 221    6   75   13   14 6430]]
