# Project 2: NB Classifier

### Course: CS 5420

### Author: Cooper Wooley

In [None]:
import tarfile
import tempfile
import shutil
import os
import atexit
from collections import defaultdict, Counter
import random

### Helper Functions for Managing Dataset

In [9]:
def extract_dataset(zip_path):
    # Create a temporary directory to extract into
    temp_dir = tempfile.mkdtemp(prefix="dataset_")

    # Extract contents
    with tarfile.open(zip_path, 'r:gz') as tar_ref:
        tar_ref.extractall(temp_dir)

    # Register cleanup handler so even if program crashes, data is removed
    atexit.register(lambda: cleanup_dataset(temp_dir))

    # Find the first subdirectory inside extracted directory
    contents = [os.path.join(temp_dir, d) for d in os.listdir(temp_dir)]
    subdirs = [d for d in contents if os.path.isdir(d)]

    if len(subdirs) == 1:
        data_root = subdirs[0]
    else:
        data_root = temp_dir # fallback if already data root

    return data_root

def cleanup_dataset(directory):
    if os.path.exists(directory):
        shutil.rmtree(directory)
        print(f"Cleaned up dataset directory: {directory}")


### Split Data

In [10]:
tar_path = "20_newsgroups.tar.gz"
extracted_path = extract_dataset(tar_path)
print(f"Dataset extracted to : {extracted_path}")

def split_dataset(base_dir, train_ratio=0.5, seed=42):
    random.seed(seed)

    train_files = []
    test_files = []
    train_labels = []
    test_labels = []

    for d in os.listdir(base_dir):
        d_path = os.path.join(base_dir, d)
        if not os.path.isdir(d_path):
            continue

        files = [
            os.path.join(d_path, f)
            for f in os.listdir(d_path)
            if os.path.isfile(os.path.join(d_path, f))
        ]

        random.shuffle(files)
        split_index = int(len(files) * train_ratio)

        train_files.extend(files[:split_index])
        test_files.extend(files[split_index:])
        train_labels.extend([d] * split_index)
        test_labels.extend([d] * split_index)

    return train_files, test_files, train_labels, test_labels # train_X, test_Y, train_Y, test_Y

  tar_ref.extractall(temp_dir)


Dataset extracted to : C:\Users\coope\AppData\Local\Temp\dataset_24i5zz9m\20_newsgroups


## NB Classifier

### Training

In [11]:
def train_naive_bayes(train_files, train_labels):
    word_counts = defaultdict(Counter) # class: {word: count}
    classes = set(train_labels)
    total_docs = len(train_labels)

    for path, label in zip(train_files, train_labels):
        with open(path, 'r', errors='ignore') as f:
            words = f.read().lower().split()
            word_counts[label].update(words)

    # Compute P(Y)
    priors = {}
    classes = set(train_labels)
    for cls in classes:
        priors[cls] = train_labels.count(cls) / total_docs

    # Compute P(X|Y)
    likelihoods = {}

    for cls, words in word_counts.items():
        total_words = sum(words.values())
        class_likelihoods = {}
        for word, count in words.items():
            class_likelihoods[word] = count / total_words
        likelihoods[cls] = class_likelihoods
        
    return priors, likelihoods

### Predicting

In [12]:
def predict(text, priors, likelihoods):
    with open(text, 'r', errors='ignore') as f:
        words = f.read().lower().split()

    score = {}
    for cls, _ in priors.items():
        score[cls] = priors[cls]
        for word in words:
            if word not in likelihoods[cls]:
                continue
            score[cls] *= likelihoods[cls][word]
            
    return max(score, key=score.get)
    

### Evaluation

In [13]:
def evaluate(test_files, test_labels, priors, likelihoods):
    correct = 0
    for file, y in zip(test_files, test_labels):
        y_hat = predict(file, priors, likelihoods)
        if y_hat == y:
            correct += 1
    return correct / len(test_files)


In [14]:
train_files, test_files, train_labels, test_labels = split_dataset(extracted_path)

priors, likelihoods = train_naive_bayes(train_files, train_labels)

accuracy = evaluate(test_files, test_labels, priors, likelihoods)
print(f"Accuracy: {accuracy:.2f}")

Accuracy: 0.03


In [15]:
cleanup_dataset(extracted_path)

Cleaned up dataset directory: C:\Users\coope\AppData\Local\Temp\dataset_24i5zz9m\20_newsgroups
