In [3]:
import numpy as np
import pandas as pd
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer

In [7]:
class NaiveBayes:
    
    def fit(self, X, y):
        n_samples, n_features = X.shape
        # _ - private attribute or method of a class (accessed only within the class)
        self._classes = np.unique(y)
        n_classes = len(self._classes)
        self._psis = np.zeros((n_classes, n_features), dtype=float)
        self._phis = np.zeros(n_classes, dtype=float)
        for k in range(n_classes):
            X_k = X[y == k]
            self._psis[k] = X_k.mean(axis=0)
            self._phis[k] = X_k.shape[0] / float(n_samples)
        self._psis = self._psis.clip(1e-14, 1-1e-14)
        
    def predict(self, X):
        y_pred = [self._predict(x) for x in X]
        return np.array(y_pred)
    
    def _predict(self, x):
        posteriors = []
        for k in range(len(self._classes)):
            logpy = np.log(self._phis[k])
            logpxy = x * np.log(self._psis[k]) + (1-x) * np.log(1 - self._psis[k])
            logpyx = np.sum(logpxy) + logpy
            posteriors.append(logpyx)
        return np.argmax(posteriors)

In [23]:
categories = ['alt.atheism', 'soc.religion.christian', 'comp.graphics', 'sci.med']
twenty_train = fetch_20newsgroups(subset='train', categories=categories, shuffle=True, 
                                  random_state=42)
count_vect = CountVectorizer(binary=True, max_features=10000)
X_train = count_vect.fit_transform(twenty_train.data).toarray()
y_train = twenty_train.target

In [24]:
nb = NaiveBayes()
nb.fit(X_train, y_train)

In [25]:
y_pred = nb.predict(X_train)
(y_pred == y_train).mean()

0.9920248116969429

In [26]:
twenty_test = fetch_20newsgroups(subset='test', categories=categories, shuffle=True,
                                random_state=42)
X_test = count_vect.transform(twenty_test.data).toarray()
y_test = twenty_test.target

In [28]:
y_pred_test = nb.predict(X_test)
(y_pred_test == y_test).mean()

0.9161118508655126