# Transform predictions from Faster R-CNN to LabelMe format - For AI-Assisted labelling

In [3]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN_ResNet50_FPN_Weights

def get_detection_model(num_classes, weights=FasterRCNN_ResNet50_FPN_Weights):
    # load a model pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=weights)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    return model

In [4]:
# Create LabelMe dictionary

labelme_dict = {'version':'5.3.1',
                'flags':{},
                'shapes':[],
                'imagePath':'',
                'imageData': None,
                'imageHeight':1440,
                'imageWidth':1920
                }

In [8]:
# Make detection predictions and format predictions into LabelMe format
import torch
import matplotlib.pyplot as plt
from torchvision.utils import draw_bounding_boxes
from torchvision.io import read_image
import os, json

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model_weights = "/fasterrcnn_resnet50_fpn/Christian_5epochs/model_weights1_Masteroppgave.pt" # Husk å bytte til riktige vekter!
PATH = "/Users/magnuswiik/Documents/NTNU/5.klasse/prosjektoppgave/FishID/models/" + model_weights

target_folder = '/Users/magnuswiik/prosjektoppgave_data/Masteroppgave_data/AI_Assisted_Deteksjonsset_Helfisk/'

imgs = list(sorted(os.listdir(target_folder)))

num_classes = 2

model = get_detection_model(num_classes)
model.to(device)

model.load_state_dict(torch.load(PATH))
model.eval()

for img in imgs:
    if "DS" not in img:
        filename= img[:-3] + 'json'
        with open(target_folder + filename, 'w') as file:
            image = read_image(target_folder + img)
            image = image.float() / 255.0

            with torch.no_grad():
                # convert RGBA -> RGB and move to device
                x = image[:3, ...].to(device)
                predictions = model([x, ])
                pred = predictions[0]

            image = (255.0 * (image - image.min()) / (image.max() - image.min())).to(torch.uint8)
            image = image[:3, ...]
            pred_labels = [f"salmon: {score:.3f}" for label, score in zip(pred["labels"], pred["scores"])]
            pred_boxes = pred["boxes"].long().tolist()
            output = labelme_dict.copy()
            
            output['imagePath'] = img
            shapes_list = []
            
            for i in range(len(pred_labels)):
                shape_dict = {"label": "Salmon", "points": [], 'group_id': None, 'description': '', 'shape_type': 'rectangle', 'flags': {}}
                shapes_list.append(shape_dict)

            for i in range(len(pred_boxes)):
                box = pred_boxes[i]
                shape_dict = shapes_list[i]
                shape_dict['points'] = [box[:2], box[2:]]
            
            output['shapes'] = shapes_list
            json.dump(output, file, indent=2)
    



#output_image = draw_bounding_boxes(image, pred_boxes, pred_labels, colors="black", width=3, font="/System/Library/Fonts/Helvetica.ttc", font_size=30)

#print(pred["scores"])

#plt.figure(figsize=(12, 12))
#plt.imshow(output_image.permute(1, 2, 0))

