# Naive Bayes (supervised machine learning)

In [None]:
# Import libraries
%matplotlib widget
from plantcv import plantcv as pcv 
from plantcv import utils
from plantcv import learn
import numpy as np
from plantcv.parallel import WorkflowInputs

In [None]:
pcv.__version__

To collect training data, download the program [Fiji](https://imagej.net/software/fiji/downloads#installation). We will use the tool "Pixel Inspector" to collect color values.

Collect pixel data from the PNG files found in the `imgs` folder: `wheat_rust1.png` & `wheat-rust2.png`.

In [None]:
# Input/output options
args = WorkflowInputs(
    #images=["img/color_image.jpg"],    
    images=["./imgs/wheat_rust1.png"],   
    names="image1",
    result="wheat-rust.csv",
    outdir=".",
    writeimg=True,
    debug="plot")

# Set debug to the global parameter 
pcv.params.debug = args.debug

In [None]:
# Read image 
img, path, filename = pcv.readimage(filename=args.image1)

In the image we are interested in identifying pixels that belong to four different groups (or classes):

`Plant`: the green parts of the wheat leaves

`Pustule`: the red/orange wheat rust infection foci

`Chlorosis`: the yellowing around each pustule

`Background`: Everything that does not belong to one of the other three categories

We will collect RGB values from the image by clicking on examples of each class and organizing the values in a file called `pixels_unformatted.txt` with this format:

```
#plant
96,154,72	95,153,72	91,155,71	91,160,70	90,155,67	92,152,66	92,157,70
54,104,39	56,104,38	59,106,41	57,105,43	54,104,40	54,103,35	56,101,39	58,99,41	59,99,41
#background
114,127,121	117,135,125	120,137,131	132,145,138	142,154,148	151,166,158	160,182,172
115,125,121	118,131,123	122,132,135	133,142,144	141,151,152	150,166,158	159,179,172
```

After you have created your table of RGB values, we need to reformat it for use in PlantCV using a command-line utility (but within the notebook).

In [None]:
# We need to format our tab-delimited file containing our pixel values (pixels_unformatted.txt)
# Hold Shift and press Tab to view the helper to see what input parameters are needed 
utils.tabulate_bayes_classes(input_file="pixels_unformatted.txt", output_file="pixels_formatted.txt")

Now that we have reformatted our table, we can train our naive Bayes classifier. Since we have 4 categories (`leaf`, `pustule`, `chlorosis`, and `background`), we will use `learn.naive_bayes_multiclass` to train our model.
For a more complete explanation on how naive Bayes classifier algorithms work, please visit this [Stacked Overflow comment](https://stackoverflow.com/a/20556654).

In [None]:
learn.naive_bayes_multiclass(samples_file="pixels_formatted.txt", outfile="wheat-rust-pdf.txt")

In [None]:
# Use the output file from `naive_bayes_multiclass` to run the multiclass 
# naive bayes classification on the image. The function below will 
# print out 4 masks (plant, pustule, chlorosis, background)

# Inputs: 
#   rgb_img - RGB image data 
#   pdf_file - Output file containing PDFs from `plantcv-train.py`
mask = pcv.naive_bayes_classifier(rgb_img=img, pdf_file="wheat-rust-pdf.txt")

In [None]:
# We can apply each mask to the original image to more accurately 
# see what got masked

pustule_img = pcv.apply_mask(mask=(mask['pustule']), img=img, mask_color='black')
chlorosis_img = pcv.apply_mask(mask=(mask['chlorosis']), img=img, mask_color='black')
plant_img = pcv.apply_mask(mask=(mask['plant']), img=img, mask_color='black')
background_img = pcv.apply_mask(mask=(mask['background']), img=img, mask_color='black')

In [None]:
# Write image and mask with the same name to the path 
# specified (creates two folders within the path if they do not exist).

plant_maskpath, plant_analysis_images = pcv.output_mask(img=img, mask=mask['plant'], 
                                                        filename='plant.png', mask_only=True)
pustule_maskpath, pustule_analysis_images = pcv.output_mask(img=img, mask=mask['pustule'], 
                                                      filename='pustule.png', mask_only=True)
chlorosis_maskpath, chlorosis_analysis_images = pcv.output_mask(img=img, mask=mask['chlorosis'], 
                                                      filename='chlorosis.png', mask_only=True)
bkgrd_maskpath, bkgrd_analysis_images = pcv.output_mask(img=img, mask=mask['background'], 
                                                        filename='background.png', mask_only=True)

In [None]:
# To see all of these masks together we can plot them with plant set to green,
# chlorosis set to gold, and pustule set to red.

classified_img = pcv.visualize.colorize_masks(masks=[mask['plant'],
                                                     mask['pustule'], 
                                                     mask['chlorosis'],
                                                     mask['background']],
                                              colors=['dark green', 'red', 'gold', 'gray'])
# Compare the merged masks with your original image
pcv.plot_image(img=img)

How did you do?

If you have pixels that are classified inappropriately, you need to revisit `pixels_unformatted.txt` and either add/remove pixels from the dataset. This is why naive Bayes is a machine learning algorithm that is developed through `supervised learning` (the model learns from the Agent).

## Extracting information

Now that we have successfully trained our algorithm, we can pull out the relevant information about our samples.

In [None]:
# Calculate percent of the plant found to be diseased 

sick_plant = np.count_nonzero(mask['pustule']) + np.count_nonzero(mask['chlorosis'])
healthy_plant = np.count_nonzero(mask['leaf'])
percent_diseased = sick_plant / (sick_plant + healthy_plant)

In [None]:
# Create a new measurement (gets saved to the outputs class) 

pcv.outputs.add_observation(sample='default', variable='percent_diseased', 
                            trait='percent of plant detected to be diseased',
                            method='ratio of pixels', scale='percent', datatype=float,
                            value=percent_diseased, label='percent')

In [None]:
# Data stored to the outputs class can be accessed using the variable name
pcv.outputs.observations['default']['percent_diseased']['value']

In [None]:
# Calculate percent of the plant found to be healthy 
percent_healthy = healthy_plant / (sick_plant + healthy_plant)

In [None]:
# Create a new measurement (gets saved to the outputs class) 

pcv.outputs.add_observation(sample='default', variable='percent_healthy', 
                            trait='percent of plant detected to be healthy',
                            method='ratio of pixels', scale='percent', datatype=float,
                            value=percent_healthy, label='percent')

In [None]:
# Data stored to the outputs class can be accessed using the variable name
pcv.outputs.observations['default']['percent_healthy']['value']

If you look at the percentage values from `pcv.outputs.observations`, they should add up to 100%.

In [None]:
# save_results
pcv.outputs.save_results(filename=args.result)