In [18]:
from train import *

from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer

import matplotlib.pyplot as plt
import random
import cv2

%matplotlib qt

In [24]:
MODEL = "charge_configuration"

DATA_DIR = Path(f"../examples/{MODEL}/data")
CONFIG_FILE = Path(f"../examples/{MODEL}/configuration.yaml")
TRAINED_MODEL_PTH = f"../examples/{MODEL}/trained_models/{MODEL}.pth"

TRAIN_DIR = DATA_DIR / "train"
VAL_DIR = DATA_DIR / "val"
TEST_DIR = DATA_DIR / "test"

CONFIDENCE_THRESHOLD = 0.

In [25]:
info, hyperparams = parse_configuration_file(CONFIG_FILE)

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))

cfg.MODEL.DEVICE = "cpu" 
cfg.DATALOADER.NUM_WORKERS = 0
cfg.SOLVER.IMS_PER_BATCH = hyperparams['batch_num']
cfg.SOLVER.BASE_LR = hyperparams['learning_rate']
cfg.SOLVER.MAX_ITER = hyperparams['num_epochs']
cfg.SOLVER.STEPS = []        
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = hyperparams['batch_size_per_img']
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(list(info['classes'].keys()))
cfg.MODEL.WEIGHTS = TRAINED_MODEL_PTH
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = CONFIDENCE_THRESHOLD  # set a custom testing threshold

predictor = DefaultPredictor(cfg)

[32m[06/24 11:38:45 d2.checkpoint.detection_checkpoint]: [0m[DetectionCheckpointer] Loading from ../examples/charge_configuration/trained_models/charge_configuration.pth ...


In [26]:
def filter_highest_confidence_per_class(instances):
    # Extract scores and predicted classes
    scores = instances.scores
    pred_classes = instances.pred_classes

    # Create a dictionary to store the highest confidence for each class
    highest_confidence_per_class = {}

    # Iterate through scores and classes
    for idx, (score, cls) in enumerate(zip(scores, pred_classes)):
        cls = int(cls)  # Ensure the class is treated as an integer
        if cls not in highest_confidence_per_class:
            highest_confidence_per_class[cls] = (score, idx)
        else:
            if score > highest_confidence_per_class[cls][0]:
                highest_confidence_per_class[cls] = (score, idx)

    # Create a list to store the indices of instances to keep
    keep_indices = [idx for (_, idx) in highest_confidence_per_class.values()]

    # Filter the instances
    filtered_instances = instances[keep_indices]
    
    return filtered_instances

### Test on Validation Data

In [27]:
DatasetCatalog.clear()
MetadataCatalog.clear()

for d in ["train", "val"]:
    DatasetCatalog.register(info["name"] + " " + d, lambda d=d: construct_dataset_dict(DATA_DIR / d, info["classes"]))
    MetadataCatalog.get(info["name"] + " " + d).set(thing_classes=list(info["classes"].keys()))

train_metadata = MetadataCatalog.get(info["name"] + " " + "train")
val_metadata = MetadataCatalog.get(info["name"] + " " + "val")

dataset_dict = DatasetCatalog.get(info["name"] + " " + "val")
for d in random.sample(dataset_dict, 1):    

    im = cv2.imread(d['file_name'])
    outputs = predictor(im)
    outputs["instances"] = filter_highest_confidence_per_class(outputs["instances"])

    # See raw predictions
    v = Visualizer(
        im[:, :, ::-1],
        metadata=val_metadata, 
        scale=5,
    )
    out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    plt.figure(layout='tight')
    plt.title("Predictions")
    plt.imshow(out.get_image())
    plt.show()

    # See raw annotations
    visualizer = Visualizer(
        im[:, :, ::-1], 
        metadata=val_metadata, 
        scale=5,
    )
    out = visualizer.draw_dataset_dict(d)
    plt.figure(layout='tight')
    plt.title("Annotations")
    plt.imshow(out.get_image()[:, :, ::-1])
    plt.show()

### Test on Experimental Data

In [28]:
MODEL = "dot_configuration"

DATA_DIR = Path(f"../examples/{MODEL}/data/sensor")

TRAIN_DIR = DATA_DIR / "train"
VAL_DIR = DATA_DIR / "val"
TEST_DIR = DATA_DIR / "test"

all_test_exp_images = list(TEST_DIR.glob("exp*.jpg"))
all_val_exp_images = list(VAL_DIR.glob("exp*.jpg"))
random_test_img = random.choice(all_test_exp_images + all_val_exp_images)
random_test_img = cv2.imread(random_test_img)

# random_test_img = cv2.imread("/Users/andrijapaurevic/Downloads/raw_bb.png")

outputs = predictor(random_test_img) 

outputs["instances"] = filter_highest_confidence_per_class(outputs["instances"])

v = Visualizer(
    random_test_img[:, :, ::-1],
    metadata=val_metadata, 
    scale=5,   
)

out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
plt.figure(layout='tight')
plt.title("Predictions")
plt.imshow(out.get_image())
plt.show()