# Evaluate the trained model

In [1]:
from utils import *
from datasets import PascalVOCDataset
from tqdm import tqdm
from pprint import PrettyPrinter
from IPython.display import clear_output

### Load model

In [2]:
# Good formatting when printing the APs for each class and mAP
pp = PrettyPrinter()

# Parameters
data_folder = '/media/bruno/HD-Arquivos2/Data_Object_Detect/' # Path data preparation
keep_difficult = True  # difficult ground truth objects must always be considered in mAP calculation, because these objects DO exist!
batch_size = 50
workers = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = 'checkpoint_ssd300.pth.tar' # Path model save

# Load model checkpoint that is to be evaluated
checkpoint = torch.load(checkpoint)
model = checkpoint['model']
model = model.to(device)

### Load data test

In [3]:
# Switch to eval mode
model.eval()

# Load test data
test_dataset = PascalVOCDataset(data_folder,
                                split='test',
                                keep_difficult=keep_difficult)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True,
                                          collate_fn=test_dataset.collate_fn, num_workers=workers, pin_memory=True)

### Evaluate Model with Mean Average Precision [mAP](https://medium.com/@jonathan_hui/map-mean-average-precision-for-object-detection-45c121a31173) 

In [4]:
def evaluate(test_loader, model):
    """
    Evaluate.

    :param test_loader: DataLoader for test data
    :param model: model
    """

    # Make sure it's in eval mode
    model.eval()

    # Lists to store detected and true boxes, labels, scores
    det_boxes = list()
    det_labels = list()
    det_scores = list()
    true_boxes = list()
    true_labels = list()
    true_difficulties = list()  # it is necessary to know which objects are 'difficult', see 'calculate_mAP' in utils.py

    with torch.no_grad():
        # Batches
        for i, (images, boxes, labels, difficulties) in enumerate(tqdm(test_loader, desc='Evaluating')):
            images = images.to(device)  # (N, 3, 300, 300)

            # Forward prop.
            predicted_locs, predicted_scores = model(images)

            # Detect objects in SSD output
            det_boxes_batch, det_labels_batch, det_scores_batch = model.detect_objects(predicted_locs, predicted_scores,
                                                                                       min_score=0.01, max_overlap=0.45,
                                                                                       top_k=200)
            # Evaluation MUST be at min_score=0.01, max_overlap=0.45, top_k=200 for fair comparision with the paper's results and other repos

            # Store this batch's results for mAP calculation
            boxes = [b.to(device) for b in boxes]
            labels = [l.to(device) for l in labels]
            difficulties = [d.to(device) for d in difficulties]

            det_boxes.extend(det_boxes_batch)
            det_labels.extend(det_labels_batch)
            det_scores.extend(det_scores_batch)
            true_boxes.extend(boxes)
            true_labels.extend(labels)
            true_difficulties.extend(difficulties)

        # Calculate mAP
        APs, mAP = calculate_mAP(det_boxes, det_labels, det_scores, true_boxes, true_labels, true_difficulties)

    # Print AP for each class
    pp.pprint(APs)
    return (APs, mAP)
    print('\nMean Average Precision (mAP): %.3f' % mAP)

In [5]:
# Avaluate
(APs, mAp) = evaluate(test_loader, model)
clear_output()

In [6]:
# Result mAP model
mAp

0.6427056789398193

In [7]:
# mAP for each class
APs

{'aeroplane': 0.708512008190155,
 'bicycle': 0.7454192638397217,
 'bird': 0.6230981945991516,
 'boat': 0.5221483111381531,
 'bottle': 0.24800018966197968,
 'bus': 0.7201552391052246,
 'car': 0.7631471753120422,
 'cat': 0.8401749730110168,
 'chair': 0.3495808243751526,
 'cow': 0.703272819519043,
 'diningtable': 0.6119518876075745,
 'dog': 0.7849552035331726,
 'horse': 0.8075953722000122,
 'motorbike': 0.7407485842704773,
 'person': 0.658084511756897,
 'pottedplant': 0.26374727487564087,
 'sheep': 0.6770304441452026,
 'sofa': 0.6690605878829956,
 'train': 0.7741146087646484,
 'tvmonitor': 0.6433150172233582}