# IMPORT

In [None]:
import os
import torch
import time
from detectron2.export import (
    STABLE_ONNX_OPSET_VERSION,
    TracingAdapter,
    dump_torchscript_IR,
    scripting_with_instances,
)
from detectron2.modeling import GeneralizedRCNN, RetinaNet, build_model
from detectron2.data import build_detection_test_loader, detection_utils
from detectron2.config import get_cfg
from detectron2.projects.point_rend import ColorAugSSDTransform, add_pointrend_config

import detectron2.data.transforms as T
from detectron2.utils.file_io import PathManager

In [None]:
# return inputs for segmentation model by using image path

def get_sample_inputs(sample_image):
    
    original_image = detection_utils.read_image(sample_image, format=cfg.INPUT.FORMAT)
    # Do same preprocessing as DefaultPredictor
    aug = T.ResizeShortestEdge(
        [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
    )
    print(cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MAX_SIZE_TEST)
    height, width = original_image.shape[:2]
    image = aug.get_transform(original_image).apply_image(original_image)
    image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

    inputs = {"image": image, "height": height, "width": width}

    # Sample ready
    sample_inputs = [inputs]
    return sample_inputs

#helper for segmentation model
def inference(model, inputs):
    # use do_postprocess=False so it returns ROI mask
    inst = model.inference(inputs, do_postprocess=False)[0]
    return [{"instances": inst}]

In [None]:
def get_cfg_segm_model():
    path_to_segmentation_config = r'/home/fishial/Fishial/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml'
    cfg = get_cfg()
    add_pointrend_config(cfg)
    cfg.merge_from_file(path_to_segmentation_config)
    return cfg


In [None]:
IMAGE_PATH = r'/home/fishial/Fishial/output/species_53_576c9cb12ac66.w1000.h666.jpg'
SEGMENTATION_FODLER = r'/home/fishial/Fishial/output/segmentation_export_torchscript'
MODEL_NAME = "model_segm_21_08_2023.ts"
#os.makedirs(SEGMENTATION_FODLER, exist_ok=True)

cfg = get_cfg_segm_model()

In [None]:
from PIL import Image
import numpy as np

In [None]:
image_path = r'/home/fishial/Fishial/output/species_53_576c9cb12ac66.w1000.h666.jpg'
image = read_image(image_path)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

In [None]:
full_model = r'/home/fishial/Fishial/output/model_21_08_2023_test.pt'
model_script = torch.load(full_model)
model_script.eval()

inputs = get_sample_inputs(IMAGE_PATH)
image = inputs[0]["image"]
inputs = [{"image": image}]  # remove other unused keys
traceable_model = TracingAdapter(model_script, inputs, inference)
ts_model = torch.jit.trace(traceable_model, (image,))
ts_model.eval()

frozen_module = torch.jit.freeze(ts_model)
with PathManager.open(os.path.join(SEGMENTATION_FODLER, MODEL_NAME), "wb") as f:
    torch.jit.save(frozen_module, f)
dump_torchscript_IR(frozen_module, SEGMENTATION_FODLER)

In [None]:
from torch.utils.mobile_optimizer import optimize_for_mobile

optimized_torchscript_model = optimize_for_mobile(frozen_module)
optimized_torchscript_model.save(os.path.join(SEGMENTATION_FODLER,'optimized_torchscript_model.pth'))

## load raw file from torch script and test it

In [None]:
from PIL import Image
import numpy as np
from torch.nn import functional as F
import cv2
import matplotlib.pyplot as plt
def convert_PIL_to_numpy(image, format):
    """
    Convert PIL image to numpy array of target format.

    Args:
        image (PIL.Image): a PIL image
        format (str): the format of output image

    Returns:
        (np.ndarray): also see `read_image`
    """
    if format is not None:
        # PIL only supports RGB, so convert to RGB and flip channels over below
        conversion_format = format
        if format in ["BGR", "YUV-BT.601"]:
            conversion_format = "RGB"
        image = image.convert(conversion_format)
    image = np.asarray(image)
    # PIL squeezes out the channel dimension for "L", so make it HWC
    if format == "L":
        image = np.expand_dims(image, -1)

    # handle formats not supported by PIL
    elif format == "BGR":
        # flip channels if needed
        image = image[:, :, ::-1]
    elif format == "YUV-BT.601":
        image = image / 255.0
        image = np.dot(image, np.array(_M_RGB2YUV).T)

    return image

def _apply_exif_orientation(image):
    """
    Applies the exif orientation correctly.

    This code exists per the bug:
      https://github.com/python-pillow/Pillow/issues/3973
    with the function `ImageOps.exif_transpose`. The Pillow source raises errors with
    various methods, especially `tobytes`

    Function based on:
      https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59
      https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527

    Args:
        image (PIL.Image): a PIL image

    Returns:
        (PIL.Image): the PIL image with exif orientation applied, if applicable
    """
    _EXIF_ORIENT = 274 
    
    if not hasattr(image, "getexif"):
        return image

    try:
        exif = image.getexif()
    except Exception:  # https://github.com/facebookresearch/detectron2/issues/1885
        exif = None

    if exif is None:
        return image

    orientation = exif.get(_EXIF_ORIENT)

    method = {
        2: Image.FLIP_LEFT_RIGHT,
        3: Image.ROTATE_180,
        4: Image.FLIP_TOP_BOTTOM,
        5: Image.TRANSPOSE,
        6: Image.ROTATE_270,
        7: Image.TRANSVERSE,
        8: Image.ROTATE_90,
    }.get(orientation)

    if method is not None:
        return image.transpose(method)
    return image

def read_image(file_name, format=None):
    """
    Read an image into the given format.
    Will apply rotation and flipping if the image has such exif information.

    Args:
        file_name (str): image file path
        format (str): one of the supported image modes in PIL, or "BGR" or "YUV-BT.601".

    Returns:
        image (np.ndarray):
            an HWC image in the given format, which is 0-255, uint8 for
            supported image modes in PIL or "BGR"; float (0-1 for Y) for YUV-BT.601.
    """
   
    image = Image.open(file_name)

    # work around this bug: https://github.com/python-pillow/Pillow/issues/3973
    image = _apply_exif_orientation(image)
    return convert_PIL_to_numpy(image, format)

def _do_paste_mask(masks, img_h: int, img_w: int):
    """
    Args:
        masks: N, 1, H, W
        boxes: N, 4
        img_h, img_w (int):
        skip_empty (bool): only paste masks within the region that
            tightly bound all boxes, and returns the results this region only.
            An important optimization for CPU.

    Returns:
        if skip_empty == False, a mask of shape (N, img_h, img_w)
        if skip_empty == True, a mask of shape (N, h', w'), and the slice
            object for the corresponding region.
    """
    # On GPU, paste all masks together (up to chunk size)
    # by using the entire image to sample the masks
    # Compared to pasting them one by one,
    # this has more operations but is faster on COCO-scale dataset.
    device = masks.device

    x0_int, y0_int = 0, 0
    x1_int, y1_int = img_w, img_h
    x0, y0, x1, y1 =  torch.Tensor([[0]]), torch.Tensor([[0]]), torch.Tensor([[img_w]]), torch.Tensor([[img_h]])

    N = masks.shape[0]

    img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5
    img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5
    img_y = (img_y - y0) / (y1 - y0) * 2 - 1
    img_x = (img_x - x0) / (x1 - x0) * 2 - 1
    # img_x, img_y have shapes (N, w), (N, h)
    gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
    gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
    grid = torch.stack([gx, gy], dim=3)

    resized_mask = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False)

    return resized_mask

