In [59]:
import numpy as np
import pandas as pd
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer

In [61]:
class NaiveBayes:

  def fit(self, X, y):
    n_samples, n_features = X.shape
    # _ - private attribute or method of a class (accessed only within the class)
    self._classes = np.unique(y)
    n_classes = len(self._classes)
    self._psis = np.zeros((n_classes, n_features), dtype=np.float64)
    self._phis = np.zeros(n_classes, dtype=np.float64)
    for k in range(n_classes):
      X_k = X[y == k]
      self._psis[k] = X_k.mean(axis=0)
      self._phis[k] = X_k.shape[0] / float(n_samples)
    self._psis = self._psis.clip(1e-14, 1-1e-14)

  def predict(self, X):
    y_pred = [self._predict(x) for x in X]
    return np.array(y_pred)

  def _predict(self, x):
    posteriors = []
    for k in range(len(self._classes)):
      logpy = np.log(self._phis[k])
      logpxy = x * np.log(self._psis[k]) + (1-x) * np.log(1-self._psis[k])
      logpyx = np.sum(logpxy) + logpy
      posteriors.append(logpyx)
    return np.argmax(posteriors)

In [7]:
categories = ['alt.atheism', 'soc.religion.christian', 'comp.graphics', 'sci.med']
twenty_train = fetch_20newsgroups(subset='train', categories=categories, shuffle=True, random_state=42)
count_vect = CountVectorizer(binary=True, max_features=1000)
y_train = twenty_train.target
X_train = count_vect.fit_transform(twenty_train.data).toarray()

In [62]:
nb = NaiveBayes()
nb.fit(X_train, y_train)

[0.21267169 0.25875055 0.26318121 0.26539654]


In [49]:
nb.predict(X_train)

array([1, 1, 3, ..., 2, 2, 2])

In [50]:
y_train

array([1, 1, 3, ..., 2, 2, 2])

In [51]:
y_pred = nb.predict(X_train)
(y_pred == y_train).mean()

0.8692955250332299

In [63]:
docs_test = ['GPU is graphics processing unit', 'catholic church poland']
X_test = count_vect.transform(docs_test).toarray()
y_pred_test = nb.predict(X_test)
for doc, category in zip(docs_test, y_pred_test):
  print('%r => %s' % (doc, twenty_train.target_names[category]))

'GPU is graphics processing unit' => comp.graphics
'catholic church poland' => comp.graphics
