In [None]:
import detectron2
import torchvision
import pickle
import json
import cv2
import matplotlib.pyplot as plt
import torch

from detectron2 import model_zoo
from detectron2.data import Metadata
from detectron2.structures import BoxMode
from detectron2.utils.visualizer import Visualizer
from detectron2.config import get_cfg
from detectron2.utils.visualizer import ColorMode
from detectron2.modeling import build_model
import detectron2.data.transforms as T
from detectron2.checkpoint import DetectionCheckpointer

from matplotlib.pyplot import imshow
from PIL import Image

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
'''This is where you will set all the relevant config file and weight file variables'''
CONFIG_FILE = "" # Training specific config file for fathomnet
WEIGHTS_FILE = "" # path to the model with fathomnet weights
NMS_THRESH = 0.45 # This is where you can set an nms threshold for the all boxes results
SCORE_THRESH = 0.3 # This is where you can set the model score threshold

In [None]:
fathomnet_metadata = Metadata(
    name='fathomnet_val',
    thing_classes=[
         'Anemone',
         'Fish',
         'Eel',
         'Gastropod',
         'Sea star',
         'Feather star',
         'Sea cucumber',
         'Urchin',
         'Glass sponge',
         'Sea fan',
         'Soft coral',
         'Sea pen',
         'Stony coral',
         'Ray',
         'Crab',
         'Shrimp',
         'Squat lobster',
         'Flatfish',
         'Sea spider',
         'Worm']
)

In [None]:
'''
This is where the model parameters are instantiated. There is a LOT of nested arguments in these yaml files, 
and the merging of baseline defaults plus dataset specific parameters.
'''

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/retinanet_R_50_FPN_3x.yaml"))
cfg.merge_from_file(CONFIG_FILE)
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = SCORE_THRESH
cfg.MODEL.WEIGHTS = WEIGHTS_FILE 

In [None]:
'''
Loading of the model weights, but more importantly this is where the model is actually instantiated as something that 
can take inputs and provide outputs. There is a lot of documentation about this, but not much in the way of 
straightforward tutorials.
'''

model = build_model(cfg)  # returns a torch.nn.Module
checkpointer = DetectionCheckpointer(model)
checkpointer.load(cfg.MODEL.WEIGHTS)
model.eval()

In [None]:
'''
This is where you can specify how to resize an image. There are two examples here, be aware of which one you are
calling later when preprocessing an image.
'''
aug = T.ResizeShortestEdge(
        short_edge_length=[cfg.INPUT.MIN_SIZE_TEST], max_size=cfg.INPUT.MAX_SIZE_TEST, sample_style="choice"
    )
aug1 = T.ResizeShortestEdge(
            short_edge_length=[1080], max_size=1980, sample_style="choice"
    )

In [None]:
'''
We use a separate NMS layer because initially detectron only does nms intra class, so we want to do nms on all boxes.
'''
post_process_nms = torchvision.ops.nms

In [None]:
TEST_IMAGE = "" # Path to image you want to run the model on

In [None]:
'''
Load the image, get the height and width, perform augmentation, run the model, perform nms thresholding. Also instantiate
a useful object for visualizing the outputs
'''

im = cv2.imread(TEST_IMAGE)
im_height,im_width,_ = im.shape
v_inf = Visualizer(im[:, :, ::-1],
               metadata=fathomnet_metadata, 
               scale=1.0, 
               instance_mode=ColorMode.IMAGE_BW)
im = aug1.get_transform(im).apply_image(im)
with torch.no_grad():
    im = torch.as_tensor(im.astype("float32").transpose(2, 0, 1))
    model_outputs = model([{"image" : im, "height" : im_height, "width" : im_width}])[0]

model_outputs["instances"] = model_outputs["instances"][post_process_nms(model_outputs["instances"].pred_boxes.tensor, model_outputs["instances"].scores, NMS_THRESH).to("cpu").tolist()]

out_inf_raw = v_inf.draw_instance_predictions(model_outputs["instances"].to("cpu"))
im_pil_inf_raw = Image.fromarray(out_inf_raw.get_image())

In [None]:
'''Visualize the outputs'''
plt.figure(figsize=(24,16))
plt.imshow(im_pil_inf_raw)