In [22]:
import pandas as pd
from ucimlrepo import fetch_ucirepo
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import CategoricalNB
from sklearn.metrics import accuracy_score, classification_report, recall_score
from sklearn.metrics import confusion_matrix, classification_report

In [2]:
# fetch dataset 
mushroom = fetch_ucirepo(id=73) 
  
# data (as pandas dataframes) 
X = mushroom.data.features 
y = mushroom.data.targets 
  
# metadata 
print(mushroom.metadata) 
  
# variable information 
print(mushroom.variables) 

{'uci_id': 73, 'name': 'Mushroom', 'repository_url': 'https://archive.ics.uci.edu/dataset/73/mushroom', 'data_url': 'https://archive.ics.uci.edu/static/public/73/data.csv', 'abstract': 'From Audobon Society Field Guide; mushrooms described in terms of physical characteristics; classification: poisonous or edible', 'area': 'Biology', 'tasks': ['Classification'], 'characteristics': ['Multivariate'], 'num_instances': 8124, 'num_features': 22, 'feature_types': ['Categorical'], 'demographics': [], 'target_col': ['poisonous'], 'index_col': None, 'has_missing_values': 'yes', 'missing_values_symbol': 'NaN', 'year_of_dataset_creation': 1981, 'last_updated': 'Thu Aug 10 2023', 'dataset_doi': '10.24432/C5959T', 'creators': [], 'intro_paper': None, 'additional_info': {'summary': "This data set includes descriptions of hypothetical samples corresponding to 23 species of gilled mushrooms in the Agaricus and Lepiota Family (pp. 500-525).  Each species is identified as definitely edible, definitely po

In [3]:
mushroom.data.keys()

dict_keys(['ids', 'features', 'targets', 'original', 'headers'])

In [4]:
X.head()

Unnamed: 0,cap-shape,cap-surface,cap-color,bruises,odor,gill-attachment,gill-spacing,gill-size,gill-color,stalk-shape,...,stalk-surface-below-ring,stalk-color-above-ring,stalk-color-below-ring,veil-type,veil-color,ring-number,ring-type,spore-print-color,population,habitat
0,x,s,n,t,p,f,c,n,k,e,...,s,w,w,p,w,o,p,k,s,u
1,x,s,y,t,a,f,c,b,k,e,...,s,w,w,p,w,o,p,n,n,g
2,b,s,w,t,l,f,c,b,n,e,...,s,w,w,p,w,o,p,n,n,m
3,x,y,w,t,p,f,c,n,n,e,...,s,w,w,p,w,o,p,k,s,u
4,x,s,g,f,n,f,w,b,k,t,...,s,w,w,p,w,o,e,n,a,g


In [5]:
X.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 8124 entries, 0 to 8123
Data columns (total 22 columns):
 #   Column                    Non-Null Count  Dtype 
---  ------                    --------------  ----- 
 0   cap-shape                 8124 non-null   object
 1   cap-surface               8124 non-null   object
 2   cap-color                 8124 non-null   object
 3   bruises                   8124 non-null   object
 4   odor                      8124 non-null   object
 5   gill-attachment           8124 non-null   object
 6   gill-spacing              8124 non-null   object
 7   gill-size                 8124 non-null   object
 8   gill-color                8124 non-null   object
 9   stalk-shape               8124 non-null   object
 10  stalk-root                5644 non-null   object
 11  stalk-surface-above-ring  8124 non-null   object
 12  stalk-surface-below-ring  8124 non-null   object
 13  stalk-color-above-ring    8124 non-null   object
 14  stalk-color-below-ring  

In [6]:
y.value_counts('poisonous')

poisonous
e    4208
p    3916
Name: count, dtype: int64

So we have a mix of categorical and binary features which are used to predict a binary variable; whether or not a mushroom is poisonous. After some research, we should start with a Naive Bayes model and then try tree-based models to compare. Will have to encode the data first but it should be fine.

In [7]:
# First lets check if there is any missing data

nulls_X_df = X[X.isna().any(axis=1)]
nulls_X_df['stalk-root']

3984    NaN
4023    NaN
4076    NaN
4100    NaN
4104    NaN
       ... 
8119    NaN
8120    NaN
8121    NaN
8122    NaN
8123    NaN
Name: stalk-root, Length: 2480, dtype: object

All right lots of NaN values in the stalk-root column. Lets explore to see what the other values look like before we decide what to do about it.

In [8]:
X.value_counts('stalk-root')

stalk-root
b    3776
e    1120
c     556
r     192
Name: count, dtype: int64

Okay so theres four categories and about 1/4 of values are missing. Im interested in what NaN means. I would like to ask if the data is truly missing or if NaN is category, like these mushrooms dont have stalk-roots. For now, we will drop the column.

In [9]:
X = X.drop(columns=['stalk-root'])

In [10]:
# Lets use pandas get dummies to encode the various categorical data

X = pd.get_dummies(X, drop_first=True)
y = pd.get_dummies(y, drop_first=True)

In [11]:
# Now lets split the data into test and train data

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, stratify=y, random_state=42)

In [12]:
# Build the model
model = CategoricalNB()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print("Categorical Naive Bayes Accuracy:", accuracy)

Categorical Naive Bayes Accuracy: 0.9369768586903003


  y = column_or_1d(y, warn=True)


Accuracy of 93.6% is pretty good. Lets try and see if we can find where the model is failing

In [13]:
# Get all the values where the prediction is wrong

y_incorrect = y_test[y_test['poisonous_p'] != y_pred]

In [14]:
y_incorrect.value_counts()

poisonous_p
True           119
False            9
Name: count, dtype: int64

The tester data has 119 poisonous mushrooms that the prediction says are safe...yikes

In [15]:
pd.DataFrame({'poisonous': y_pred}).value_counts()

poisonous
False        1162
True          869
Name: count, dtype: int64

In [16]:
y_test.value_counts()

poisonous_p
False          1052
True            979
Name: count, dtype: int64

I wonder if we can tell the model that it is crucial we not deem a poisonous mushroom as safe, and less crucial otherwise. The thing we are looking for is a high recall, the ratio between the number of mushrooms that the model said are poisonous to the number that are actually poisonous. In this case we dont care about accuracy, we just need to make sure no one eats a poisonous mushroom. We care about recall; the ratio between how many the model says are poisonous and  how many truly are poisonous.

In [20]:
recall = recall_score(y_test, y_pred)
print('Categorical Naive Bayes recall: ', recall)

Categorical Naive Bayes recall:  0.8784473953013279


In [23]:
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))

[[1043    9]
 [ 119  860]]
              precision    recall  f1-score   support

       False       0.90      0.99      0.94      1052
        True       0.99      0.88      0.93       979

    accuracy                           0.94      2031
   macro avg       0.94      0.93      0.94      2031
weighted avg       0.94      0.94      0.94      2031

