This is a guide to do implement Naive Bayes Classifier in Python 3.

In [2]:
# Import the necessary Python libraries

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys
from time import time
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score
from sklearn import model_selection
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

In [3]:
# Import the data set

# Load dataset
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
names = [ 'sepal-length','sepal-width', 'petal-length', 'petal-width', 'class']
dataset = pd.read_csv(url, names=names)


In [4]:
# Check the top 5 records

dataset.head(5)

Unnamed: 0,sepal-length,sepal-width,petal-length,petal-width,class
0,5.1,3.5,1.4,0.2,Iris-setosa
1,4.9,3.0,1.4,0.2,Iris-setosa
2,4.7,3.2,1.3,0.2,Iris-setosa
3,4.6,3.1,1.5,0.2,Iris-setosa
4,5.0,3.6,1.4,0.2,Iris-setosa


In [5]:
# Check the bottom 5 records

dataset.tail(5)

Unnamed: 0,sepal-length,sepal-width,petal-length,petal-width,class
145,6.7,3.0,5.2,2.3,Iris-virginica
146,6.3,2.5,5.0,1.9,Iris-virginica
147,6.5,3.0,5.2,2.0,Iris-virginica
148,6.2,3.4,5.4,2.3,Iris-virginica
149,5.9,3.0,5.1,1.8,Iris-virginica


In [6]:
# View column names

dataset.columns

Index(['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'class'], dtype='object')

In [7]:
# getting general information about a dataframe

dataset.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 150 entries, 0 to 149
Data columns (total 5 columns):
sepal-length    150 non-null float64
sepal-width     150 non-null float64
petal-length    150 non-null float64
petal-width     150 non-null float64
class           150 non-null object
dtypes: float64(4), object(1)
memory usage: 7.0+ KB


In [8]:
# Finding out the total number of rows and total no of columns of a dataframe

print("Tota No of Rows:" + str(dataset.shape[0]))
print("Tota No of Columns:" + str(dataset.shape[1]))

Tota No of Rows:150
Tota No of Columns:5


In [9]:
# getting basic statistics about all numerical columns:

dataset.describe()

Unnamed: 0,sepal-length,sepal-width,petal-length,petal-width
count,150.0,150.0,150.0,150.0
mean,5.843333,3.054,3.758667,1.198667
std,0.828066,0.433594,1.76442,0.763161
min,4.3,2.0,1.0,0.1
25%,5.1,2.8,1.6,0.3
50%,5.8,3.0,4.35,1.3
75%,6.4,3.3,5.1,1.8
max,7.9,4.4,6.9,2.5


In [10]:
# Finding Unique values in a column

dataset["class"].unique()

array(['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'], dtype=object)

In [11]:
# Counting values of a column:

dataset["class"].value_counts()

Iris-virginica     50
Iris-versicolor    50
Iris-setosa        50
dtype: int64

In [12]:
# Renaming column names
# 1. Creating a new dataframe with new column name

dataset1 = dataset.rename(columns={'class':'flower_type'})
print(dataset1.columns)



Index(['sepal-length', 'sepal-width', 'petal-length', 'petal-width',
       'flower_type'],
      dtype='object')


In [13]:
# Make a copy of the dataframe:

dataset2 = dataset.copy()


In [8]:
# Finding Missing values in a particular column

sum(dataset['class'].isnull())

# Finding Missing values in all columns

def num_missing(x):
  return sum(x.isnull())

print(dataset.apply(num_missing))
   

sepal-length    0
sepal-width     0
petal-length    0
petal-width     0
class           0
dtype: int64


In [15]:
# Split the dataset into Training and Validation

array = dataset.values
X = array[:,0:4]
Y = array[:,4]
validation_size = 0.20
seed = 7
X_train, X_validation, Y_train, Y_validation = model_selection.train_test_split(X, Y, test_size=validation_size, random_state=seed)

In [17]:
# Build the model

clf = GaussianNB()
t0 = time()
clf.fit(X_train, Y_train)
print("Training time:", round(time()-t0, 3), "s")
t1 = time()
pred = clf.predict(X_validation)
print("Prediction time:", round(time()-t1, 3), "s")
accuracy = accuracy_score(Y_validation, pred)
print(accuracy)

Training time: 0.002 s
Prediction time: 0.001 s
0.833333333333


In [24]:
clf.predict([[100000000,1000,1000,1000]])

array(['Iris-virginica'], 
      dtype='<U15')