In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.insert(0, '/home/andrew/Andrew/MMLAB/saved_models')

import os
import time
import json
import logging

import fiftyone as fo

from PIL import Image
from torchvision import transforms

import torchvision
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm

from export_model.inference import Inference
from model.model_light import FishSeg
from model.dataset import SimpleFishialFishDataset
from model.utils import generate_random_image, draw_polygon, create_mask, visualize, scale_polygon, load_image, max_iou
from shapely.geometry import Polygon
logger = logging.getLogger()
logger.disabled = True

In [None]:
from detector.inference import YOLOInference
from segmentator.inference import Inference
from segmentator.inference import poly_array_to_dict, convert_local_polygons_to_global
from classification_rect.inference import EmbeddingClassifier
from classification_old.inference import EmbeddingClassifier as EmbeddingClassifierOld

from segmentation.inference_segm import SegmentationInference

In [None]:
def save_json(data, path):
    with open(path, 'w', encoding='utf-8') as f:
        json.dump(data, f)

In [None]:
CLASSIFICATION_RECT_URL = 'https://storage.googleapis.com/fishial-ml-resources/classification_rectangle.zip'
DETECTRON_URL = 'https://storage.googleapis.com/fishial-ml-resources/segmentation.zip'

In [None]:
SEGMENTATION_MODEL_PATH = 'saved_models/segmentator/model.ts'
segmentator = Inference(model_path = SEGMENTATION_MODEL_PATH, image_size = 416)

In [None]:
DETECTRON2_MODEL_PATH = 'saved_models/segmentation/model.ts'
detectron = SegmentationInference(model_path = DETECTRON2_MODEL_PATH)

In [None]:
classification_path = 'saved_models/classification_rect/model.ts'
data_base_path      = 'saved_models/classification_rect/database.pt'

model_classifier = EmbeddingClassifier(
        classification_path, 
        data_base_path
)

In [None]:
classification_path = 'saved_models/classification_old/model.ts'
data_base_path      = 'saved_models/classification_old/embeddings.pt'
data_id_path      = 'saved_models/classification_old/idx.json'


model_classifier_v7 = EmbeddingClassifierOld(
        classification_path, 
        data_base_path,
        data_id_path = data_id_path)


In [None]:
# YOLO_MODEL_PATH = 'saved_models/detector/best.torchscript'
YOLO_MODEL_PATH = '/home/andrew/Andrew/yolov9/UltraSegmTrainFISHIAL_OBJECT_DETECTION/yolov10s_640_False/weights/best.torchscript'#'saved_models/detector/model.ts'

detector = YOLOInference(YOLO_MODEL_PATH, imsz = (640,640), conf_threshold = 0.3, nms_threshold = 0.9, yolo_ver = 'v10')

In [None]:
USE_YOLO_SEGMENTATOR = True
DETECTRON2 = True
CLASSIFICATION = True

line_width = 7
data = fo.load_dataset('SEGM-2024-V0.8')
# data = fo.load_dataset('SEGM-2024-V0.8-VALIDATION')

count_to_visualize = 20
view = data.take(count_to_visualize)
final_dict = {}

for sample_id, sample in tqdm(enumerate(view)):
    if sample.image_id in final_dict: continue
        
    
    filepath = sample.filepath
    
    np_img = cv2.imread(filepath)
