In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
from skimage.io import imread, imshow
from skimage.transform import resize
import torch; print('Torch Version: {}'.format(torch.__version__))
import time

from model01_PT import Net

In [None]:
# Set location of base directory
base_path = os.path.join('/', 'workspace', 'optimization')

# Set location of datasets
datasets_path = os.path.join(base_path, 'datasets')
images_path = os.path.join(datasets_path, 'images')
masks_path = os.path.join(datasets_path, 'masks')

# Set location of model
weights_path = 'weights'
weights_file_name = 'Jan_2019_99_w_rejects.pth'
weights_file_path = os.path.join(weights_path, weights_file_name)

In [None]:
# Load the weights from a file (.pth usually)
print('Loading weights file from:', weights_file_path)
state_dict = torch.load(weights_file_path)

In [None]:
# Create instance of model
model = Net()

In [None]:
# Load the weights now into a model net architecture defined by our class
model.load_state_dict(state_dict)

In [None]:
# Set network settings
n_channel, n_height, n_width = 1, 512, 384
dimensions = [n_channel, n_height, n_width]
batch_size = 1
architecture = 'v100'  # options are 't4' (default), 'v100' and 'xavier'

In [None]:
# Load test images and masks
images_file_names = os.listdir(images_path)
images_file_paths = [os.path.join(images_path, f) for f in images_file_names]
masks_file_names = os.listdir(masks_path)
masks_file_paths = [os.path.join(masks_path, f) for f in masks_file_names]

# Print first 2
print('Images:', images_file_paths[:2])
print('Masks:', masks_file_paths[:2])

In [None]:
# Get the jth image and mask
j = 0
image_file_path = images_file_paths[j]
mask_file_path = masks_file_paths[j]
print('Image:', image_file_path)
print('Mask:', mask_file_path)

In [None]:
# Load image and mask
image = imread(image_file_path)
mask = imread(mask_file_path)
print('Image:', image.shape, image.dtype)
print('Mask:', mask.shape, mask.dtype)

In [None]:
# Change image dimension from 3 to 1
image = image[:, :, 0]
print('Image:', image.shape, image.dtype)
print('Mask:', mask.shape, mask.dtype)

In [None]:
# Show image
imshow(image, cmap='binary'); plt.show()

In [None]:
# Show mask
imshow(mask, cmap='binary'); plt.show()

In [None]:
# Resize image
image = resize(image, (n_height, n_width))
print('Image:', image.shape, image.dtype)

In [None]:
# Add channel dimension
image = image[np.newaxis, :, :]
print('Image:', image.shape, image.dtype)

In [None]:
# Change type
image = image.astype(np.float32)
print('Image:', image.shape, image.dtype)

In [None]:
# Add batch dimension
image = image[np.newaxis, :, :, :]
print('Image:', image.shape, image.dtype)

In [None]:
# Repeat data along batch axis
batch_size = 1
image = np.repeat(image, batch_size, axis=0)
print('Image:', image.shape, image.dtype)

In [None]:
# Create Torch Tensor and send image to GPU
image = torch.Tensor(image)

In [None]:
# Send image to GPU
image_gpu = image.cuda()

In [None]:
# Send model to GPU
model_gpu = model.cuda()

In [None]:
# Perform inference on the GPU
output_gpu = model_gpu(image_gpu)

In [None]:
# Return prediction to CPU and convert to NumPy array
output = output_gpu.cpu().detach().numpy()

In [None]:
# Log info
print("Prediction Shape: {}".format(output.shape)) 
# print("Prediction: {} ".format(output))

In [None]:
# Show prediction
if batch_size > 1:
    imshow(np.squeeze(output[0]), cmap='binary')
else:
    imshow(np.squeeze(output), cmap='binary')
plt.show()

## Benchmarking

In [None]:
n = 2000
start = time.time()
for _ in range(n):
    # Transfer input data to device
    image_gpu = image.cuda()
    
    # Execute model
    output_gpu = model_gpu(image_gpu)
    
    # Transfer predictions back
    output = output_gpu.cpu()
    
end = time.time()

In [None]:
delta = end - start
average_latency = delta / n
average_throughput = batch_size * (1 / average_latency)
print('Inference: {} seconds'.format(delta))
print('Number of Inferences: {}'.format(n))
print('Average Latency: {} seconds'.format(average_latency))
print('Average Throughput w/ Batch Size {}: {} examples per second'.format(batch_size, average_throughput))