# Naive Bayes

## Load imports.

In [1]:
from sklearn.metrics import classification_report, confusion_matrix

# Make common scripts visible
import sys
sys.path.append('../common/')

from loader import load_preprocessed_data
from classification import run_bernoulli_naive_bayes
from lookup_tables import topic_code_to_topic_dict

## Load the train and test data.

Use already lemmatized data.

In [2]:
x, y = load_preprocessed_data('data/rcv1_lemmatized.csv')

# Split data into 80% train, 20% test
total_examples = len(y)
split_point = int(total_examples * 0.8)
train_x = x[:split_point]
train_y = y[:split_point]
test_x = x[split_point:]
test_y = y[split_point:]

## Assess Bernoulli Naive Bayes baseline classification performance.

Run Bernoulli Naive Bayes and report classification accuracy.

In [3]:
predict_y = run_bernoulli_naive_bayes(train_x,
                                      train_y,
                                      test_x,
                                      test_y,
                                      ngram_range = (1, 1))

In [4]:
print(classification_report(test_y, predict_y, digits=6, target_names=topic_code_to_topic_dict.values()))
print(confusion_matrix(test_y, predict_y))

                        precision    recall  f1-score   support

CRIME, LAW ENFORCEMENT   0.882540  0.953143  0.916484      6125
  ECONOMIC PERFORMANCE   0.926600  0.950602  0.938448      1660
             ELECTIONS   0.842675  0.846263  0.844465      2114
                HEALTH   0.951456  0.800000  0.869180       980
              RELIGION   0.882353  0.497630  0.636364       422
                SPORTS   0.989170  0.959603  0.974163      6758

             micro avg   0.923861  0.923861  0.923861     18059
             macro avg   0.912466  0.834540  0.863184     18059
          weighted avg   0.925562  0.923861  0.922544     18059

[[5838   41  184   22   17   23]
 [  19 1578   49    4    1    9]
 [ 237   64 1789    1    2   21]
 [ 154   14   17  784    2    9]
 [ 165    0   32    6  210    9]
 [ 202    6   52    7    6 6485]]
