# Explore evaluation results

In [None]:
import time
from coco_utils import get_coco  # get_coco_kp
from torchvision import transforms
import torchvision
import torchvision.models.detection
import transforms as T
import torch
import utils
from matplotlib.pyplot import figure, imshow, show
import matplotlib
import numpy as np

model_path = 'runs/whiskers_longer/model_29_finished.pth'
device = 'cuda'
data_path = 'datasets/whiskers/'
batch_size = 3
workers = 4
draw_threshold = 0.5
DPI = 220
convert_to_pil = torchvision.transforms.ToPILImage()

def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

# Datasets
dataset_train, num_classes, label_names = get_coco(
     data_path, image_set='train', transforms=get_transform(train=True)
)
dataset_test, _, _ = get_coco(
    data_path, image_set='val', transforms=get_transform(train=False)
)

# Samplers
train_sampler = torch.utils.data.SequentialSampler(dataset_train)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)

train_batch_sampler = torch.utils.data.BatchSampler(
    train_sampler, batch_size, drop_last=True)

# Loaders
data_loader_train = torch.utils.data.DataLoader(
    dataset_train, batch_sampler=train_batch_sampler, num_workers=workers,
    collate_fn=utils.collate_fn)
data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1,
    sampler=test_sampler, num_workers=workers,
    collate_fn=utils.collate_fn)

# Load checkpoint
checkpoint = torch.load(model_path, map_location=device)
label_names = checkpoint['label_names']

# Set up model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
    num_classes=len(label_names) + 1, pretrained_backbone=False
)
model.to(device)
model.load_state_dict(checkpoint['model'])
model.eval()
print('Done loading model')

In [None]:
matplotlib.rcParams['figure.dpi'] = DPI

def print_inference_results(data_loader, model):
    images_evaluated = 0
    for image, targets in data_loader:
        pre_model_image = image[0]

        image = list(img.to(device) for img in image)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # torch.cuda.synchronize()
        model_time = time.time()
        outputs = model(image)

        outputs = [{k: v.to('cpu') for k, v in t.items()} for t in outputs]
        model_time = time.time() - model_time


        scores = outputs[0]['scores']
        top_scores_filter = scores > draw_threshold
        top_scores = scores[top_scores_filter]
        top_boxes = outputs[0]['boxes'][top_scores_filter]
        top_labels = outputs[0]['labels'][top_scores_filter]
        image_with_boxes = utils.draw_boxes(
            pre_model_image, top_boxes, top_labels, label_names, scores,
            vert_size=600, line_width=1, draw_label=False
        )
        print(f"# {images_evaluated}")
        figure()
        imshow(np.asarray(convert_to_pil(image_with_boxes)))
        show()
        images_evaluated += 1


In [None]:
# Run data_loader_train or data_loader_test, but not both together or you will probably run out of GPU memory
print_inference_results(data_loader_train, model)