In [223]:
import pandas as pd
import numpy as np
import tarfile
import os
import seaborn as sns
import re
import random
import wget

#np.set_printoptions(suppress=True, linewidth=np.inf)
np.set_printoptions(suppress=True, formatter={'float': '{: 0.6f}'.format})

%reload_ext version_information
%version_information numpy, pandas, seaborn, re, random, wget

Software,Version
Python,3.7.4 64bit [GCC 7.3.0]
IPython,7.8.0
OS,Linux 4.15.0 66 generic x86_64 with debian buster sid
numpy,1.17.2
pandas,0.25.1
seaborn,0.9.0
re,2.2.1
random,The 'random' distribution was not found and is required by the application
wget,3.2
Mon Nov 11 12:52:50 2019 CET,Mon Nov 11 12:52:50 2019 CET


In [224]:
filename = "20news-18828"
url = 'http://qwone.com/~jason/20Newsgroups/20news-18828.tar.gz'
tar_file_name = "../data/" + filename + ".tar.gz"
untar_file_name = "../data/"

if not os.path.isfile(tar_file_name):
    wget.download(url, tar_file_name)
else:
    print("File already downloaded.")

if not os.path.isdir(untar_file_name + filename):
    tarfile.open(tar_file_name, "r:gz").extractall(untar_file_name)
else: 
    print("File already extracted.")

File already downloaded.
File already extracted.


In [225]:
newsgroups = ["alt.atheism", "comp.graphics", "sci.space", "talk.religion.misc"]
data_raw = []
labels = []
for path, _, files in os.walk(untar_file_name + filename):
    directory = path.split('/')[-1]
    if directory not in newsgroups:
        continue
    for file in files:
        with open(os.path.join(path, file), encoding="iso-8859-1") as myfile:
            data_raw.append(myfile.read())
            labels.append(directory)

In [226]:
n_entries = len(data_raw)
print("String-Einträge im Array: {}".format(n_entries))

String-Einträge im Array: 3387


In [227]:
data_partition = [data_raw[i].partition('\n\n')[-1] for i in range(len(data_raw))]

In [228]:
data_regex = [re.compile(r"(?u)\b[a-zA-Z]+\b").findall(data_partition[i].lower()) for i in range(len(data_partition))]

In [229]:
different_tokens = []
for tokens in data_regex:
    different_tokens.extend(tokens)
different_tokens = list(dict.fromkeys(different_tokens))

In [230]:
n_tokens = len(different_tokens)
print("Verschiedene Tokens: {}".format(n_tokens))

Verschiedene Tokens: 34588


In [231]:
def load_data(data_regex, different_tokens):
    data = np.zeros(shape=(n_entries, n_tokens), dtype=np.int)
    for j, article in enumerate(data_regex):
        if not ((j + 1) % 100):
            print("Article Nr.: {}\n...".format(j + 1))

        for i, token in enumerate(different_tokens):
            count = article.count(token)

            if count > 0:
                data[j][i] = count
                
    data = np.asarray(data)
    np.save("../data/data.npy", data)
    
    return data

In [232]:
try:
    data = np.load("../data/data.npy")
    print("Found saved data.")
except IOError:
    print("Found no saved data. Process data.")
    data = load_data(data_regex, different_tokens)

Found saved data.


In [233]:
n_train = int(n_entries * 0.6)
c = list(zip(data, labels))
random.shuffle(c)
data, labels = zip(*c)
train = data[:n_train]
train_lables = labels[:n_train]
test = data[n_train:]
test_lables = labels[n_train:]
n_test = len(test_lables)

In [234]:
Nij = {newsgroups[0] : 0, 
       newsgroups[1] : 0, 
       newsgroups[2] : 0, 
       newsgroups[3] : 0}

nij = {newsgroups[0] : np.zeros(n_tokens, dtype=np.int), 
       newsgroups[1] : np.zeros(n_tokens, dtype=np.int), 
       newsgroups[2] : np.zeros(n_tokens, dtype=np.int), 
       newsgroups[3] : np.zeros(n_tokens, dtype=np.int)}

pij = {newsgroups[0] : np.zeros(n_tokens), 
       newsgroups[1] : np.zeros(n_tokens), 
       newsgroups[2] :
       np.zeros(n_tokens), 
       newsgroups[3] : np.zeros(n_tokens)}

pi = {newsgroups[0] : 0.0, 
      newsgroups[1] : 0.0, 
      newsgroups[2] : 0.0, 
      newsgroups[3] : 0.0}

for i in range(n_train):
    Nij[labels[i]] += train[i].sum()

for i in range(n_train):
    nij[labels[i]] += train[i]

for group in newsgroups:
    pij[group] = (nij[group] + 1) / (Nij[group] + n_tokens)
    
for i in range(n_train):
    pi[labels[i]] += 1
    
for group in newsgroups:
    pi[group] /= n_train 
    
for group in newsgroups:
    print("Group(i): {}".format(group))
    print("pi(A-priori): {}".format(pi[group]))
    print("Nij: {}".format(Nij[group]))
    print("nij: {}".format(nij[group]))
    print("pij: {}".format(pij[group]))
    print()

Group(i): alt.atheism
pi(A-priori): 0.23917322834645668
Nij: 144743
nij: [2498  118  426 ...    0    0    1]
pij: [ 0.013935  0.000664  0.002381 ...  0.000006  0.000006  0.000011]

Group(i): comp.graphics
pi(A-priori): 0.2947834645669291
Nij: 145996
nij: [2176   63  259 ...    0    0    0]
pij: [ 0.012055  0.000354  0.001440 ...  0.000006  0.000006  0.000006]

Group(i): sci.space
pi(A-priori): 0.28641732283464566
Nij: 149868
nij: [1661   28  246 ...    0    0    0]
pij: [ 0.009010  0.000157  0.001339 ...  0.000005  0.000005  0.000005]

Group(i): talk.religion.misc
pi(A-priori): 0.1796259842519685
Nij: 127183
nij: [1952   72  415 ...    0    0    0]
pij: [ 0.012073  0.000451  0.002572 ...  0.000006  0.000006  0.000006]



In [238]:
corrects = np.full(shape=(n_test), fill_value=False)
for i, lable in enumerate(test_lables):
    max_p, max_group = -np.inf, None
    for group in newsgroups:
        p = np.sum(np.log(pij[group]) * test[i])
        
        if p > max_p:
            max_p = p
            max_group = group
            
    corrects[i] = (lable==max_group)
print("Correct: {}; Total: {}; Accuracy: {}%".format(corrects.sum(), n_test, np.round(corrects.sum() / n_test, decimals=4)*100))

Correct: 1232; Total: 1355; Accuracy: 90.92%
