<a href="https://colab.research.google.com/github/nkrj01/Models-from-scratch/blob/main/Naivebayes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
from sklearn.datasets import load_breast_cancer
data_obj = load_breast_cancer()
import matplotlib.pyplot as plt
from scipy.stats import norm as gaussian

In [2]:
X = data_obj["data"]
y = data_obj["target"].reshape(-1, 1)
data = np.hstack((X, y))
print(data.shape)

(569, 31)


In [3]:
def getParameters(data):
  X = data[:, :-1]
  y = data[:, -1]
  labels = np.unique(y)
  X_class = {}

  for label in labels:
    index = np.where(y == label)
    X_class[label] = X[index]

  mean = {}
  std = {}
  prior = {}
  for label in labels:
    mean[label] = np.mean(X_class[label], axis=0)
    std[label] = np.std(X_class[label], axis=0)
    prior[label] = np.count_nonzero(y == label)/len(y)

  parameters = {
      "mean" : mean,
      "std" : std,
      'prior' : prior,
      'labels': labels
  }

  return parameters


def predict(x, parameter):
  posteriors = {}
  labels = parameter["labels"]
  for label in labels:
    p = 0
    for col in range(len(x)):
      mean = parameter["mean"][label][col]
      std = parameter["std"][label][col]
      gauss = gaussian(mean, std)
      likelihood = gauss.pdf(x[col])
      if likelihood == 0:
        likelihood = 1e-15 # to prevent underflow
      p = p + np.log(likelihood) # adding all log likelihood

    posteriors[label] = p + np.log(parameter["prior"][label]) # adding prior

  posterior = -1e15
  for key in posteriors.keys():
    if (posteriors[key] > posterior):
      posterior = posteriors[key]
      answer = key
  return answer

In [4]:
parameter = getParameters(data)
m = X.shape[0]
y_predict = []
for row in range(m):
  x = X[row, :]
  y_predict.append(predict(x, parameter))

In [5]:
y_predict = np.array(y_predict)
y = y.flatten()
matching_elements = (y_predict == y)
count_matches = np.sum(matching_elements)
accuracy = count_matches/X.shape[0]
print(accuracy)

0.9402460456942003
