# Naive Bayes

`scikit-learn` provides multiple implementations of Naive Bayes that differ on how conditional probabilities are calculated. So the different implementations are suitable for different types of data. 

- `CategorialNB` will work with categorical data once it is processed using an `OrdinalEncoder`
- `GaussianNB` assumes the numerical features have a Gaussian distribution
- `BernoulliNB` binary data
- `MultinomialNB` count data, e.g. word counts

In [None]:
import pandas as pd
from sklearn.naive_bayes import MultinomialNB, GaussianNB, BernoulliNB, CategoricalNB
from sklearn.metrics import confusion_matrix 
from sklearn.preprocessing import OneHotEncoder

In [None]:
swim = pd.read_csv('data/Swimming.csv')
swim

## Categorical NB

In [None]:
from sklearn.preprocessing import OrdinalEncoder

In [None]:
y = swim.pop('Swimming').values # Set this as the y (target)
print(swim.columns)
print(y)

In [None]:
ord_encoder = OrdinalEncoder()
swimOE = ord_encoder.fit_transform(swim)
swimOE

What ordering do you get for the ordinal features?   
Look up the documentation to see how you would fix this if you want to use the OrdinalEncoder

In [None]:
catNB = CategoricalNB(fit_prior=True, alpha=.0001)
swim_catNB = catNB.fit(swimOE,y)
y_dash = swim_catNB.predict(swimOE)  #predict training data
confusion = confusion_matrix(y, y_dash)
print("Confusion matrix:\n{}".format(confusion)) 

In [None]:
# Three query examples, one from the training data and two others.

squery = pd.DataFrame([["Moderate","Moderate","Warm","Light","Some"],
                       ["Moderate","Moderate","Cold","Moderate","Some"],
                       ["Moderate","Light","Warm","Light","None"]
                      ], columns=swim.columns)

In [None]:
X_query = ord_encoder.transform(squery)
X_query, X_query.shape

In [None]:
y_query = swim_catNB.predict(X_query)
y_query

In [None]:
q_probs = swim_catNB.predict_proba(X_query)  # get the probabilities of each class for each query
q_probs

In [None]:
swim_catNB.get_params()    # check what parameters mean...

### One-Hot-Encode the training data
Here we use one-hot encoding to convert to the Swimming dataset to a numeric format.   
This converts the data to a binary format so it is valid to use `BernoulliNB` and possibly `MultinormialNB` - `GaussianNB` not so much. 

In [None]:
swim = pd.read_csv('data/Swimming.csv')
y = swim.pop('Swimming').values # Set this as the y (target)


onehot_encoder = OneHotEncoder(sparse=False)
swimOH = onehot_encoder.fit_transform(swim)
swimOH

In [None]:
onehot_encoder.get_feature_names_out(swim.columns)

In [None]:
gnb = GaussianNB()
swim_gNB = gnb.fit(swimOH,y)
y_dash = swim_gNB.predict(swimOH)
confusion = confusion_matrix(y, y_dash)
print("Confusion matrix:\n{}".format(confusion)) 

In [None]:
mnb = MultinomialNB()
swim_mNB = mnb.fit(swimOH,y)
y_dash = swim_mNB.predict(swimOH)
confusion = confusion_matrix(y, y_dash)
print("Confusion matrix:\n{}".format(confusion)) 

In [None]:
bnb = BernoulliNB()
swim_bNB = bnb.fit(swimOH,y)
y_dash = swim_bNB.predict(swimOH)
confusion = confusion_matrix(y, y_dash)
print("Confusion matrix:\n{}".format(confusion)) 

In [None]:
# Three query examples, one from the training data and two others.

squery = pd.DataFrame([["Moderate","Moderate","Warm","Light","Some"],
                       ["Moderate","Moderate","Cold","Moderate","Some"],
                       ["Moderate","Light","Warm","Light","None"]
                      ], columns=swim.columns)

In [None]:
# encode the query examples
X_query = onehot_encoder.transform(squery)
X_query, X_query.shape

#### Output the probabilities for the query examples for the three NB classifiers

In [None]:
y_query = swim_gNB.predict(X_query)
y_query

In [None]:
q_probs = swim_gNB.predict_proba(X_query)  
q_probs

In [None]:
y_query = swim_mNB.predict(X_query)
y_query

In [None]:
q_probs = swim_mNB.predict_proba(X_query)  
q_probs

In [None]:
y_query = swim_bNB.predict(X_query)
y_query

In [None]:
q_probs = swim_bNB.predict_proba(X_query)  
q_probs

#### Looking at the probabilities, which NB classifier appears to work best for this data?