In [5]:
import joblib
import cv2
import numpy as np

In [3]:
def load_model(model_path):
    """
    Load a pickled scikit-learn model from a file and make predictions.

    Parameters:
    - model_path: The path to the pickled model file.

    Returns:
    - loaded_model: The model is loaded
    """

    # Load the model from the file
    loaded_model = joblib.load(model_path)

    return loaded_model


In [17]:
def preprocess_input(input_path):
  """
  Proprocess a image for prediction (resize, convert color space, normalize)

  Parameter:
  - input_path: the input path of a user

  Returns:
  - img_normalized: image is preprocessed
  """

  image = cv2.imread(input_path)
  image = cv2.resize(image, (128,128))
  img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  img_normalized = img_rgb / 255.0
  height, width, channel = img_normalized.shape
  img = img_normalized.reshape((1, height*width*channel))
  return img

In [28]:
def decode_out(encoded_label):
  """
  Mapping a encode label into class name

  Parameter:
  - encoded_label: label is encoded by One-Hot-Encoding

  Returns:
  - decoded_label: label is decoded
  """

  # find index:
  decoded_label = np.argmax(encoded_label, axis=1)

  # Mapping index to label:
  label_mapping = {0: "Compost", 1: "Non-Recycle", 2: "Recycle"}
  decoded_label = [label_mapping[label] for label in decoded_label]

  return decoded_label


In [25]:
def predict(model_path, input_path):
  """
  Predict the user's input

  Parameter:
  - input_path: the path to a user's image
  - model_path: the path to access model

  Returns:
  - Prediction: Type of Trash (Recycle/Non-Recyle/Compost)
  """

  # load model:
  model = load_model(model_path)

  # preprocess data:
  input = preprocess_input(input_path)

  # predict:
  prediction = model.predict(input)

  # decode label:
  prediction = decode_out(prediction)

  return prediction

In [31]:
model_path = "/content/drive/MyDrive/Colab Notebooks/AI - School/random_forest.pkl"
input_path = "/content/drive/MyDrive/Dataset/School AI-Project/augumented_dataset/Recycle/01036.png"

predict(model_path, input_path)

['Recycle']