In [84]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import tensorflow as tf
from PIL import Image, ImageOps
import pillow_heif

import os
import streamlit as st
from filesplit.merge import Merge


# Create a split file of the model's weights
from filesplit.split import Split

## App Helper Functions (testing)

In [25]:
# Settings
st.set_page_config(page_title="Herbarium Classification", page_icon=None, 
                  layout="wide", 
                  initial_sidebar_state="auto", 
                  menu_items=None)



# Helpers ---
@st.cache(allow_output_mutation=True)
def load_model():
    # Load model structure
    ## Load in JSON model config
    json_file = open('./data/herb-model.json', 'r')
    loaded_json_mod = json_file.read()
    json_file.close()
    ## Feed in JSON config to keras model
    model = tf.keras.models.model_from_json(loaded_json_mod)

    # Load model weights
    ## Concatenate split weights into one file (creates full_weights.h5)
    merge = Merge('./data/', './data/', 'full_weights.h5')
    merge.merge(cleanup=False)   ##keep split files in data/ dir with False
    ## Load saved weights into model
    model.load_weights('./data/full_weights.h5')

    return model

def upload_predict(image, model, image_size=(199,199)):
    """Accepts the uploaded image (from PIL), passes through the model and organizes the result."""
    image = np.array(image)
    if len(image.shape) > 2 and image.shape[2] == 4:
        #convert the image from RGBA2RGB (for example, if input is PNG)
        image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
    image = prep_input_image(image, resize=image_size)
    prediction = model.predict(image, verbose=0)[0] ##remove [0] 
    result_df, result = preds_to_results(prediction)
    # pred_prob = [p for p in top5['prob']]
    # category, scientific_name, family, genus, species = query_plant_info(pred_classes)
    # out = pd.DataFrame({
    #     'category_id' : category,
    #     'confidence' : pred_prob,
    #     'scientific name' : scientific_name,
    #     'family' : family,
    #     'genus' : genus,
    #     'species' : species
    # })
    if prediction.item() < 0.5:
        confidence = 1-prediction.item()
    else:
        confidence = prediction.item()
    out = pd.DataFrame({
        'Prediction' : result,
        'Confidence' : [confidence]
    })

    return out


def prep_input_image(image, resize=(199, 199)):
    image = image.astype('float32')
    # Preprocessing
    image = np.array(image)
    image = preprocess_image(image, image_size=resize, resize=True, recolor=False) ##PIL reads in as RGB automatically
    # Rescale pixels
    image *= 1.0/255
    # standardize image - https://github.com/keras-team/keras/issues/2559
    image -= np.mean(image, axis=2, keepdims=True) ##samplewise center
    image /= (np.std(image, axis=2, keepdims=True) + 1e-7) ##samplewise std normalization
    # Reshape dimensions (add batch dimension)
    return np.expand_dims(image, axis=0)

def preprocess_image(img, bright_threshold=0.25, bright_value=30, image_size=(199,199), resize=True, recolor=True):
    """
    Applies a standardized preprocessing procedure to every image before it is passed into the model.
    Notes:
    If using tf.keras.preprocessing.image.ImageDataGenerator (with either flow_from_dataframe or flow_from_directory),
    do not specify the image size (ie, resize=False) or recolor the images (recolor=False). Resizing will cause issues, so 
    instead specify the image size in the flow_from_ function. Additionally, the flow_from_ functions automatically read in images
    as RGB (whereas cv2 reads them as BGR), so recoloring the images will have the opposite effect.
    """
    #Recolor
    if recolor:
        im = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    else:
        im = img
    #Adjust brightness
    hsv = cv2.cvtColor(im, cv2.COLOR_RGB2HSV)
    mean_val = np.mean(hsv[:,:,2])/255 #as a percentage of maximum pixel value
    if mean_val <= bright_threshold:
        h, s, v = cv2.split(hsv)
        lim = 255 - bright_value
        v[v > lim] = 255
        v[v <= lim] += bright_value
        final_hsv = cv2.merge((h, s, v))
        im = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2RGB)
    #Resize
    if resize:
        im = cv2.resize(im, image_size, interpolation = cv2.INTER_AREA)
    else:
        im = im
    return im


