# Describes the categorical Bayes encoding of selected mushrooms features to feed in CNN

In [1]:
import os
import numpy as np
from numpy import asarray
from sklearn.preprocessing import OneHotEncoder
import pandas as pd
import joblib

Loading DataFrame with mushroom features

In [2]:
df_shrooms = pd.read_csv('Shrooms_species_corrected.csv')

In [3]:
df_shrooms.head(17)

Unnamed: 0,species,hymenium,hymenium_color,ring,cap_color,cap_scales
0,agaricus_arvensis,gills,pink-brown,yes,white,no
1,agaricus_augustus,gills,pink-brown,yes,brown,yes
2,agaricus_bitorquis,gills,pink-brown,yes,white,no
3,agaricus_campestris,gills,pink-brown,yes,white,no
4,agaricus_impudicus,gills,pink-brown,yes,brown,yes
5,agaricus_xanthodermus,gills,pink-brown,yes,white,yes
6,agrocybe_pediades,gills,brown,no,yellow,no
7,agrocybe_praecox,gills,pink-white,no,yellow,no
8,amanita_excelsa,gills,white,yes,brown,yes
9,amanita_fulva,gills,white,no,brown,no


In [4]:
df_shrooms.columns

Index(['species', 'hymenium', 'hymenium_color', 'ring', 'cap_color',
       'cap_scales'],
      dtype='object')

Loading DataFrame with families to the corresponding species

In [5]:
df_ylabel = pd.read_csv('Shrooms_families.csv')

In [6]:
df_ylabel_series = df_ylabel[df_ylabel.columns[0]]

In [7]:
df_ylabel_series

0           agaricaceae
1           agaricaceae
2           agaricaceae
3           agaricaceae
4           agaricaceae
             ...       
209          pluteaceae
210    tricholomataceae
211        polyporaceae
212        polyporaceae
213          boletaceae
Name: family, Length: 214, dtype: object

In [8]:
len(df_ylabel.family.unique())

35

OneHotEncoding the 35 families to which the 214 species belong

In [9]:
ohe = OneHotEncoder(sparse = False)

In [10]:
y_train = ohe.fit_transform(df_ylabel)

In [11]:
y_train.shape

(214, 35)

convert 2d array to series

In [12]:
y_train_series = pd.Series(y_train.tolist())

In [13]:
y_train_series

