# Mushroom classifier

This application employs a neural network to determine whether that white mushroom is _Amanita virosa_ or _Agaricus sp._
Upload an image to get a prediction.

Note: you are responsible for recognizing mushrooms yourself. The results from this app are indicative and are not to be relied upon.

In [1]:
import torch

from ipywidgets import Button, FileUpload, Label, Output, VBox
from fastai.learner import load_learner
from fastai.vision.core import PILImage
from pathlib import Path

ModuleNotFoundError: No module named 'fastai'

In [None]:
root_path = Path()
learner = load_learner(root_path/'export.pkl')
category_labels = learner.dls.vocab

In [None]:
btn_upload = FileUpload()

lbl_pred = Label()
out_pl = Output()
out_pl.clear_output()

def on_click_classify(change):
    img = PILImage.create(btn_upload.data[-1])
    out_pl.clear_output()
    
    with out_pl:
        display(img.to_thumb(224,224))
        
    pred, pred_idx, probs = learner.predict(img)
    lbl_pred.value = f'Prediction: {pred}; Probability: {100 * probs[pred_idx]:.0f} %'
    if (probs[pred_idx] < 0.95):
        top_n = torch.topk(probs, 2)
        next_idx = top_n.indices[1]
        lbl_pred.value += f'; Next guess: {category_labels[next_idx]} ({100 * probs[next_idx]:.0f} %)'

btn_run = Button(description='Classify')
btn_run.on_click(on_click_classify)

In [None]:
VBox([Label('Amanita or agaricus?'), 
      btn_upload, btn_run, out_pl, lbl_pred])