In [1]:
import os
import random
import numpy as np
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.config import get_cfg
from detectron2.data.datasets import register_coco_instances
from detectron2.engine import DefaultTrainer, DefaultPredictor
from detectron2.model_zoo import model_zoo
from detectron2.utils.visualizer import Visualizer
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.data import build_detection_test_loader

# Register your training dataset with Detectron2
register_coco_instances("custom_dataset_train", {}, "train_output_coco_annotations.json", "train/images")

# Register your test dataset with Detectron2
register_coco_instances("custom_dataset_test", {}, "test_output_coco_annotations.json", "test/images")

# Define metadata for your training dataset (class names)
metadata_train = MetadataCatalog.get("custom_dataset_train")

# Define metadata for your test dataset (class names)
metadata_test = MetadataCatalog.get("custom_dataset_test")

# Create a configuration
cfg = get_cfg()

# Set your custom configuration options here, for example:
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.DATASETS.TRAIN = ("custom_dataset_train",)
cfg.DATASETS.TEST = ("custom_dataset_test",)  # Include your test dataset
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 1111  # Increase the maximum number of iterations
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.MODEL.DEVICE = 'cpu'  # Use 'cuda' if GPU is available

# Instantiate a trainer
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)

# Train the model
trainer.train()

# Optionally, evaluate the model on the test set
evaluator = COCOEvaluator("custom_dataset_test", cfg, False, output_dir="./output/")
test_loader = build_detection_test_loader(cfg, "custom_dataset_test")
test_results = inference_on_dataset(trainer.model, test_loader, evaluator)

# Save the model's configuration to a YAML file
model_config_path = "model_config.yaml"
with open(model_config_path, "w") as f:
    f.write(cfg.dump())

print(f"Model's configuration saved to {model_config_path}")

# Print and visualize the test results
print("Test results:")
print(test_results)