0      [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
1      [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
2      [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
3      [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
4      [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
                             ...                        
209    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
210    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
211    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
212    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
213    [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
Length: 214, dtype: object

In [14]:
directory = 'original_data/Sample_data/'

List the directory with test data

In [15]:
os.listdir(directory)

['67478_Lentinus_substrictus',
 '21151_Tricholoma_sulphureum',
 '14119_Galerina_marginata',
 '14881_Cuphophyllus_virgineus',
 '61351_Hygrocybe_coccinea',
 '67628_Sutorius_luridiformis',
 '15955_Lactifluus_volemus',
 '13099_Dermoloma_cuneifolium',
 '17296_Mycena_sanguinolenta',
 '14056_Flammulina_velutipes var. velutipes',
 '21155_Tricholoma_ustale',
 '14422_Gymnopilus_spectabilis',
 '20093_Russula_vesca',
 '61239_Coprinopsis_lagopus',
 '12587_Cortinarius_rubellus',
 '17238_Mycena_haematopus',
 '17924_Mucidula_mucida',
 '17289_Mycena_rosea',
 '12656_Cortinarius_torvus',
 '13739_Entoloma_rhodopolium',
 '10062_Agaricus_bitorquis',
 '10065_Agaricus_campestris',
 '11062_Butyriboletus_appendiculatus',
 '16117_Leccinum_versipelle',
 '14812_Hydropus_subalpinus',
 '11091_Caloboletus_radicans',
 '16660_Macrolepiota_procera',
 '16601_Lycoperdon_perlatum',
 '21143_Tricholoma_scalpturatum',
 '18911_Pluteus_phlebophorus',
 '14798_Hydnum_repandum',
 '14914_Hygrophorus_olivaceoalbus',
 '11068_Calobole

In [16]:
from sklearn.preprocessing import OrdinalEncoder

Before applying Bayesian encoding to the features, features are transformed with ordinal encoder

In [17]:
oe = OrdinalEncoder()

In [18]:
df_shrooms_oe = oe.fit_transform(df_shrooms.drop(columns = 'species'))

In [19]:
df_shrooms_oe.shape

(214, 5)

In [20]:
df_shrooms_oe

array([[0., 4., 1., 8., 0.],
       [0., 4., 1., 2., 1.],
       [0., 4., 1., 8., 0.],
       ...,
       [1., 8., 0., 2., 0.],
       [1., 8., 0., 2., 1.],
       [1., 6., 0., 2., 0.]])

Saving the ordinal encoder as it will be necessary to transform the end user input when predicting a mushroom's class

In [21]:
joblib.dump(oe,'New_pickles/' + "User_ordinal_encoder.pkl")

['New_pickles/User_ordinal_encoder.pkl']

Testing the transformation

In [22]:
oe.categories_

[array(['gills', 'pores', 'teeth'], dtype=object),
 array(['black', 'brown', 'gray', 'pink', 'pink-brown', 'pink-white',
        'red', 'violet', 'white', 'yellow'], dtype=object),
 array(['no', 'yes'], dtype=object),
 array(['black', 'blue', 'brown', 'gray', 'greenish', 'pink', 'red',
        'violet', 'white', 'yellow'], dtype=object),
 array(['no', 'yes'], dtype=object)]

Renaming the categories for convenience when creating the final DataFrame

In [23]:
cap_color = []
for i in oe.categories_[3]:
    cap_color.append('cap_'+i)

In [24]:
hymenium_color = []
for i in oe.categories_[1]:
    hymenium_color.append('hymenium_'+i)

In [25]:
columns_names = ['gills', 'pores', 'teeth', 
                 'hymenium_black', 'hymenium_brown', 'hymenium_gray', 'hymenium_pink', 'hymenium_pink-brown', 'hymenium_pink-white', 'hymenium_red', 'hymenium_violet', 'hymenium_white', 'hymenium_yellow',
                 'ring_no', 'ring_yes',
                 'cap_black', 'cap_blue', 'cap_brown', 'cap_gray', 'cap_greenish', 'cap_pink', 'cap_red', 'cap_violet', 'cap_white', 'cap_yellow',
                 'scales_no', 'scales_yes'
                 ]

In [26]:
df_shrooms_oe[1]

array([0., 4., 1., 2., 1.])

Encoding the features with categorical Naive Bayes

In [27]:
from sklearn.naive_bayes import CategoricalNB

In [28]:
cNB = CategoricalNB(alpha = 0.1)

In [29]:
cNB.fit(df_shrooms_oe, df_ylabel_series)

In [30]:
X_predict = cNB.predict_proba(df_shrooms_oe)

In [31]:
X_predict.shape

(214, 35)

In [32]:
joblib.dump(cNB, 'New_pickles/User_cNB.pkl')

['New_pickles/User_cNB.pkl']

Testing the categorical encoder

user_family = np.argmax(user_cNB_pred)

In [33]:
cNB.classes_[0]

'agaricaceae'

In [34]:
cNB.category_count_

[array([[17.,  0.,  5.],
        [ 9.,  0.,  0.],
        [ 1.,  0.,  0.],
        [ 0., 22.,  0.],
        [ 1.,  0.,  0.],
        [ 5.,  0.,  0.],
        [11.,  0.,  0.],
        [ 4.,  0.,  0.],
        [ 0.,  1.,  0.],
        [ 1.,  0.,  0.],
        [ 1.,  0.,  0.],
        [ 3.,  0.,  0.],
        [19.,  0.,  0.],
        [ 3.,  0.,  0.],
        [ 3.,  0.,  0.],
        [ 6.,  0.,  0.],
        [ 2.,  0.,  0.],
        [ 4.,  0.,  0.],
        [ 0.,  0.,  4.],
        [14.,  0.,  0.],
        [ 4.,  0.,  0.],
        [ 1.,  0.,  0.],
        [ 0.,  0.,  1.],
        [ 6.,  0.,  0.],
        [ 2.,  0.,  0.],
        [ 4.,  0.,  0.],
        [ 0.,  2.,  1.],
        [10.,  0.,  0.],
        [ 1.,  0.,  0.],
        [14.,  0.,  0.],
        [ 0.,  2.,  0.],
        [11.,  0.,  0.],
        [ 0.,  2.,  0.],
        [16.,  0.,  0.],
        [ 1.,  0.,  0.]]),
 array([[ 0.,  2.,  0.,  0.,  6.,  0.,  0.,  0., 14.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  9.,  0.],
   

In [35]:
cNB.classes_

array(['agaricaceae', 'amanita', 'bolbitiaceae', 'boletaceae',
       'bulgariaceae', 'cantharellaceae', 'cortinariaceae',
       'entolomataceae', 'geastraceae', 'helvellaceae', 'hydnaceae',
       'hydnangiaceae', 'hygrophoraceae', 'hymenogastraceae', 'incertae',
       'inocybaceae', 'lyophyllaceae', 'marasmiaceae', 'morchellaceae',
       'mycenaceae', 'omphalotaceae', 'paxillaceae', 'phallaceae',
       'physalacriaceae', 'pleurotaceae', 'pluteaceae', 'polyporaceae',
       'psathyrelaceae', 'repetobasidiaceae', 'russulaceae',
       'sclerodermataceae', 'strophariaceae', 'suillaceae',
       'tricholomataceae', 'tubariaceae'], dtype='<U17')

In [36]:
cNB.feature_log_prob_

[array([[-0.26550821, -5.40717177, -1.47534614],
        [-0.02173999, -4.53259949, -4.53259949],
        [-0.16705408, -2.56494936, -2.56494936],
        [-5.40717177, -0.00900907, -5.40717177],
        [-0.16705408, -2.56494936, -2.56494936],
        [-0.03846628, -3.97029191, -3.97029191],
        [-0.01785762, -4.72738782, -4.72738782],
        [-0.04762805, -3.76120012, -3.76120012],
        [-2.56494936, -0.16705408, -2.56494936],
        [-0.16705408, -2.56494936, -2.56494936],
        [-0.16705408, -2.56494936, -2.56494936],
        [-0.06252036, -3.49650756, -3.49650756],
        [-0.01041676, -5.26269019, -5.26269019],
        [-0.06252036, -3.49650756, -3.49650756],
        [-0.06252036, -3.49650756, -3.49650756],
        [-0.03226086, -4.14313473, -4.14313473],
        [-0.09097178, -3.13549422, -3.13549422],
        [-0.04762805, -3.76120012, -3.76120012],
        [-3.76120012, -3.76120012, -0.04762805],
        [-0.01408474, -4.96284463, -4.96284463],
        [-0.04762805

Creating a DataFrame with the features log probability

In [37]:
df_feature_prob = pd.DataFrame(cNB.feature_log_prob_[0], columns = range(0,3))

In [38]:
df = pd.DataFrame(np.hstack(cNB.feature_log_prob_))

In [39]:
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,17,18,19,20,21,22,23,24,25,26
0,-0.265508,-5.407172,-1.475346,-5.438079,-2.393557,-5.438079,-5.438079,-1.327205,-5.438079,-5.438079,...,-0.728549,-3.040184,-5.438079,-5.438079,-5.438079,-5.438079,-0.92722,-3.040184,-1.008228,-0.453917
1,-0.02174,-4.532599,-4.532599,-4.60517,-4.60517,-4.60517,-4.60517,-4.60517,-4.60517,-4.60517,...,-0.891598,-4.60517,-2.207275,-4.60517,-2.207275,-4.60517,-1.171183,-4.60517,-0.808217,-0.589963
2,-0.167054,-2.564949,-2.564949,-2.995732,-0.597837,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,...,-0.597837,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-0.087011,-2.484907
3,-5.407172,-0.009009,-5.407172,-5.438079,-5.438079,-5.438079,-5.438079,-5.438079,-5.438079,-3.040184,...,-0.562882,-2.004092,-5.438079,-5.438079,-1.724507,-5.438079,-5.438079,-2.393557,-0.150404,-1.96869
4,-0.167054,-2.564949,-2.564949,-0.597837,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,...,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-0.087011,-2.484907
5,-0.038466,-3.970292,-3.970292,-4.094345,-4.094345,-4.094345,-4.094345,-4.094345,-4.094345,-4.094345,...,-1.049822,-4.094345,-4.094345,-4.094345,-4.094345,-4.094345,-1.696449,-1.049822,-0.019418,-3.951244
6,-0.017858,-4.727388,-4.727388,-4.787492,-0.676618,-4.787492,-1.353505,-4.787492,-4.787492,-2.389596,...,-0.676618,-4.787492,-4.787492,-4.787492,-2.389596,-4.787492,-2.389596,-1.353505,-0.207639,-1.673976
7,-0.047628,-3.7612,-3.7612,-3.912023,-3.912023,-1.514128,-3.912023,-3.912023,-3.912023,-3.912023,...,-1.514128,-0.867501,-3.912023,-3.912023,-3.912023,-3.912023,-1.514128,-3.912023,-0.024098,-3.73767
8,-2.564949,-0.167054,-2.564949,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,...,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-0.597837,-2.995732,-0.087011,-2.484907
9,-0.167054,-2.564949,-2.564949,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,...,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-0.597837,-2.995732,-0.087011,-2.484907


In [40]:
df.columns = columns_names

In [41]:
df['family'] = cNB.classes_

In [42]:
df.loc[[0,15]]

Unnamed: 0,gills,pores,teeth,hymenium_black,hymenium_brown,hymenium_gray,hymenium_pink,hymenium_pink-brown,hymenium_pink-white,hymenium_red,...,cap_gray,cap_greenish,cap_pink,cap_red,cap_violet,cap_white,cap_yellow,scales_no,scales_yes,family
0,-0.265508,-5.407172,-1.475346,-5.438079,-2.393557,-5.438079,-5.438079,-1.327205,-5.438079,-5.438079,...,-3.040184,-5.438079,-5.438079,-5.438079,-5.438079,-0.92722,-3.040184,-1.008228,-0.453917,agaricaceae
15,-0.032261,-4.143135,-4.143135,-4.248495,-4.248495,-4.248495,-4.248495,-4.248495,-4.248495,-4.248495,...,-4.248495,-4.248495,-4.248495,-4.248495,-1.8506,-1.203973,-4.248495,-0.016261,-4.127134,inocybaceae


In [43]:
df[df.columns[0:27]]

Unnamed: 0,gills,pores,teeth,hymenium_black,hymenium_brown,hymenium_gray,hymenium_pink,hymenium_pink-brown,hymenium_pink-white,hymenium_red,...,cap_brown,cap_gray,cap_greenish,cap_pink,cap_red,cap_violet,cap_white,cap_yellow,scales_no,scales_yes
0,-0.265508,-5.407172,-1.475346,-5.438079,-2.393557,-5.438079,-5.438079,-1.327205,-5.438079,-5.438079,...,-0.728549,-3.040184,-5.438079,-5.438079,-5.438079,-5.438079,-0.92722,-3.040184,-1.008228,-0.453917
1,-0.02174,-4.532599,-4.532599,-4.60517,-4.60517,-4.60517,-4.60517,-4.60517,-4.60517,-4.60517,...,-0.891598,-4.60517,-2.207275,-4.60517,-2.207275,-4.60517,-1.171183,-4.60517,-0.808217,-0.589963
2,-0.167054,-2.564949,-2.564949,-2.995732,-0.597837,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,...,-0.597837,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-0.087011,-2.484907
3,-5.407172,-0.009009,-5.407172,-5.438079,-5.438079,-5.438079,-5.438079,-5.438079,-5.438079,-3.040184,...,-0.562882,-2.004092,-5.438079,-5.438079,-1.724507,-5.438079,-5.438079,-2.393557,-0.150404,-1.96869
4,-0.167054,-2.564949,-2.564949,-0.597837,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,...,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-0.087011,-2.484907
5,-0.038466,-3.970292,-3.970292,-4.094345,-4.094345,-4.094345,-4.094345,-4.094345,-4.094345,-4.094345,...,-1.049822,-4.094345,-4.094345,-4.094345,-4.094345,-4.094345,-1.696449,-1.049822,-0.019418,-3.951244
6,-0.017858,-4.727388,-4.727388,-4.787492,-0.676618,-4.787492,-1.353505,-4.787492,-4.787492,-2.389596,...,-0.676618,-4.787492,-4.787492,-4.787492,-2.389596,-4.787492,-2.389596,-1.353505,-0.207639,-1.673976
7,-0.047628,-3.7612,-3.7612,-3.912023,-3.912023,-1.514128,-3.912023,-3.912023,-3.912023,-3.912023,...,-1.514128,-0.867501,-3.912023,-3.912023,-3.912023,-3.912023,-1.514128,-3.912023,-0.024098,-3.73767
8,-2.564949,-0.167054,-2.564949,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,...,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-0.597837,-2.995732,-0.087011,-2.484907
9,-0.167054,-2.564949,-2.564949,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,...,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-2.995732,-0.597837,-2.995732,-0.087011,-2.484907


Transforming the log probabilities

In [44]:
norm = lambda x: 10 ** x

In [45]:
df1 = df[df.columns[0:27]].apply(norm)

In [46]:
df1.head()

Unnamed: 0,gills,pores,teeth,hymenium_black,hymenium_brown,hymenium_gray,hymenium_pink,hymenium_pink-brown,hymenium_pink-white,hymenium_red,...,cap_brown,cap_gray,cap_greenish,cap_pink,cap_red,cap_violet,cap_white,cap_yellow,scales_no,scales_yes
0,0.542615,4e-06,0.03347,4e-06,0.004041,4e-06,4e-06,0.047075,4e-06,4e-06,...,0.186832,0.000912,4e-06,4e-06,4e-06,4e-06,0.118244,0.000912,0.098123,0.351627
1,0.951174,2.9e-05,2.9e-05,2.5e-05,2.5e-05,2.5e-05,2.5e-05,2.5e-05,2.5e-05,2.5e-05,...,0.128352,2.5e-05,0.006205,2.5e-05,0.006205,2.5e-05,0.067424,2.5e-05,0.155519,0.257062
2,0.680685,0.002723,0.002723,0.00101,0.252443,0.00101,0.00101,0.00101,0.00101,0.00101,...,0.252443,0.00101,0.00101,0.00101,0.00101,0.00101,0.00101,0.00101,0.818443,0.003274
3,4e-06,0.97947,4e-06,4e-06,4e-06,4e-06,4e-06,4e-06,4e-06,0.000912,...,0.273601,0.009906,4e-06,4e-06,0.018858,4e-06,4e-06,0.004041,0.707288,0.010748
4,0.680685,0.002723,0.002723,0.252443,0.00101,0.00101,0.00101,0.00101,0.00101,0.00101,...,0.00101,0.00101,0.00101,0.00101,0.00101,0.00101,0.00101,0.00101,0.818443,0.003274


In [47]:
df1['family'] = cNB.classes_

In [48]:
df1

Unnamed: 0,gills,pores,teeth,hymenium_black,hymenium_brown,hymenium_gray,hymenium_pink,hymenium_pink-brown,hymenium_pink-white,hymenium_red,...,cap_gray,cap_greenish,cap_pink,cap_red,cap_violet,cap_white,cap_yellow,scales_no,scales_yes,family
0,0.542615,4e-06,0.03347,4e-06,0.004041,4e-06,4e-06,0.047075,4e-06,4e-06,...,0.000912,4e-06,4e-06,4e-06,4e-06,0.118244,0.000912,0.098123,0.351627,agaricaceae
1,0.951174,2.9e-05,2.9e-05,2.5e-05,2.5e-05,2.5e-05,2.5e-05,2.5e-05,2.5e-05,2.5e-05,...,2.5e-05,0.006205,2.5e-05,0.006205,2.5e-05,0.067424,2.5e-05,0.155519,0.257062,amanita
2,0.680685,0.002723,0.002723,0.00101,0.252443,0.00101,0.00101,0.00101,0.00101,0.00101,...,0.00101,0.00101,0.00101,0.00101,0.00101,0.00101,0.00101,0.818443,0.003274,bolbitiaceae
3,4e-06,0.97947,4e-06,4e-06,4e-06,4e-06,4e-06,4e-06,4e-06,0.000912,...,0.009906,4e-06,4e-06,0.018858,4e-06,4e-06,0.004041,0.707288,0.010748,boletaceae
4,0.680685,0.002723,0.002723,0.252443,0.00101,0.00101,0.00101,0.00101,0.00101,0.00101,...,0.00101,0.00101,0.00101,0.00101,0.00101,0.00101,0.00101,0.818443,0.003274,bulgariaceae
5,0.915237,0.000107,0.000107,8e-05,8e-05,8e-05,8e-05,8e-05,8e-05,8e-05,...,8e-05,8e-05,8e-05,8e-05,8e-05,0.020116,0.089162,0.956273,0.000112,cantharellaceae
6,0.959715,1.9e-05,1.9e-05,1.6e-05,0.210563,1.6e-05,0.044309,1.6e-05,1.6e-05,0.004078,...,1.6e-05,1.6e-05,1.6e-05,0.004078,1.6e-05,0.004078,0.044309,0.619956,0.021185,cortinariaceae
7,0.896132,0.000173,0.000173,0.000122,0.000122,0.030611,0.000122,0.000122,0.000122,0.000122,...,0.135675,0.000122,0.000122,0.000122,0.000122,0.030611,0.000122,0.946025,0.000183,entolomataceae
8,0.002723,0.680685,0.002723,0.00101,0.00101,0.00101,0.00101,0.00101,0.00101,0.00101,...,0.00101,0.00101,0.00101,0.00101,0.00101,0.252443,0.00101,0.818443,0.003274,geastraceae
9,0.680685,0.002723,0.002723,0.00101,0.00101,0.00101,0.00101,0.00101,0.00101,0.00101,...,0.00101,0.00101,0.00101,0.00101,0.00101,0.252443,0.00101,0.818443,0.003274,helvellaceae


Creating a DataFrame containing the transformed user input

Creating completed DataFrame to be used when fitting the model