def preds_to_results(pred):
    # Initialize dictionary mappings
    i2tox = setup_dictionaries()
    # Top 5 class prediction
    # num_class = 15501
    # pred = pred.reshape(num_class)
    # pred_idx = np.argpartition(pred, -5)[-5: ]
    # Probabilities for top 5 preds
    # pred_prob = pred.reshape(num_class)[pred_idx]

    tox_pred_ix = np.where(pred > 0.5, 1, 0).item()

    # Map to get category labels
    pred_class = i2tox[tox_pred_ix]

    image_guess = pd.DataFrame({
    'class' : [pred_class],
    'prob' : [round(pred.item(), 3)],
    }).sort_values(by = 'prob', ascending=False)
    sorted_classes = [c for c in image_guess['class']]

    return image_guess, sorted_classes


def query_plant_info(categories, X=None,y=None):
    return None
    # # Locate row
    # rows = meta_data.loc[meta_data.category.isin(categories)]
    # # Extract info
    # category = [c for c in rows['category']]
    # scientific_name = [sc for sc in rows['scientific_name']]
    # family = [f for f in rows['family']]
    # genus = [g for g in rows['genus']]
    # species = [s for s in rows['species']]
    # images = []
    # # plot example image
    # if X is not None:
    #     for categ in [c for c in rows['category']]:
    #         img = X[np.where(y==categ)[0][0]]
    #         img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    #         images.append(img)
            
    #         #plt.imshow(img)
    #         #plt.axis('off')
    #         #plt.show()
    
    # return category, scientific_name, family, genus, species#, images


def setup_dictionaries():
    i2tox = {0: "Nontoxic", 1: "Toxic"}
    meta = pd.DataFrame({
    "class_id" : [0, 1, 2, 3],
    "slang" : ["Poison Oak", "Poison Ivy", "Poison Sumac", "Nontoxic"],
    "scientific_name" : ["Toxicodendron diversilobum / Toxicodendron pubescens", 
                        "Toxicodendron radicans / Toxicodendron rydbergii", 
                        "Toxicodendron vernix", "'Nontoxic'"],
    }) ##for use later...
    return i2tox

In [None]:

# #-----#
# # APP #--------------------------------------------------------
# #-----#

# with st.spinner('Loading trained model...'):
#     model = load_model()
#     # Load metadata
#     meta_data = pd.read_csv('./data/herb22meta_data.csv')

# # INTRO
# st.write("""
#          # Rash Plant Classification
#          ### Helping to Identify Poison Oak, Ivy, and sumac
#          *DISCLAIMER:* This app is built on an image classification model which was trained to 80% test  
#          accuracy. The best way to avoid poison oak, ivy, or sumac is to learn to identify them yourself. 
         
#           ### Quick Facts:
#          -Eastern/Atlantic Poison Oak: Toxicodendron pubescens. 3 leaflets, usually lobed, fuzzy when young but dull when mature. 
#          Can have light green to cream colored berries and small, light green flowers. Typically grows in shrub form.  
#          -Western/Pacific Poison Oak: Toxicodendron diversilobum. 3 leaflets, usually lobed, both shiny and dull (often shiny). 
#           Can have White or cream colored berries and small, pink or green flowers. Can grow as vines or shrubs.
#          -Eastern/Atlantic Poison Ivy & Western/Pacific Poison Ivy: Toxicodendron radicans & Toxicodendron rydbergii. 
#          3 leafelts, either smooth (entire) or lobed, but never jagged (serrate).  
#          Can be shiny and fuzzy. May have white to cream colored berries and tiny white/green buds. Can grow as vines, shrubs, or ground cover.  
#          Poison Sumac: Toxicodendron vernix. 7-13 leaflets per single leaf, often very shiny and angled upward. 
#          May have white to green berries and greenish, drooping flowers. Grows as a shrub in wet, wooded areas. It is not as common as poison oak
#          and ivy and tends to be found in the swampy areas of the southeastern and northeastern united states. 
         
#          The deep learning model used for this project is a ResNet50 model pre-trained on ImageNet and then trained for 10
#          epochs on the Herbarium 2022 dataset. It achieved an F1-score of 0.73 on the Herbarium test data available on Kaggle.  
#          """
#          )