In [None]:
model_path = r'/home/fishial/Fishial/output/segmentation_export_torchscript/model_segm_21_08_2023.ts'
model = torch.jit.load(model_path)
model.eval()
model.cpu()

In [None]:
image_path = '/home/fishial/Fishial/output/dbfded7e-77a4-433e-b1ee-c52a67d8b0fd.jpg'
input_np = read_image(image_path)
input_ts = torch.as_tensor(input_np.astype("float32").transpose(2, 0, 1))

outputs_ts = model(input_ts) 
masks = outputs_ts[2]
device = 'cpu'

s_t = time.time()
mask_outputs = []
for ind in range(len(outputs_ts[0])):
    x1, y1, x2, y2 = int(outputs_ts[0][ind][0]), int(outputs_ts[0][ind][1]), int(outputs_ts[0][ind][2]) , int(outputs_ts[0][ind][3])  
    img_w, img_h = x2 - x1, y2 - y1
    masks = outputs_ts[2][ind, None, :, :]
    masks_chunk = _do_paste_mask(masks, img_h, img_w)

    np_mask = masks_chunk.numpy()
    np_mask = np.where(np_mask > 0.5, 255, 0)
    
    crop_image = input_np[y1:y1 + img_h, x1:x1+img_w]
    crop_image = crop_image.astype(int)
    
    np_mask = np_mask.astype(np.uint8)
    
    res = cv2.bitwise_and(crop_image, crop_image, mask = np_mask[0][0])
    mask_outputs.append([crop_image, np_mask[0][0]])


