# Naive Bayes

## Load imports.

In [8]:
from collections import defaultdict
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

# Make common scripts visible
import sys
sys.path.append('../common/')

from reuters_parser import load_data
from sentence_utils import remove_stop_words_and_lemmatize
from conversion import convert_dictionary_to_array
from classification import run_bernoulli_naive_bayes
from lookup_tables import topic_code_to_topic_dict, topic_code_to_int, int_to_topic_code

## Load the train and test data.

Load the articles.

In [2]:
def print_number_of_articles_per_topic(dataset, dataset_name):
    # Print out the number of documents in each category
    print('')
    print('------------------ {} ------------------'.format(dataset_name))
    print('')
    total_number = 0
    for topic_code, articles in dataset.items():
        print('Number of articles for topic {}: {}'.format(topic_code_to_topic_dict[topic_code], len(articles)))
        total_number += len(articles)
    print('')
    print('Total number of articles: {}'.format(total_number))

year_data = load_data('19960820', '19970819', '../../../downloads/reuters/rcv1/', topic_code_to_topic_dict)
#year_data = load_data('19960820', '19960830', '../../../downloads/reuters/rcv1/', topic_code_to_topic_dict)

print_number_of_articles_per_topic(year_data, 'Data for a Year August 96 to August 97')


------------------ Data for a Year August 96 to August 97 ------------------

Number of articles for topic CRIME, LAW ENFORCEMENT: 30276
Number of articles for topic SPORTS: 35200
Number of articles for topic RELIGION: 2287
Number of articles for topic ELECTIONS: 10940
Number of articles for topic ECONOMIC PERFORMANCE: 8452
Number of articles for topic HEALTH: 4999

Total number of articles: 92154


Lemmatize and remove stopwords from each news article.

In [3]:
def sanitise_each_topic(dataset):
    """
    Removes stop words and lemmatizes all articles for each topic.
    """
    data_sanitised = defaultdict(list)
    
    for topic_code, articles in dataset.items():
        for article in articles:
            article_sanitised = remove_stop_words_and_lemmatize(article)
            data_sanitised[topic_code].append(article_sanitised)
    
    return data_sanitised


year_data_sanitised = sanitise_each_topic(year_data)

Convert dictionary to array.

In [4]:
np.random.seed(42)

# Split data into 80% train, 20% test
x, y = convert_dictionary_to_array(year_data_sanitised, topic_code_to_int)
total_examples = len(y)
split_point = int(total_examples * 0.8)
train_x = x[:split_point]
train_y = y[:split_point]
test_x = x[split_point:]
test_y = y[split_point:]

## Assess Bernoulli Naive Bayes baseline classification performance.

Run Bernoulli Naive Bayes and report classification accuracy.

In [5]:
predict_y = run_bernoulli_naive_bayes(train_x,
                                      train_y,
                                      test_x,
                                      test_y,
                                      ngram_range = (1, 1))

In [10]:
print(classification_report(test_y, predict_y, digits=6, target_names=topic_code_to_topic_dict.values()))
print(confusion_matrix(test_y, predict_y))
print('')
print('Accuracy score of {}'.format(accuracy_score(test_y, predict_y)))

                        precision    recall  f1-score   support

CRIME, LAW ENFORCEMENT   0.881695  0.963499  0.920784      6219
  ECONOMIC PERFORMANCE   0.936842  0.950178  0.943463      1686
             ELECTIONS   0.848554  0.845450  0.847000      2187
                HEALTH   0.950237  0.796425  0.866559      1007
              RELIGION   0.888889  0.516129  0.653061       434
                SPORTS   0.993684  0.957959  0.975495      6898

             micro avg   0.926537  0.926537  0.926537     18431
             macro avg   0.916650  0.838273  0.867727     18431
          weighted avg   0.928635  0.926537  0.925313     18431

[[5992   24  153   20   18   12]
 [  16 1602   61    2    0    5]
 [ 251   61 1849    4    2   20]
 [ 158   19   23  802    3    2]
 [ 167    1   31    8  224    3]
 [ 212    3   62    8    5 6608]]

Accuracy score of 0.9265368129781346
