<a href="https://colab.research.google.com/github/jeongin7103/CalCheck/blob/main/evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import sys
sys.path.append('/content/drive/MyDrive/SSD_detection')

In [None]:
from utils import *
from datasets import CustomDataset
from tqdm import tqdm
from pprint import PrettyPrinter

In [None]:
# Good formatting when printing the APs for each class and mAP
pp = PrettyPrinter()

# Parameters
data_folder = '/content/drive/MyDrive/SSD_detection/test'
batch_size = 4
workers = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = '/content/drive/MyDrive/SSD_detection/checkpoints/checkpoint_ssd300_epoch_19.pth.tar'

In [None]:
# Load model checkpoint that is to be evaluated
checkpoint = torch.load(checkpoint, map_location=torch.device('cpu'))
model = checkpoint['model']
model = model.to(device)

In [None]:
# Switch to eval mode
model.eval()

# Load test data
test_dataset = CustomDataset(data_folder, split='test')
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                                          collate_fn=test_dataset.collate_fn, pin_memory=True)

In [None]:
def evaluate(test_loader, model):
    """
    Evaluate.

    :param test_loader: DataLoader for test data
    :param model: model
    """

    # Make sure it's in eval mode
    model.eval()

    # Lists to store detected and true boxes, labels, scores
    det_boxes = list()
    det_labels = list()
    det_scores = list()
    true_boxes = list()
    true_labels = list()
    true_difficulties = list()  # it is necessary to know which objects are 'difficult', see 'calculate_mAP' in utils.py

    with torch.no_grad():
        # Batches
        for i, (images, boxes, labels) in enumerate(tqdm(test_loader, desc='Evaluating')):
            images = images.to(device)  # (N, 3, 300, 300)

            # Forward prop.
            predicted_locs, predicted_scores = model(images)

            # Detect objects in SSD output
            det_boxes_batch, det_labels_batch, det_scores_batch = model.detect_objects(predicted_locs, predicted_scores,
                                                                                       min_score=0.01, max_overlap=0.45,
                                                                                       top_k=200)
            # Evaluation MUST be at min_score=0.01, max_overlap=0.45, top_k=200 for fair comparision with the paper's results and other repos

            # Store this batch's results for mAP calculation
            boxes = [b.to(device) for b in boxes]
            labels = [l.to(device) for l in labels]

            det_boxes.extend(det_boxes_batch)
            det_labels.extend(det_labels_batch)
            det_scores.extend(det_scores_batch)
            true_boxes.extend(boxes)
            true_labels.extend(labels)

        # Calculate mAP
        APs, mAP = calculate_mAP(det_boxes, det_labels, det_scores, true_boxes, true_labels)

    # Print AP for each class
    pp.pprint(APs)

    print('\nMean Average Precision (mAP): %.3f' % mAP)

In [None]:
evaluate(test_loader, model)

Evaluating: 100%|██████████| 250/250 [24:16<00:00,  5.83s/it]


{'Kohlrabi': 0.8771983981132507,
 'Mushroom': 0.926325261592865,
 'Paprika': 0.7396257519721985,
 'Pimento': 0.4829806387424469,
 'Pumpkin': 0.9344001412391663,
 'Tomato': 0.9296153783798218,
 'apple': 0.8561944365501404,
 'blueberry': 0.8924987316131592,
 'cherry': 0.7098320722579956,
 'chestnuts': 0.8994967937469482,
 'chicory': 0.9920255541801453,
 'grape': 0.9398550391197205,
 'grapefruit': 0.8048553466796875,
 'mango': 0.5383621454238892,
 'melon': 0.8975515365600586,
 'peach': 0.9330011606216431,
 'pepper': 0.6311404705047607,
 'plum': 0.9017970561981201,
 'strawberry': 0.9217002391815186,
 'yam': 0.8972849249839783}

Mean Average Precision (mAP): 0.835
