In [None]:
!pip -q install ultralytics
!pip -q install colorthief

!wget -q -O best.pt https://drive.google.com/uc?id=1yKBydqWKQhbIjZ4-faCnqz905PYhOod3&export=download
!wget -q -O finalized_model.sav https://drive.google.com/uc?id=1OcRpqUxXiTtTVJYCwGcO1fZ8MQ7TrKEC&export=download

In [None]:
import numpy as np
import cv2
import os
from ultralytics import YOLO
import glob
from IPython.display import Image, display
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from colorthief import ColorThief
import math
import sklearn
from sklearn.linear_model import LinearRegression
from sklearn import metrics
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import SGDRegressor
import pickle
from google.colab import files
from ipywidgets import widgets

In [None]:
class DetectUtil():
  def __init__(self):
    self.model = YOLO("best.pt")
    self.threshold = .7
    self.img_size = 416

  def detect(self,frame):
    SneakerBoxes = []
    results = self.model(frame, verbose=False)[0]
    for result in results.boxes.data.tolist():
        x1, y1, x2, y2, score, class_id = result
        if (score > self.threshold) :
            if (class_id == 0) :
                SneakerBoxes.append((round(100*score), [int(x1),int(y1),int(x2),int(y2)]))
    SneakerBoxes.sort(reverse=True)
    if (len(SneakerBoxes) > 0):
        return True,SneakerBoxes
    else:
        return False,[]

  def crop_image(self, image, crop_rect, target_size):
    # Crop the image
    x, y, w, h = crop_rect
    cropped_image = image[y:y+h, x:x+w]
    resized_image = cv2.resize(cropped_image, (target_size, target_size), interpolation=cv2.INTER_LINEAR)
    return resized_image

  def return_top3_palette(self, image_path, plot) :
    color_count = 3
    ct = ColorThief(image_path)
    dominant_color = ct.get_color(quality = 1)
    palette = ct.get_palette(color_count = color_count)
    if (plot) :
      image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_RGB2BGR)
      fig, (ax1, ax2) = plt.subplots(2)
      ax1.imshow(image)
      ax2.imshow([[palette[i] for i in range(color_count)]])
      fig.suptitle('ImagePaletteThief')
      plt.show()
    return palette

  def run_model_from_image_path(self, image_path, plot):
    X = np.zeros((10,13))
    cropped_img_paths = []
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    img = cv2.resize(img, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR)
    Detected, boxes = self.detect(img)
    if (plot) :
      fig, ax = plt.subplots()
      ax.imshow(img)

    for index,sb in enumerate(boxes) :
      score,box = sb
      xmin , ymin , xmax , ymax = box
      w , h = xmax-xmin , ymax-ymin
      X[index][0:4] = [xmin/self.img_size,ymin/self.img_size,w/self.img_size,h/self.img_size] # Test add to dataset
      crop_img = self.crop_image(img, [xmin,ymin,w,h], 200)
      crop_path = image_path.split("/")[-1][0:-4] + "_" + str(index) + ".jpg"
      cropped_img_paths.append(crop_path)
      cv2.imwrite(crop_path, cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB))
      if (plot) :
        rect = Rectangle((xmin , ymin), w, h, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)

    for index,cropped_path  in enumerate(cropped_img_paths) :
      palette = self.return_top3_palette(cropped_path,plot)
      flatten = np.array(palette).flatten()
      X[index][4::] = flatten[0:9] / 256

    if (Detected) :
      return X.flatten()
    else : return []



class PredictUtil():
    def __init__(self):
      self.model = pickle.load(open('finalized_model.sav', 'rb'))
      self.detectutil = DetectUtil()
      print("init done!")
    def predict_from_image_path(self, image_path, plot) :
      input_features = self.detectutil.run_model_from_image_path(image_path,plot)
      if (len(input_features) > 0) :
        prediction_likes = self.model.predict([input_features])
        return prediction_likes[0]
      return -1

**Upload the picture and run the prediction model !**


In [None]:
uploaded = files.upload()
file_path = list(uploaded.keys())[0]
predict_model = PredictUtil()

def predict_and_print(event):
    print("Predict like =", predict_model.predict_from_image_path(file_path, plot=False))

button_choice = widgets.Button(description="Predict", button_style='success')
button_choice.on_click(predict_and_print)
display(button_choice)


Saving IMG_6257.jpg to IMG_6257 (3).jpg
init done!


Button(button_style='success', description='Predict', style=ButtonStyle())

Predict like = 463.78
