In [None]:
from train import *

from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
import matplotlib.pyplot as plt

import random
import cv2

%matplotlib qt

In [None]:
DATA_DIR = Path("../examples/dot_configuration/data")
CONFIG_FILE = Path("../examples/dot_configuration/configuration.yaml")
TRAINED_MODEL_PTH = "../examples/dot_configuration/trained_models/dot_configuration.pth"

TRAIN_DIR = DATA_DIR / "train"
VAL_DIR = DATA_DIR / "val"
TEST_DIR = DATA_DIR / "test"

In [None]:
info, hyperparams = parse_configuration_file(CONFIG_FILE)

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
# cfg.DATASETS.TRAIN = ("csd_train",)
# cfg.DATASETS.TEST = ()

cfg.MODEL.DEVICE = "cpu" 
cfg.DATALOADER.NUM_WORKERS = 0

cfg.SOLVER.IMS_PER_BATCH = hyperparams['batch_num']
cfg.SOLVER.BASE_LR = hyperparams['learning_rate']
cfg.SOLVER.MAX_ITER = hyperparams['num_epochs']
cfg.SOLVER.STEPS = []        
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = hyperparams['batch_size_per_img']

cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(list(info['classes'].keys()))

# os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

# Inference should use the config with parameters that are used in training
# cfg now already contains everything we've set previously. We changed it a little bit for inference:
cfg.MODEL.WEIGHTS = TRAINED_MODEL_PTH
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.9  # set a custom testing threshold
predictor = DefaultPredictor(cfg)

### Test on Validation Data

In [None]:
DatasetCatalog.clear()
MetadataCatalog.clear()

for d in ["train", "val"]:
    DatasetCatalog.register(info["name"] + " " + d, lambda d=d: construct_dataset_dict(DATA_DIR / d, info["classes"]))
    MetadataCatalog.get(info["name"] + " " + d).set(thing_classes=list(info["classes"].keys()))

train_metadata = MetadataCatalog.get(info["name"] + " " + "train")
val_metadata = MetadataCatalog.get(info["name"] + " " + "val")

def check_image_exists(directory, image_name):
    # Get the list of files in the directory
    files = os.listdir(directory)

    # Check if the image_name exists with any image extension
    for file in files:
        if file == image_name:
            print(file, image_name)
            return True

    return False

dataset_dict = DatasetCatalog.get(info["name"] + " " + "val")
for d in random.sample(dataset_dict, 1):    

    im = cv2.imread(d['file_name'])
    outputs = predictor(im)

    # See raw predictions
    v = Visualizer(
        im[:, :, ::-1],
        metadata=val_metadata, 
        scale=1,
    )
    out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    plt.figure(layout='tight')
    plt.title("Predictions")
    plt.imshow(out.get_image())
    plt.show()

    # See raw annotations
    visualizer = Visualizer(
        im[:, :, ::-1], 
        metadata=val_metadata, 
        scale=1,
    )
    out = visualizer.draw_dataset_dict(d)
    plt.figure(layout='tight')
    plt.title("Annotations")
    plt.imshow(out.get_image()[:, :, ::-1])
    plt.show()

### Test on Experimental Data

In [None]:
all_test_exp_images = list(TEST_DIR.glob("exp*.jpg"))
all_val_exp_images = list(VAL_DIR.glob("exp*.jpg"))
random_test_img = random.choice(all_test_exp_images + all_val_exp_images)
random_test_img = cv2.imread(random_test_img)

outputs = predictor(random_test_img) 
v = Visualizer(
    random_test_img[:, :, ::-1],
    metadata=val_metadata, 
    scale=1,   
)

out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
plt.figure(layout='tight')
plt.title("Predictions")
plt.imshow(out.get_image())
plt.show()