In [None]:
import os
# script_dir = os.path.dirname(os.path.realpath(__file__)) + os.sep
# os.system("bash " +script_dir+ "run_tf_serving/run_tf_serving.sh")
labels = ['전갱이', '황아귀', '도루묵', '참조기', '갈치', '청어', '멸치', '삼치', '방어', '고등어']
import torch
import torch.nn as nn
import timm
import numpy as np
import gradio as gr
import cv2
from torchvision import transforms
from object_detection import *
from rembg.bg import remove as remove_bg
import mimetypes
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')

#models setting
model_names = ['efficientnet_b0', 'efficientnet_b4', 'nfnet_l0']
models = []
for model_name in model_names:
    models.append(timm.create_model(model_name, pretrained = False, num_classes = 10))
    
#load efficientnet_b0
models[0].load_state_dict(torch.load('./efficientnet_b0_5epoch.pt', map_location=torch.device('cpu')))
models[0].eval()

#load efficientnet_b0
models[1].load_state_dict(torch.load('./efficientnet_b4_5epoch.pt', map_location=torch.device('cpu')))
models[1].eval()

#load nfnet_l0
#load efficientnet_b0
models[2].load_state_dict(torch.load('./nfnet_l0_5epoch.pt', map_location=torch.device('cpu')))
models[2].eval()

#define AccuracyEnsemble class
class AccuracyEnsemble(torch.nn.Module):
    def __init__(self, models, train_accuracy = [1 / len(models)] * len(models)):
        super(AccuracyEnsemble, self).__init__()
        self.models = models
        self.weights = torch.tensor(train_accuracy)

    def forward(self, x):
        total_output = 0
        for weight, model in zip(self.weights, self.models):
            total_output += weight * model(x)
        return total_output

#train_lastest_accuracy setting
train_lastest_accuracy = [99.84285736083984, 99.71428680419922, 99.5999984741211]
ensemble = AccuracyEnsemble(models, train_lastest_accuracy)
ensemble.eval()

def predict(inp):
  inp = cv2.resize(inp,(224,224))
  inp = transforms.ToTensor()(inp).unsqueeze(0)
  with torch.no_grad():
    prediction = torch.nn.functional.softmax(ensemble(inp)[0], dim=0)
    confidences = {labels[i]: float(prediction[i]) for i in range(10)}
  return confidences
def classify_image(inp):
    predictions = predict(inp)
    return predictions
    
label = gr.outputs.Label(num_top_classes=3) #numtopclasses 9-> 10


def fish_length(input_img):
    class HomogeneousBgDetector():
        def __init__(self):
            pass

        def detect_objects(self, frame):
            # Convert Image to grayscale
            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

            # Create a Mask with adaptive threshold
            #mask = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 19, 5)
            mask = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY_INV, 21, 5)

            # Find contours
            contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

            #cv2.imshow("mask", mask)
            objects_contours = []

            for cnt in contours:
                area = cv2.contourArea(cnt)
                #물고기의 경우 비늘이나 지느러미를 인식하는 경우가 있다. 이를 방지하기 위해 일정 크기 이상만 인식.
                if area > 20000:
                    objects_contours.append(cnt)

            return objects_contours

    # Load Object Detector
    detector = HomogeneousBgDetector()

    # Load Image
    img = remove_bg(input_img)
    contours = detector.detect_objects(img)
    
    card_width = cv2.minAreaRect(contours[0])
    (x1, y1), (w1, h1), angle = card_width

    card_width2 = cv2.minAreaRect(contours[1])
    (x2, y2), (w2, h2), angle = card_width2

    card_flag = 0;
    # 카드의 가로와 세로 인식을 제대로 못하는 경우를 위함.
    if w1+h1 < w2+h2 :
        if w1 > h1 :
            pixel_cm_ratio = w1 / 8.56
        else :
            pixel_cm_ratio = h1 / 8.56
    else :
        card_flag = 1
        if w2 > h2 :
            pixel_cm_ratio = w2 / 8.56
        else :
            pixel_cm_ratio = h2 / 8.56

    count = 0;
    # Draw objects boundaries
    for cnt in contours:
        # Get rect
        rect = cv2.minAreaRect(cnt)
        (x, y), (w, h), angle = rect

        # Get Width and Height of the Objects by applying the Ratio pixel to cm
        object_width = w / pixel_cm_ratio
        object_height = h / pixel_cm_ratio

        # Display rectangle
        box = cv2.boxPoints(rect)
        box = np.int0(box)
        #신용카드 길이 : 가로(8.56cm), 세로(5.398cm), 둘레(27.916)
        cv2.circle(img, (int(x), int(y)), 5, (0, 0, 255), -1)
        cv2.polylines(img, [box], True, (255, 0, 0), 2)

        # 카드의 길이는 표시하지 않기 위함.
        if count == card_flag :
            count += 1
            continue;
        else :
            count += 1          
            if object_width > object_height:
                return object_width
            else :
                return object_height


    #cv2.imshow("Image", img)
    #cv2.waitKey(0)
def fish_Measurement(fish_category, fish_image):
    length_result = fish_length(fish_image)
    if (fish_category == "참조기"):
        return [round(length_result, 1), round(18.48386958*length_result-284.26583617666523, 1)]
    elif (fish_category == "방어"):
        return [round(length_result, 1), round(132.8907088*length_result-5022.848364318911, 1)]
    elif (fish_category == "도루묵"):
        return [round(length_result, 1), round(7.37708892*length_result-74.46301691482816, 1)]
    else :
        return [round(length_result, 1), "길이만 측정하셨습니다"]


def total_fn(input_img, selected):
    fish_labels = classify_image(input_img)
    sorted_labels = sorted(fish_labels.items(), key = lambda item: item[1], reverse = True)
    if(selected == "분류"):
        return [classify_image(input_img), "0", "0"]
    else:
        result = fish_Measurement(sorted_labels[0][0], input_img)
        return [classify_image(input_img), result[0], result[1]]
   
            
def flip_text(x):
    return x[::-1]

def flip_image(x):
    return np.fliplr(x)
css_code='body{background-image:url("file = background.jpg");}'

classification = gr.Image(invert_colors = True, show_label = False, elem_id = "{background-color : #c7d8e5}" )
seleteFunction = gr.Radio(["분류", "길이 & 무게"], value = "분류", label = "기능")
    
    
demo = gr.Interface(fn =total_fn, title = '어류 분류 및 길이 & 무게 측정', inputs = [classification, seleteFunction], outputs = [gr.Label(num_top_classes=3, show_label = False), gr.Textbox(label="결과, 단위(cm)"), gr.Textbox(label="결과, 단위(g)")], server_name='0.0.0.0', server_port=8898, verbose=True, allow_flagging=False)


demo.launch()