# # EXAMPLE IMAGES
# st.write("""
#         ### Examples  
#         Sources: [Leaf margins](https://biodiversity.utexas.edu/news/entry/leaves), [iNaturalist](https://www.inaturalist.org/) 
#         """)
# ip = "./images/"
# paths = [ip+'east-pois-oak.jpg', ip+'west-pois-oak.jpg', ip+'east-pois-ivy.jpg', ip+'west-pois-ivy.jpg', ip+'pois-sumac.jpg']
# ex_imgs = [Image.open(paths[0]), Image.open(paths[1]), Image.open(paths[2]), Image.open(paths[3]), Image.open(paths[4]), Image.open(paths[4])]
# st.image(ex_imgs, caption=['Eastern Poison Oak', 'Western Poison Oak',
#                             'Eastern Poison Ivy', 'Western Poison Ivy',
#                             'Poison Sumac'],
#                             use_column_width=False, width=400)



# # FILE UPLOADER
# file = st.file_uploader("Upload an image.", 
#                         help="Supported filetypes: jpg/jpeg, png, heic (iPhone).") #type=["jpg", "png", "heic, "])
# st.set_option('deprecation.showfileUploaderEncoding', False)



# # PROCESS IMAGE AND PREDICT
# if file is None:
#     st.text("Upload image for prediction.")
# else:
#     bytes_data = file.read()
#     filename = file.name
#     # If file is in HEIC format (ie, if uploaded from iphone)
#     if filename.split('.')[-1] in ['heic', 'HEIC', 'heif', 'HEIF']:
#         heic_file = pillow_heif.read_heif(file)
#         img = Image.frombytes(
#             heic_file.mode,
#             heic_file.size,
#             heic_file.data
#         )
#     else:
#         img = Image.open(file)
#     st.write("### Your Image:")
#     st.write("filename:", filename)
#     st.image(img, width=400, use_column_width=False)
#     pred_classes = upload_predict(img, model)
#     st.write("# Prediction")
#     st.write(pred_classes)


# st.write("---")
# st.write("by [Hans Elliott](https://hans-elliott99.github.io/)")


## Model Splitting

In [104]:
split1 = Split('./models/ToxRes101_1_best_mod.h5', './split_model_1/')
split1.bysize(size=5e+7) 

split2 = Split('./models/ToxRes101_1_final_model.h5', './split_model_2/')
split2.bysize(size=5e+7)

2022-09-22 14:27:42.539 INFO    filesplit.split: Starting file split process
2022-09-22 14:27:43.428 INFO    filesplit.split: Process completed in 0 min(s)
2022-09-22 14:27:43.429 INFO    filesplit.split: Starting file split process
2022-09-22 14:27:44.254 INFO    filesplit.split: Process completed in 0 min(s)


## Re-compile the split models and test

In [105]:
merge1 = Merge('./split_model_1/', './models/', 'compiled_model1.h5',)
merge1.merge(cleanup=False) ##keep split files with False

merge2 = Merge('./split_model_2/', './models/', 'compiled_model2.h5')
merge2.merge(cleanup=False)

2022-09-22 14:28:07.520 INFO    filesplit.merge: Starting file merge process
2022-09-22 14:28:11.962 INFO    filesplit.merge: Process completed in 0 min(s)
2022-09-22 14:28:11.963 INFO    filesplit.merge: Starting file merge process
2022-09-22 14:28:14.650 INFO    filesplit.merge: Process completed in 0 min(s)


In [107]:
bestmod = './models/compiled_model1.h5'
finalmod = './models/compiled_model2.h5'
bestmod = tf.keras.models.load_model(bestmod, compile=False)
finalmod = tf.keras.models.load_model(finalmod, compile=False)

In [169]:
img_path = './images/west-pois-oak.jpg'
img = Image.open(img_path)

ex_image = np.array(img)
ex_image = prep_input_image(ex_image, resize=(199, 199))

In [170]:
pred1 = bestmod(ex_image, training=False)
pred2 = finalmod(ex_image, training=False)

In [171]:
np.mean([pred1, pred2])
pred1.numpy().item()

0.989142119884491