In [None]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt 
from torchvision import datasets, transforms, models
from helpers import iou_pytorch, display_image, get_best_device

In [None]:
# Use eval to enable prediction mode for the model
device = get_best_device()
model = torchvision.models.segmentation.fcn_resnet101(pretrained=True)
model = model.to(device).eval()

In [None]:
# Resizing is necessary for a reasonable testing time
transform_pipeline = transforms.Compose([
    transforms.Resize(200),
    transforms.CenterCrop((200, 200)),
    transforms.ToTensor()
])

test = datasets.VOCSegmentation('../torch', image_set='val', download=True, 
                                transform=transform_pipeline,
                                target_transform=transform_pipeline)

test_loader = torch.utils.data.DataLoader(test, batch_size=8, num_workers=8)

In [None]:
totalIoU = 0
count = 0

for idx, (images, labels) in enumerate(test_loader):
    print(f"Predicting for batch: {idx}")
    pred = model(images.to(device))['out']
    
    # Get the maximum predicted class for each pixel in the image
    pred = torch.argmax(pred, dim=1).detach().cpu()
    pred = pred.to(torch.uint8)

    # Remove the extra dimension in each image and 
    # remove segmentation borders (last line)
    labels = labels.squeeze(dim=1)
    labels = (labels * 255).to(torch.uint8)
    labels[labels == 255] = 0

    # Calculate the Batch Intersection over Union
    batchIoU = iou_pytorch(pred, labels).mean().numpy()
    totalIoU += batchIoU
    count += 1

    # print(f"IoU for batch {idx}: {batchIoU}")

    # # Uncomment to set number of batches to run
    # if count >= 10: 
    #     break

print("Average IoU on dataset: ", np.round(totalIoU / count, 3))