# Naive Bayes

## Load imports.

In [7]:
from collections import defaultdict
import numpy as np

# Make common scripts visible
import sys
sys.path.append('../common/')

from reuters_parser import load_data
from sentence_utils import remove_stop_words_and_lemmatize
from conversion import convert_dictionary_to_array
from classification import run_bernoulli_naive_bayes

## Useful lookup tables.

In [2]:
topic_code_to_topic_dict = {
    'GCRIM': 'CRIME, LAW ENFORCEMENT',
    'E11': 'ECONOMIC PERFORMANCE',
    'GVOTE': 'ELECTIONS',
    'GHEA': 'HEALTH',
    'GREL': 'RELIGION',
    'GSPO': 'SPORTS'
}

topic_code_to_int = {
    'GCRIM': 0,
    'E11': 1,
    'GVOTE': 2,
    'GHEA': 3,
    'GREL': 4,
    'GSPO': 5
}

int_to_topic_code = {
    0: 'GCRIM',
    1: 'E11',
    2: 'GVOTE',
    3: 'GHEA',
    4: 'GREL',
    5: 'GSPO'
}

## Load the train and test data.

Load the articles.

In [3]:
def print_number_of_articles_per_topic(dataset, dataset_name):
    # Print out the number of documents in each category
    print('')
    print('------------------ {} ------------------'.format(dataset_name))
    print('')
    total_number = 0
    for topic_code, articles in dataset.items():
        print('Number of articles for topic {}: {}'.format(topic_code_to_topic_dict[topic_code], len(articles)))
        total_number += len(articles)
    print('')
    print('Total number of articles: {}'.format(total_number))

year_data = load_data('19960820', '19970819', '../../../downloads/reuters/rcv1/', topic_code_to_topic_dict)
#year_data = load_data('19960820', '19960830', '../../../downloads/reuters/rcv1/', topic_code_to_topic_dict)

print_number_of_articles_per_topic(year_data, 'Data for a Year August 96 to August 97')


------------------ Data for a Year August 96 to August 97 ------------------

Number of articles for topic CRIME, LAW ENFORCEMENT: 30276
Number of articles for topic SPORTS: 35200
Number of articles for topic RELIGION: 2287
Number of articles for topic ELECTIONS: 10940
Number of articles for topic ECONOMIC PERFORMANCE: 8452
Number of articles for topic HEALTH: 4999

Total number of articles: 92154


Lemmatize and remove stopwords from each news article.

In [4]:
def sanitise_each_topic(dataset):
    """
    Removes stop words and lemmatizes all articles for each topic.
    """
    data_sanitised = defaultdict(list)
    
    for topic_code, articles in dataset.items():
        for article in articles:
            article_sanitised = remove_stop_words_and_lemmatize(article)
            data_sanitised[topic_code].append(article_sanitised)
    
    return data_sanitised


year_data_sanitised = sanitise_each_topic(year_data)

Convert dictionary to array.

In [5]:
np.random.seed(42)

# Split data into 80% train, 20% test
x, y = convert_dictionary_to_array(year_data_sanitised, topic_code_to_int)
total_examples = len(y)
split_point = int(total_examples * 0.8)
train_x = x[:split_point]
train_y = y[:split_point]
test_x = x[split_point:]
test_y = y[split_point:]

## Assess Bernoulli Naive Bayes baseline classification performance.

Run Bernoulli Naive Bayes and report classification accuracy.

In [8]:
report = run_bernoulli_naive_bayes(train_x,
                                   train_y,
                                   test_x,
                                   test_y, 
                                   topic_code_to_topic_dict.values(),
                                   ngram_range = (1, 1))
print(report)

                        precision    recall  f1-score   support

CRIME, LAW ENFORCEMENT   0.882214  0.963499  0.921067      6219
  ECONOMIC PERFORMANCE   0.936295  0.950178  0.943185      1686
             ELECTIONS   0.850554  0.843164  0.846843      2187
                HEALTH   0.943728  0.799404  0.865591      1007
              RELIGION   0.864865  0.516129  0.646465       434
                SPORTS   0.992479  0.956509  0.974162      6898

             micro avg   0.925886  0.925886  0.925886     18431
             macro avg   0.911689  0.838147  0.866219     18431
          weighted avg   0.927625  0.925886  0.924657     18431