#     URL = 'https://storage.googleapis.com/backend-fishes-2023/h6raxscor6zx2skzlbz63q5385wo?GoogleAccessId=backend-fishes-storage%40fishial-staging.iam.gserviceaccount.com&Expires=1721939251&Signature=gbc%2ByuqDbQLqGIfdxRoRurKrywck%2FohEJmReoWMwmFNUMlkDti3Za4kOt2Ppmx1eyWUPpsHju7VylGBt2keH9RM4BYYkib8d9xwca8aAzI0KRLErnaDFVMLOEGnmZkoBw87gQhvGFrOyozudFXWhkc4dMOfVFYSx8d6JeY7ASA8BEgz2o3j2wZNFkGWgJP3AViGUUa6VgOpU9cSmEIBJxytLt%2BUlTZgMyx5uWju7sMR1wGKsBbaP8tH%2F9CpPYzfyY7C85bdwQJ683%2BAWBcbPIBycuEf4XQ75%2FXjG2VrqGAQJUGYuL8%2FvCsMOlzuNGe8UIktMNCHRPdPzCRt1GcQ5Gw%3D%3D&response-content-disposition=inline%3B+filename%3D%22Redband+Parrotfish+-+Sparisoma+aurofrenatum+-+152%252C+Stoplight+Parrotfish+-+Sparisoma+viride+-+TWA+-+155%252C+Striped+Parrotfish+-+Scarus+iseri+-+TWA+-+156P9251257-ChristySemmens.jpg%22%3B+filename%2A%3DUTF-8%27%27Redband%2520Parrotfish%2520-%2520Sparisoma%2520aurofrenatum%2520-%2520152%252C%2520Stoplight%2520Parrotfish%2520-%2520Sparisoma%2520viride%2520-%2520TWA%2520-%2520155%252C%2520Striped%2520Parrotfish%2520-%2520Scarus%2520iseri%2520-%2520TWA%2520-%2520156P9251257-ChristySemmens.jpg&response-content-type=image%2Fjpeg'
#     np_img = load_image({'imageURL': URL})

    
    img_rgb = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
    pil_image = Image.fromarray(img_rgb)
    
    start_time = time.time()
    
    
    if USE_YOLO_SEGMENTATOR:
        
        boxes = detector.predict(np_img)[0]
        cropped_images = [box_inst.get_mask_BGR() for box_inst in boxes]
    
        np_polygon_yolo = segmentator.predict(cropped_images)


#         print(f"[YOLO] Fish {len(list_of_imgs)} time: {time.time() - start_time}")
        
        yolo_img = np.array(pil_image)
        if len(cropped_images) != 0:
            
            for poly_id, poly in enumerate(np_polygon_yolo):
                
                poly_inst = np_polygon_yolo[poly_id]
                poly_inst.move_to(boxes[poly_id].x1, boxes[poly_id].y1)
                poly_inst.draw_polygon(yolo_img, color = (255,0,0), thickness = line_width)
        print(f"DICTIONARY: {[len(output) for output in dict_output]}")

    start_time = time.time()        
    if DETECTRON2:
        detectron_img = pil_image.copy() 
        
        list_of_boxes_mask_rcnn, mask_rcnn_output = detectron.inference(np_img)
#         mask_rcnn_output = []
        list_of_imgs_detectron2 = []
        shaply_mask_rcnn_poly_detected = []
        for ploy_dict in list_of_boxes_mask_rcnn:
            converted_poly = [(ploy_dict[f"x{point_id}"], ploy_dict[f"y{point_id}"]) for point_id in range(1, int(len(ploy_dict)/2) + 1)]
            converted_poly_np = np.array(converted_poly)
            x_min, x_max, y_min, y_max = min(converted_poly_np[:, :1])[0], max(converted_poly_np[:, :1])[0], min(converted_poly_np[:, 1:2])[0], max(converted_poly_np[:, 1:2])[0]
            list_of_imgs_detectron2.append(np_img[y_min:y_max, x_min:x_max])
            
            draw_polygon(detectron_img, converted_poly, line_color = (0,255,0), line_width = line_width)
        

    if len(list_of_imgs_yolo) != 0:
        output_class_yolo = model_classifier.batch_inference(list_of_imgs_yolo)
        
    if len(list_of_imgs_detectron2) != 0:
        output_class_detectron2 = model_classifier_v7.batch_inference(mask_rcnn_output)

    
    height, width = np_img.shape[:2]
    
    for polylin_inst in sample['General body shape'].polylines:
        polygon = polylin_inst.points[0]
        polygon = scale_polygon(polygon, width, height)

        draw_polygon(pil_image, polygon, line_color = (0,255,0), line_width = line_width)

    visualize(
        ground_truth=pil_image, 
        DETECTRON2=detectron_img, 
        NEW_YOLO=yolo_img
    )
# save_json(final_dict, "result_actual.json")

In [None]:
URL = 'https://reefguide.org/pix/emperorangelfish2.jpg'
im = load_image({'imageURL': URL})

In [None]:
Image.fromarray(im)

In [None]:
!wget https://storage.googleapis.com/fishial-ml-resources/classification.zip

In [None]:
!unzip classification.zip -d classification

In [None]:
list(converted_poly)