In [None]:
plt.imshow(mask_outputs[1][1])
plt.show()

plt.imshow(mask_outputs[1][0])
plt.show()

# Classification model

In [None]:
import os
import sys
# Change path specificly to your directories
sys.path.insert(1, '/home/fishial/Fishial/Object-Detection-Model')

In [None]:
CLASSIFICATION_FODLER = r'/home/fishial/Fishial/output/classification/resnet_18_triplet_08_09_2023_v06_under_train_cross'
MODEL_SCRIPT_PATH     = os.path.join(CLASSIFICATION_FODLER, 'model.pt')
CKPT_PATH             = os.path.join(CLASSIFICATION_FODLER, 'model.ckpt')
MODEL_NAME_TS = "model.ts"

os.makedirs(CLASSIFICATION_FODLER, exist_ok=True)


## IF script model wasn't created, create it from state dict "checkpoint"

In [None]:
from module.classification_package.src.model import init_model

model_script = init_model(289, device='cpu', checkpoint = os.path.join(CLASSIFICATION_FODLER, r'model.ckpt'))
model_script.eval()
torch.save(model_script, MODEL_SCRIPT_PATH)


# Get model from PT

In [None]:
model_script = torch.load(MODEL_SCRIPT_PATH)
model_script.eval()

# convert the script model to torchscript and freeze him 

In [None]:
ts_model = torch.jit.script(model_script)
ts_model.eval()
frozen_module = torch.jit.freeze(ts_model)
with PathManager.open(os.path.join(CLASSIFICATION_FODLER, MODEL_NAME_TS), "wb") as f:
    torch.jit.save(frozen_module, f)
    
dump_torchscript_IR(frozen_module, CLASSIFICATION_FODLER)

# make some tests, laod and push random input 

In [None]:
model_ts = torch.jit.load(os.path.join(CLASSIFICATION_FODLER, MODEL_NAME_TS))
model_ts.eval()

# Lets provide perfomance test and accuracy of optimized model

frozen torchscript model have to be ~50% faster than script model after warm up and diffrence in output tend to be 0|

In [None]:
for _ in range(10):
    input_rand = torch.rand(1,3,224,224)

    s_t = time.time()
    output_src = model_script(input_rand)
    exec_src = time.time() - s_t

    s_t = time.time()
    output_ts = model_ts(input_rand)
    exec_ts = time.time() - s_t

    diff_emb = (output_src[0] - output_ts[0]).abs().sum() # diff on a embedding values
    diff_acc = (output_src[1] - output_ts[1]).abs().sum() # diff on a accuracy output

    print(f"Execution time: |ts: {exec_ts} vs src: {exec_src}| Difference on outputs: |{diff_emb} vs {diff_acc}|")

## Test on a real data, reading, converting and inference

In [None]:
import cv2
import numpy as np
import fiftyone as fo

from PIL import Image
from torch import nn
from tqdm import tqdm
from torchvision import transforms

from module.segmentation_package.src.utils import get_mask
from module.classification_package.src.utils import read_json, save_json

In [None]:
def get_image(img_path, polyline):
    full_image = cv2.imread(img_path)
    mask_np = get_mask(full_image, np.array(polyline))
    mask_pil = Image.fromarray(mask_np)
    mask_tensor = loader(mask_pil)
    
    return mask_pil, mask_tensor.unsqueeze(0) # add batch dimension

def classify_fc(output):
    acc_values = softmax(output)
    class_id = torch.argmax(acc_values).item()
    #print(f"Recognized species id {class_id} with liklyhood: {acc_values[0][class_id]}")
    return class_id, acc_values[0][class_id]

def classify_embedding(data_base, embedding, indexes_of_elements):
    diff = (data_base - embedding).pow(2).sum(dim=1).sqrt()
    val, indi = torch.sort(diff)
    class_lib = [[indexes_of_elements['list_of_ids'][indiece], diff[indiece]] for indiece in indi[:10]]
    class_lib = [[indexes_of_elements['categories'][str(rec[0][0])],rec[0], rec[1]] for rec in class_lib]
    return class_lib

## If embedding tensor has been created define it and check accuracy 

In [None]:
data_base = torch.load('/home/fishial/Fishial/output/classification/resnet_18_triplet_08_09_2023_v06_under_train_cross/embeddings.pt')
data_base.cpu()
indexes_of_elements = read_json('/home/fishial/Fishial/output/classification/resnet_18_triplet_08_09_2023_v06_under_train_cross/idx.json')


In [None]:

fo_dataset = fo.load_dataset("classification-05-09-2023-v06")
fo_dataset = fo_dataset.match_tags(['val', 'train'])

