In [48]:
import numpy as np
import itertools as it

class NaiveBayes:
    def __init__(self, data, classes=2):
        self._class_cnt = classes
        self._attr_cnt = len(data[0,:]) - 1
        self._classes = list(range(self._class_cnt))
        self._attrs = list(range(self._attr_cnt))
        self._vars = np.zeros((classes, self._attr_cnt), dtype='float32')
        self._avgs = np.zeros((classes, self._attr_cnt), dtype='float32')
        self._prior = np.zeros((self._class_cnt), dtype='float32')
        
        self.train(data)
        

    def train(self, data):
        for k in self._classes:
            data_k = np.array([x for x in data if x[0] == k])
            
            for i in self._attrs:
                self._vars[k, i] = np.var(data_k[:,i+1])
                self._avgs[k, i] = np.mean(data_k[:,i+1])
                    
            self._prior[k] = len([x for x in data if x[0]==k]) / len(data)
            
            
    def predict(self, x):
        def prod(x, k):
            acc = 1
            for i in self._attrs:
                acc *= self.posterior(list(x)[i], i, k)
            return acc

        return np.argmax([prod(x, k)*self._prior[k] for k in self._classes])
    
    
    def posterior(self, a_i, i, k):
        var = self._vars[k, i]
        avg = self._avgs[k, i]
        
        return ((2*np.pi*var)**-0.5 * (np.exp((-(a_i - avg)**2) / (2*var))))


In [126]:
from sklearn import datasets

iris = datasets.load_iris()
x = iris.data
y = iris.target

full_data = np.empty((150,5), dtype='float32')
for i in range(150):
    full_data[i,0] = y[i]
    full_data[i,1] = x[i,0]
    full_data[i,2] = x[i,1]
    full_data[i,3] = x[i,2]
    full_data[i,4] = x[i,3]
    
np.random.shuffle(full_data)

training_data = full_data[:130,:]
testing_data = full_data[130:,:]

test = NaiveBayes(training_data, classes=3)
predictions = [test.predict(x[1:5]) for x in testing_data]
predictions2 = [test.predict(x[1:5]) for x in full_data]

In [129]:
correct = 0
for i, _ in enumerate(predictions2):
    correct += 1 if predictions2[i] == full_data[i,0] else 0
    
print(correct)
print(len(full_data))

print('{}'.format(correct / len(full_data)))

144
150
0.96