[32m[09/09 00:59:46 d2.engine.defaults]: [0mModel:
GeneralizedRCNN(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
      )
 

[32m[09/09 00:59:47 d2.data.datasets.coco]: [0mLoaded 5690 images in COCO format from train_output_coco_annotations.json
[32m[09/09 00:59:47 d2.data.build]: [0mRemoved 4198 images with no usable annotations. 1492 images left.
[32m[09/09 00:59:47 d2.data.build]: [0mDistribution of instances among all 1 categories:
[36m|  category  | #instances   |
|:----------:|:-------------|
|   damage   | 32467        |
|            |              |[0m
[32m[09/09 00:59:47 d2.data.dataset_mapper]: [0m[DatasetMapper] Augmentations used in training: [ResizeShortestEdge(short_edge_length=(640, 672, 704, 736, 768, 800), max_size=1333, sample_style='choice'), RandomFlip()]
[32m[09/09 00:59:47 d2.data.build]: [0mUsing training sampler TrainingSampler
[32m[09/09 00:59:47 d2.data.common]: [0mSerializing the dataset using: <class 'detectron2.data.common._TorchSerializedList'>
[32m[09/09 00:59:47 d2.data.common]: [0mSerializing 1492 elements to byte tensors and concatenating them all ...
[32m[0

Skip loading parameter 'roi_heads.box_predictor.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (2, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (2,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.weight' to the model due to incompatible shapes: (320, 1024) in the checkpoint but (4, 1024) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.box_predictor.bbox_pred.bias' to the model due to incompatible shapes: (320,) in the checkpoint but (4,) in the model! You might want to double check if this is expected.
Skip loading parameter 'roi_heads.mask_head.predictor.weight' to the model due to incompatible shapes: (80, 256, 1, 1) in the checkpoint but (1, 256, 1, 1) in

[32m[09/09 00:59:47 d2.engine.train_loop]: [0mStarting training from iteration 0


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


[32m[09/09 01:01:07 d2.utils.events]: [0m eta: 1:09:32  iter: 19  total_loss: 4.188  loss_cls: 0.7786  loss_box_reg: 0.261  loss_mask: 0.6933  loss_rpn_cls: 2.115  loss_rpn_loc: 0.2883    time: 3.8693  last_time: 4.4125  data_time: 0.0709  last_data_time: 0.0013   lr: 4.9953e-06  
[32m[09/09 01:02:36 d2.utils.events]: [0m eta: 1:14:36  iter: 39  total_loss: 3.891  loss_cls: 0.7078  loss_box_reg: 0.1829  loss_mask: 0.691  loss_rpn_cls: 1.942  loss_rpn_loc: 0.4196    time: 4.0669  last_time: 4.0970  data_time: 0.0016  last_data_time: 0.0012   lr: 9.9902e-06  
[32m[09/09 01:03:57 d2.utils.events]: [0m eta: 1:12:52  iter: 59  total_loss: 2.704  loss_cls: 0.6341  loss_box_reg: 0.3535  loss_mask: 0.6862  loss_rpn_cls: 0.7001  loss_rpn_loc: 0.3549    time: 4.0556  last_time: 4.0599  data_time: 0.0013  last_data_time: 0.0017   lr: 1.4985e-05  
[32m[09/09 01:05:19 d2.utils.events]: [0m eta: 1:12:28  iter: 79  total_loss: 2.272  loss_cls: 0.5406  loss_box_reg: 0.2017  loss_mask: 0.6804  

[32m[09/09 01:39:16 d2.utils.events]: [0m eta: 0:34:15  iter: 599  total_loss: 1.817  loss_cls: 0.3818  loss_box_reg: 0.3158  loss_mask: 0.4499  loss_rpn_cls: 0.2546  loss_rpn_loc: 0.2534    time: 3.9369  last_time: 4.3976  data_time: 0.0012  last_data_time: 0.0008   lr: 0.00014985  
[32m[09/09 01:40:31 d2.utils.events]: [0m eta: 0:32:50  iter: 619  total_loss: 1.744  loss_cls: 0.3448  loss_box_reg: 0.3154  loss_mask: 0.4771  loss_rpn_cls: 0.3075  loss_rpn_loc: 0.2816    time: 3.9321  last_time: 3.3826  data_time: 0.0012  last_data_time: 0.0012   lr: 0.00015485  
[32m[09/09 01:41:48 d2.utils.events]: [0m eta: 0:31:30  iter: 639  total_loss: 1.85  loss_cls: 0.3631  loss_box_reg: 0.3785  loss_mask: 0.4789  loss_rpn_cls: 0.2928  loss_rpn_loc: 0.3147    time: 3.9296  last_time: 4.1614  data_time: 0.0012  last_data_time: 0.0009   lr: 0.00015984  
[32m[09/09 01:43:06 d2.utils.events]: [0m eta: 0:30:07  iter: 659  total_loss: 1.944  loss_cls: 0.3629  loss_box_reg: 0.4342  loss_mask: 0

[32m[09/09 02:12:48 d2.data.dataset_mapper]: [0m[DatasetMapper] Augmentations used in inference: [ResizeShortestEdge(short_edge_length=(800, 800), max_size=1333, sample_style='choice')]
[32m[09/09 02:12:48 d2.data.common]: [0mSerializing the dataset using: <class 'detectron2.data.common._TorchSerializedList'>
[32m[09/09 02:12:48 d2.data.common]: [0mSerializing 1866 elements to byte tensors and concatenating them all ...
[32m[09/09 02:12:48 d2.data.common]: [0mSerialized dataset takes 3.36 MiB
[32m[09/09 02:12:49 d2.data.datasets.coco]: [0mLoaded 1866 images in COCO format from test_output_coco_annotations.json
[32m[09/09 02:12:49 d2.data.dataset_mapper]: [0m[DatasetMapper] Augmentations used in inference: [ResizeShortestEdge(short_edge_length=(800, 800), max_size=1333, sample_style='choice')]
[32m[09/09 02:12:49 d2.data.common]: [0mSerializing the dataset using: <class 'detectron2.data.common._TorchSerializedList'>
[32m[09/09 02:12:49 d2.data.common]: [0mSerializing 186

[32m[09/09 02:16:14 d2.evaluation.evaluator]: [0mInference done 191/1866. Dataloading: 0.0003 s/iter. Inference: 0.9549 s/iter. Eval: 0.1069 s/iter. Total: 1.0623 s/iter. ETA=0:29:39
[32m[09/09 02:16:20 d2.evaluation.evaluator]: [0mInference done 196/1866. Dataloading: 0.0003 s/iter. Inference: 0.9553 s/iter. Eval: 0.1074 s/iter. Total: 1.0632 s/iter. ETA=0:29:35
[32m[09/09 02:16:25 d2.evaluation.evaluator]: [0mInference done 201/1866. Dataloading: 0.0003 s/iter. Inference: 0.9553 s/iter. Eval: 0.1074 s/iter. Total: 1.0632 s/iter. ETA=0:29:30
[32m[09/09 02:16:30 d2.evaluation.evaluator]: [0mInference done 206/1866. Dataloading: 0.0003 s/iter. Inference: 0.9558 s/iter. Eval: 0.1074 s/iter. Total: 1.0637 s/iter. ETA=0:29:25
[32m[09/09 02:16:36 d2.evaluation.evaluator]: [0mInference done 211/1866. Dataloading: 0.0003 s/iter. Inference: 0.9557 s/iter. Eval: 0.1074 s/iter. Total: 1.0637 s/iter. ETA=0:29:20
[32m[09/09 02:16:41 d2.evaluation.evaluator]: [0mInference done 216/1866.

[32m[09/09 02:20:17 d2.evaluation.evaluator]: [0mInference done 416/1866. Dataloading: 0.0003 s/iter. Inference: 0.9620 s/iter. Eval: 0.1080 s/iter. Total: 1.0705 s/iter. ETA=0:25:52
[32m[09/09 02:20:22 d2.evaluation.evaluator]: [0mInference done 421/1866. Dataloading: 0.0003 s/iter. Inference: 0.9616 s/iter. Eval: 0.1080 s/iter. Total: 1.0701 s/iter. ETA=0:25:46
[32m[09/09 02:20:27 d2.evaluation.evaluator]: [0mInference done 426/1866. Dataloading: 0.0003 s/iter. Inference: 0.9618 s/iter. Eval: 0.1080 s/iter. Total: 1.0703 s/iter. ETA=0:25:41
[32m[09/09 02:20:32 d2.evaluation.evaluator]: [0mInference done 431/1866. Dataloading: 0.0003 s/iter. Inference: 0.9617 s/iter. Eval: 0.1080 s/iter. Total: 1.0701 s/iter. ETA=0:25:35
[32m[09/09 02:20:38 d2.evaluation.evaluator]: [0mInference done 436/1866. Dataloading: 0.0003 s/iter. Inference: 0.9624 s/iter. Eval: 0.1080 s/iter. Total: 1.0709 s/iter. ETA=0:25:31
[32m[09/09 02:20:44 d2.evaluation.evaluator]: [0mInference done 441/1866.

[32m[09/09 02:24:19 d2.evaluation.evaluator]: [0mInference done 641/1866. Dataloading: 0.0003 s/iter. Inference: 0.9637 s/iter. Eval: 0.1080 s/iter. Total: 1.0723 s/iter. ETA=0:21:53
[32m[09/09 02:24:24 d2.evaluation.evaluator]: [0mInference done 646/1866. Dataloading: 0.0003 s/iter. Inference: 0.9636 s/iter. Eval: 0.1080 s/iter. Total: 1.0721 s/iter. ETA=0:21:47
[32m[09/09 02:24:29 d2.evaluation.evaluator]: [0mInference done 651/1866. Dataloading: 0.0003 s/iter. Inference: 0.9636 s/iter. Eval: 0.1080 s/iter. Total: 1.0721 s/iter. ETA=0:21:42
[32m[09/09 02:24:34 d2.evaluation.evaluator]: [0mInference done 656/1866. Dataloading: 0.0003 s/iter. Inference: 0.9635 s/iter. Eval: 0.1080 s/iter. Total: 1.0721 s/iter. ETA=0:21:37
[32m[09/09 02:24:40 d2.evaluation.evaluator]: [0mInference done 661/1866. Dataloading: 0.0003 s/iter. Inference: 0.9635 s/iter. Eval: 0.1080 s/iter. Total: 1.0720 s/iter. ETA=0:21:31
[32m[09/09 02:24:45 d2.evaluation.evaluator]: [0mInference done 666/1866.

[32m[09/09 02:28:20 d2.evaluation.evaluator]: [0mInference done 866/1866. Dataloading: 0.0003 s/iter. Inference: 0.9639 s/iter. Eval: 0.1079 s/iter. Total: 1.0723 s/iter. ETA=0:17:52
[32m[09/09 02:28:25 d2.evaluation.evaluator]: [0mInference done 871/1866. Dataloading: 0.0003 s/iter. Inference: 0.9638 s/iter. Eval: 0.1079 s/iter. Total: 1.0723 s/iter. ETA=0:17:46
[32m[09/09 02:28:31 d2.evaluation.evaluator]: [0mInference done 876/1866. Dataloading: 0.0003 s/iter. Inference: 0.9639 s/iter. Eval: 0.1079 s/iter. Total: 1.0723 s/iter. ETA=0:17:41
[32m[09/09 02:28:36 d2.evaluation.evaluator]: [0mInference done 881/1866. Dataloading: 0.0003 s/iter. Inference: 0.9638 s/iter. Eval: 0.1079 s/iter. Total: 1.0723 s/iter. ETA=0:17:36
[32m[09/09 02:28:41 d2.evaluation.evaluator]: [0mInference done 886/1866. Dataloading: 0.0003 s/iter. Inference: 0.9638 s/iter. Eval: 0.1079 s/iter. Total: 1.0722 s/iter. ETA=0:17:30
[32m[09/09 02:28:47 d2.evaluation.evaluator]: [0mInference done 891/1866.

[32m[09/09 02:32:23 d2.evaluation.evaluator]: [0mInference done 1091/1866. Dataloading: 0.0003 s/iter. Inference: 0.9653 s/iter. Eval: 0.1080 s/iter. Total: 1.0738 s/iter. ETA=0:13:52
[32m[09/09 02:32:28 d2.evaluation.evaluator]: [0mInference done 1096/1866. Dataloading: 0.0003 s/iter. Inference: 0.9653 s/iter. Eval: 0.1080 s/iter. Total: 1.0738 s/iter. ETA=0:13:46
[32m[09/09 02:32:33 d2.evaluation.evaluator]: [0mInference done 1101/1866. Dataloading: 0.0003 s/iter. Inference: 0.9653 s/iter. Eval: 0.1080 s/iter. Total: 1.0738 s/iter. ETA=0:13:41
[32m[09/09 02:32:39 d2.evaluation.evaluator]: [0mInference done 1106/1866. Dataloading: 0.0003 s/iter. Inference: 0.9652 s/iter. Eval: 0.1080 s/iter. Total: 1.0737 s/iter. ETA=0:13:36
[32m[09/09 02:32:44 d2.evaluation.evaluator]: [0mInference done 1111/1866. Dataloading: 0.0003 s/iter. Inference: 0.9652 s/iter. Eval: 0.1080 s/iter. Total: 1.0737 s/iter. ETA=0:13:30
[32m[09/09 02:32:49 d2.evaluation.evaluator]: [0mInference done 1116

[32m[09/09 02:36:26 d2.evaluation.evaluator]: [0mInference done 1316/1866. Dataloading: 0.0003 s/iter. Inference: 0.9662 s/iter. Eval: 0.1082 s/iter. Total: 1.0750 s/iter. ETA=0:09:51
[32m[09/09 02:36:31 d2.evaluation.evaluator]: [0mInference done 1321/1866. Dataloading: 0.0003 s/iter. Inference: 0.9663 s/iter. Eval: 0.1082 s/iter. Total: 1.0750 s/iter. ETA=0:09:45
[32m[09/09 02:36:36 d2.evaluation.evaluator]: [0mInference done 1326/1866. Dataloading: 0.0003 s/iter. Inference: 0.9663 s/iter. Eval: 0.1082 s/iter. Total: 1.0750 s/iter. ETA=0:09:40
[32m[09/09 02:36:42 d2.evaluation.evaluator]: [0mInference done 1331/1866. Dataloading: 0.0003 s/iter. Inference: 0.9664 s/iter. Eval: 0.1082 s/iter. Total: 1.0751 s/iter. ETA=0:09:35
[32m[09/09 02:36:47 d2.evaluation.evaluator]: [0mInference done 1336/1866. Dataloading: 0.0003 s/iter. Inference: 0.9665 s/iter. Eval: 0.1082 s/iter. Total: 1.0752 s/iter. ETA=0:09:29
[32m[09/09 02:36:53 d2.evaluation.evaluator]: [0mInference done 1341

[32m[09/09 02:40:25 d2.evaluation.evaluator]: [0mInference done 1541/1866. Dataloading: 0.0003 s/iter. Inference: 0.9650 s/iter. Eval: 0.1082 s/iter. Total: 1.0737 s/iter. ETA=0:05:48
[32m[09/09 02:40:31 d2.evaluation.evaluator]: [0mInference done 1546/1866. Dataloading: 0.0003 s/iter. Inference: 0.9649 s/iter. Eval: 0.1082 s/iter. Total: 1.0737 s/iter. ETA=0:05:43
[32m[09/09 02:40:36 d2.evaluation.evaluator]: [0mInference done 1551/1866. Dataloading: 0.0003 s/iter. Inference: 0.9649 s/iter. Eval: 0.1082 s/iter. Total: 1.0736 s/iter. ETA=0:05:38
[32m[09/09 02:40:41 d2.evaluation.evaluator]: [0mInference done 1556/1866. Dataloading: 0.0003 s/iter. Inference: 0.9649 s/iter. Eval: 0.1082 s/iter. Total: 1.0736 s/iter. ETA=0:05:32
[32m[09/09 02:40:47 d2.evaluation.evaluator]: [0mInference done 1561/1866. Dataloading: 0.0003 s/iter. Inference: 0.9648 s/iter. Eval: 0.1082 s/iter. Total: 1.0735 s/iter. ETA=0:05:27
[32m[09/09 02:40:52 d2.evaluation.evaluator]: [0mInference done 1566

[32m[09/09 02:44:24 d2.evaluation.evaluator]: [0mInference done 1766/1866. Dataloading: 0.0003 s/iter. Inference: 0.9633 s/iter. Eval: 0.1081 s/iter. Total: 1.0719 s/iter. ETA=0:01:47
[32m[09/09 02:44:29 d2.evaluation.evaluator]: [0mInference done 1771/1866. Dataloading: 0.0003 s/iter. Inference: 0.9633 s/iter. Eval: 0.1081 s/iter. Total: 1.0719 s/iter. ETA=0:01:41
[32m[09/09 02:44:35 d2.evaluation.evaluator]: [0mInference done 1776/1866. Dataloading: 0.0003 s/iter. Inference: 0.9633 s/iter. Eval: 0.1081 s/iter. Total: 1.0719 s/iter. ETA=0:01:36
[32m[09/09 02:44:40 d2.evaluation.evaluator]: [0mInference done 1781/1866. Dataloading: 0.0003 s/iter. Inference: 0.9633 s/iter. Eval: 0.1081 s/iter. Total: 1.0719 s/iter. ETA=0:01:31
[32m[09/09 02:44:45 d2.evaluation.evaluator]: [0mInference done 1786/1866. Dataloading: 0.0003 s/iter. Inference: 0.9633 s/iter. Eval: 0.1081 s/iter. Total: 1.0719 s/iter. ETA=0:01:25
[32m[09/09 02:44:51 d2.evaluation.evaluator]: [0mInference done 1791

In [3]:
import os
import random
import numpy as np
from detectron2.data import MetadataCatalog
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
import cv2

# Load the saved model weights and configuration
cfg = get_cfg()
cfg.merge_from_file("model_config.yaml")  # Replace with the path to your model's configuration file
cfg.MODEL.WEIGHTS = "output/model_final.pth"  # Replace with the path to your saved model weights

# Set the device to CPU
cfg.MODEL.DEVICE = 'cpu'

# Create a predictor using the loaded model
predictor = DefaultPredictor(cfg)

# Get a random image from the "train/images" folder
image_folder = 'train/images'  # Replace with the path to your image folder
image_files = os.listdir(image_folder)
random_image_file = random.choice(image_files)
image_path = os.path.join(image_folder, random_image_file)

# Read the random image
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Get predictions on the image
outputs = predictor(image)

# Filter predictions with confidence > 75%
instances = outputs["instances"]
filtered_instances = instances[instances.scores > 0.75]

# Visualize the original image with filtered predictions
v = Visualizer(image, metadata=MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
out = v.draw_instance_predictions(filtered_instances.to("cpu"))

# Show the image with predictions
cv2.imshow("Predictions", out.get_image()[:, :, ::-1])

# Construct the path to the corresponding mask image
mask_folder = 'train/binned_targets'  # Replace with the path to your mask folder
mask_file = os.path.join(mask_folder, os.path.splitext(random_image_file)[0] + '_target.png')

# Read and show the mask image
mask_image = cv2.imread(mask_file)
cv2.imshow("Mask", mask_image)

cv2.waitKey(0)
# Close all OpenCV windows when any key is pressed
cv2.destroyAllWindows()
cv2.waitKey(1)

-1