# get softmax function
softmax = nn.Softmax(dim=1)

loader = transforms.Compose([
        transforms.Resize((224, 224), Image.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [None]:
pbar = tqdm(fo_dataset.match_tags(['val', 'train']))

acc_fc = []
acc_eb = []

wrong_classified_by_both = []

for sample in pbar:
    img_path = sample['filepath']
    label = sample['polyline']['label']
    image_id, annotation_id, drawn_fish_id = sample['image_id'], sample['annotation_id'], sample['drawn_fish_id']
    width, height = sample['width'], sample['height']
    
    polyline = sample['polyline']['points'][0]
    polyline = [[int(point[0] * width), int(point[1] * height)] for point in polyline]
    
    pil_image, input_tensor = get_image(img_path, polyline)   
    output_ts     = model_ts(input_tensor) 
    
    class_id_ts, score_ts = classify_fc(output_ts[1])
    output = classify_embedding(data_base, output_ts[0][0], indexes_of_elements)
    
    if output[0][2].item() != 0.0:
        print("ERROR, something went wrong in propper way it's impossible")
        
    acc_fc.append(indexes_of_elements['categories'][str(class_id_ts)]['name'] == label)
    acc_eb.append(output[1][0]['name'] == label)
    if indexes_of_elements['categories'][str(class_id_ts)]['name'] != label and output[1][0]['name'] != label:
        wrong_classified_by_both.append([label, image_id, annotation_id, drawn_fish_id,indexes_of_elements['categories'][str(class_id_ts)]['name'],output[1][0]['name'], output[1][2].item()])
        sample['wrong_emb'] = output[1][0]['name']
        sample['wrong_dist'] = output[1][2].item()
        sample['wrong_fc'] = indexes_of_elements['categories'][str(class_id_ts)]['name']
        sample.save()
    
    pbar.set_description(f"Eval acc_fc: {sum(acc_fc)/len(acc_fc)} vs {sum(acc_eb)/len(acc_eb)} acc_eb: WRONG: {len(wrong_classified_by_both)}")

In [None]:
wrong_classified_path = r'/home/fishial/Fishial/output/classification/resnet_18_triplet_08_09_2023_v06_under_train_cross/wrong_classified.json'
save_json(wrong_classified_by_both, wrong_classified_path)

## Next code is temporary if wasn't remove please do it

In [None]:
import torch.nn as nn
import numpy as np
import logging
import torch
import json

from PIL import Image
from torchvision import transforms


def read_json(path_to_json):
    with open(path_to_json) as f:
        return json.load(f)
    
def get_results(output):
        top_1, top_1_val = None, 10e9
        top_median, top_median_val = None, 10e9

        for i in output:
            if top_1_val > output[i]['top_1']:
                top_1 = i
                top_1_val = output[i]['top_1']

            if top_median_val > output[i]['median']:
                top_median = i
                top_median_val = output[i]['median']
        return [top_1, top_median]
    
class EmbeddingClassifier:
    def __init__(self, model_path, data_set_path, data_id_path, device='cpu', THRESHOLD = 8.84):
        self.device = device
        self.THRESHOLD = THRESHOLD
        self.indexes_of_elements = read_json(data_id_path)
        self.softmax = nn.Softmax(dim=1)
        
        self.model = torch.jit.load(model_path)
        self.model.eval()
        self.model.to(device)
        
        self.loader = transforms.Compose([
            transforms.Resize((224, 224), Image.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        self.data_base = torch.load(data_set_path).to(device)
        logging.info("[INIT][CLASSIFICATION] Initialization of classifier was finished")
                
    def __inference(self, image, top_k = 15): 
        logging.info("[PROCESSING][CLASSIFICATION] Getting embedding for a single detection mask")
        
        dump_embed, fc_output = self.model(image.unsqueeze(0).to(self.device))
        
        logging.info("[PROCESSING][CLASSIFICATION] Classification by Full Connected layer for a single detection mask")  
        classes, scores = self.__classify_fc(fc_output)
        
        logging.info("[PROCESSING][CLASSIFICATION] Classification by embedding for a single detection mask")
        output_by_embeddings = self.__classify_embedding(dump_embed[0], top_k)
        
        logging.info("[PROCESSING][CLASSIFICATION] Beautify output for a single detection mask")
        result = self.__beautifier_output(output_by_embeddings, self.indexes_of_elements['categories'][str(classes[0].item())]['name'])
        result = [[self.__get_species_name(record[0]), record] for record in result]
        return result
    
    def __beautifier_output(self, output_by_embeddings, classification_label):
            
        dict_results = {}
        for i in output_by_embeddings:
            if i[0]['name'] in dict_results:
                dict_results[i[0]['name']]['values'].append(i[2].item())
                dict_results[i[0]['name']]['annotations'].append(i[1])
            else:
                dict_results.update({i[0]['name']: {
                    'values': [i[2].item()],
                    'annotations': [i[1]]
                }})

        for i in dict_results:
            dict_results[i].update({'top_1': dict_results[i]['values'][0]})
            dict_results[i].update({'annotation': dict_results[i]['annotations'][0]})
            dict_results[i].update({'median': np.median(dict_results[i]['values'])})
            del dict_results[i]['values']
            del dict_results[i]['annotations']

        labels = get_results(dict_results)
        labels = list(set(labels))

        for result in list(dict_results.keys()):
            if result not in labels:
                del dict_results[result]
            else:
                mean_distance = (dict_results[result]['top_1'] + dict_results[result]['median'])/2
                dict_results[result]['dist'] = mean_distance
                dict_results[result]['conf'] = round(self.__get_confidence(mean_distance), 3)
                logging.info(f"[PROCESSING][CLASSIFICATION] the threshold |{mean_distance}| has been recalculated to |{dict_results[result]['conf']}|")
        results = [[label, dict_results[label]['conf'], dict_results[label]['annotation']] for label in dict_results]
        logging.info("[PROCESSING][CLASSIFICATION] Classification by embedding was finished successfuly")

        if classification_label not in labels:
            logging.info("[PROCESSING][CLASSIFICATION] Append into output classification result by FC - layer")
            results.append([classification_label, 0.1, [None, None, None]])
        else:
            logging.info("[PROCESSING][CLASSIFICATION] Output from FC layer exist in Embedding results")
        results = sorted(results, key=lambda x: x[1], reverse=True)
        return results
    
    def __get_confidence(self, dist):
        min_dist = 4.2
        max_dist = self.THRESHOLD
        delta = max_dist - min_dist
        return 1.0 - (max(min(max_dist, dist), min_dist) - min_dist) / delta
    
    def inference_numpy(self, img, top_k=10):
        image = Image.fromarray(img)
        image = self.loader(image)
        
        return self.__inference(image, top_k)
    
    def batch_inference(self, imgs):
        batch_input = []
        for idx in range(len(imgs)):
            image = Image.fromarray(imgs[idx])
            image = self.loader(image)
            batch_input.append(image)

        batch_input = torch.stack(batch_input)
        dump_embeds, class_ids = self.model(batch_input)
        
        logging.info("[PROCESSING][CLASSIFICATION] Classification by Full Connected layer for a single detection mask")  
        classes, scores = self.__classify_fc(class_ids)
       
        outputs = []
        for output_id in range(len(classes)):

            logging.info("[PROCESSING][CLASSIFICATION] Classification by embedding for a single detection mask")
            output_by_embeddings = self.__classify_embedding(dump_embeds[output_id])
            result = self.__beautifier_output(output_by_embeddings, self.indexes_of_elements['categories'][str(classes[output_id].item())]['name'])
            result = [[self.__get_species_name(record[0]), record] for record in result]
            
            outputs.append(result)
            
        return outputs
    
    def __classify_fc(self, output):
        acc_values = self.softmax(output)
        class_id = torch.argmax(acc_values, dim=1)
        #print(f"Recognized species id {class_id} with liklyhood: {acc_values[0][class_id]}")
        return class_id, acc_values

    def __classify_embedding(self, embedding, top_k = 15):
        diff = (self.data_base - embedding).pow(2).sum(dim=1).sqrt()
        val, indi = torch.sort(diff)
        class_lib = [[self.indexes_of_elements['list_of_ids'][indiece], diff[indiece]] for indiece in indi[:top_k]]
        class_lib = [[self.indexes_of_elements['categories'][str(rec[0][0])],rec[0], rec[1]] for rec in class_lib]
        return class_lib

    
    def __get_species_name(self, category_name):
        for i in self.indexes_of_elements['categories']:
            if self.indexes_of_elements['categories'][i]['name'] == category_name:
                return self.indexes_of_elements['categories'][i]['species_id']

In [None]:
model_path = '/home/fishial/Fishial/output/classification/resnet_18_triplet_08_09_2023_v06_under_train_cross/model.ts'
embedding_path = '/home/fishial/Fishial/output/classification/resnet_18_triplet_05_09_2023/embeddings_new.pt'
index_path  = '/home/fishial/Fishial/output/classification/resnet_18_triplet_05_09_2023/idx_new_short.json'

model = EmbeddingClassifier(model_path, embedding_path, index_path)
