In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision

from data.dataset_utils import get_dataset, get_transforms

In [None]:
# Load validation dataset
bbox_format = 'pascal_voc'

dataset = get_dataset(
    'val', 
    '/home/thuynh/data/04_model_input', 
    get_transforms('val', params={'transforms': {'resize': 512, 'min_area': 900, 'min_visibility': 0.25}, 'format': bbox_format}, normalize=True),
    bbox_format
)

In [None]:
# Dataset specifications
ckpt_path = '/home/thuynh/torchvision_tutorial/runs/2022-05-11_18_56_13_2173/model_9.pth'

model_name = 'fasterrcnn_mobilenet_v3_large_fpn'

kwargs = {"trainable_backbone_layers": 5, "min_size": 512, "max_size": 512, 'image_mean': (0., 0., 0.), 'image_std': (1., 1., 1.), 'box_score_thresh': 0.5}

num_classes = 7

pretrained = True
pretrained_backbone = True

In [None]:
# Load model
if 'rcnn' in model_name:
    model = torchvision.models.detection.__dict__[model_name](
                pretrained=pretrained, pretrained_backbone=pretrained_backbone, **kwargs
            )

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
else:
    model = torchvision.models.detection.__dict__[model_name](
                pretrained=pretrained, pretrained_backbone=pretrained_backbone, num_classes=num_classes, **kwargs
            )

# Load checkpoint
ckpt = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(ckpt['model'])
model = model.eval()

In [None]:
# .visualize() method allows for data visualization. If a model is supplied, it will also plot the predictions over the original images and ground truths (if any)
plt.figure(figsize=(30,15))
plt.axis('off')
plt.imshow(np.array(dataset.visualize(model=model).permute(1,2,0)))
plt.show()