In [1]:
import os
from tqdm import tqdm
from copy import deepcopy
from random import randint
import re
import numpy as np
import pandas as pd
from collections import Counter

from sklearn.metrics import classification_report, accuracy_score

In [2]:
def file_iniate():
  dataset = '20_newsgroups'
  for group in tqdm(os.listdir(dataset)):
    filepath = os.path.join(dataset, group)
    files = os.listdir(filepath)
    for file in files:
      with open(os.path.join(filepath, file), 'rb',) as f:
          lines=f.readlines()[4:]
  
      folder = os.path.join('dataset', group)
      if not os.path.exists(folder):
          os.makedirs(folder)

      with open(os.path.join(folder, file), 'wb') as f:
          f.writelines(lines)


def shuffle(lst):
  temp_lst = deepcopy(lst)
  m = len(temp_lst)
  while (m):
    m -= 1
    i = randint(0, m)
    temp_lst[m], temp_lst[i] = temp_lst[i], temp_lst[m]
  return temp_lst

def file2wordlist(filename):
  with open(filename, 'rb') as f:
    lines = f.readlines()
    
  string_lines = [line.decode('utf-8', errors='ignore').strip() for line in lines]
  cleaned_lines = [re.sub(r'[^a-zA-Z\s]', ' ', line.replace("'s", "")) for line in string_lines]
  result_string = ' '.join(cleaned_lines)
  words = re.findall(r'\b[a-zA-Z]+\b', result_string)

  words = list(map(lambda x: x.lower(), words))

  return words

def word2prob(word, wordsdict, distinctwords, laplace=1):
    if word not in distinctwords:
        return np.ones((len(wordsdict)))
    prob = np.ones((len(wordsdict)))
    for i, group in enumerate(wordsdict):
        if word in group[1]:
            prob[i] = (group[1][word] + laplace)/(group[0]+len(distinctwords))
        else:
            prob[i] = laplace/(group[0]+len(distinctwords))
    return prob

In [3]:
# load data from folder 'dataset' and split into train and test
dataset = 'dataset'
train = []
test = []

for group in os.listdir(dataset):
    filepath = os.path.join(dataset, group)
    files = os.listdir(filepath)
    files = [os.path.join(filepath, f) for f in files]

    shuffle(files)
    train = train + [files[:len(files)//2]]
    test = test + [files[len(files)//2:]]

In [4]:
# prior P
PriorP = np.zeros(len(train))
total = 0
for i in range(len(train)):
    PriorP[i] = len(train[i])
    total += len(train[i])
PriorP /= total

In [5]:
# Conditional P
distinctwords = set()
wordsdict = []
merged_counter = Counter()
for group in train:
    allwords = []
    tempdic = {}
    for ele in group:
        words = file2wordlist(ele)

        allwords.extend(words)
    for w in allwords:
        if w in tempdic:
            tempdic[w] += 1
        else:
            tempdic[w] = 1

    distinctwords.update(allwords)
    wordsdict.append([len(allwords), tempdic])
    merged_counter += Counter(tempdic)


In [6]:
# set stop words
stopwords = [ele[0] for ele in merged_counter.most_common(300)]

In [7]:
temwordsdic = deepcopy(wordsdict)
temdistin = deepcopy(distinctwords)

for i in range(len(temwordsdic)):
    temwordsdic[i][1] = {k: v for k, v in temwordsdic[i][1].items() if k not in stopwords}
temdistin = {ele for ele in temdistin if ele not in stopwords}

# predict
predictLabels = []
trueLabels = []
for i, group in enumerate(test):
    for file in group:
        words = file2wordlist(file)
        overallsum = np.log2(PriorP)
        for w in words:
            prob = np.log2(word2prob(w, temwordsdic, temdistin))
            overallsum += prob
        predictLabels.append(np.argmax(overallsum))
        trueLabels.append(i)

In [9]:
# results
report = classification_report(trueLabels, predictLabels)
avg_acc = accuracy_score(trueLabels, predictLabels)
print('average accuracy is ', avg_acc)
print(report)

average accuracy is  0.8052805280528053
              precision    recall  f1-score   support

           0       0.79      0.73      0.76       500
           1       0.62      0.76      0.68       500
           2       0.83      0.58      0.68       500
           3       0.62      0.76      0.68       500
           4       0.81      0.77      0.79       500
           5       0.83      0.75      0.79       500
           6       0.70      0.81      0.75       500
           7       0.86      0.89      0.87       500
           8       0.92      0.94      0.93       500
           9       0.94      0.90      0.92       500
          10       0.92      0.96      0.94       500
          11       0.91      0.90      0.91       500
          12       0.75      0.64      0.69       500
          13       0.90      0.82      0.86       500
          14       0.89      0.90      0.90       500
          15       0.94      0.99      0.97       499
          16       0.73      0.90      0.