What you’ll learn here:

    How to run inference tasks with Pytorch on the GPU cluster
    How to use batch processing to accelerate your inference tasks with Pytorch on the GPU cluster

To begin, we need to ensure that our image dataset is available and that our GPU cluster is running.

In our case, we have stored the data on S3 and use the s3fs library to work with it, as you’ll see below. If you would like to use this same dataset, it is the Stanford Dogs dataset, available here: http://vision.stanford.edu/aditya86/ImageNetDogs/

In [None]:
!wget 

In [None]:
client = Client(cluster)

client.run(lambda: torch.cuda.is_available())



we set the device for pytorch computations

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


#Inference

Now, we’re ready to start doing some classification! We’re going to use some custom-written functions to do this efficiently and make sure our jobs can take full advantage of the parallelization of the GPU cluster.


#Preprocessing
Single Image Processing

In [None]:
@dask.delayed
def preprocess(path, fs=__builtins__):
    '''Ingest images directly from S3, apply transformations,
    and extract the ground truth and image identifier. Accepts
    a filepath. '''

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(250),
        transforms.ToTensor()])

    with fs.open(path, 'rb') as f:
        img = Image.open(f).convert("RGB")
        nvis = transform(img)

    truth = re.search('dogs/Images/n[0-9]+-([^/]+)/n[0-9]+_[0-9]+.jpg', path).group(1)
    name = re.search('dogs/Images/n[0-9]+-[a-zA-Z-_]+/(n[0-9]+_[0-9]+).jpg', path).group(1)

    return [name, nvis, truth]


This function allows us to process one image, but of course, we have a lot of images to work with here! We’re going to use some list comprehension strategies to create our batches and get them ready for our inference.

First, we break the list of images we have from our S3 file path into chunks that will define the batches.



In [None]:
s3fpath = 's3://dask-datasets/dogs/Images/*/*.jpg'

batch_breaks = [list(batch) for batch in toolz.partition_all(60, s3.glob(s3fpath))]


In [None]:
def evaluate_pred_batch(batch, gtruth, classes):
    ''' Accepts batch of images, returns human readable predictions. '''
    _, indices = torch.sort(batch, descending=True)
    percentage = torch.nn.functional.softmax(batch, dim=1)[0] * 100

    preds = []
    labslist = []
    for i in range(len(batch)):
        pred = [(classes[idx], percentage[idx].item()) for idx in indices[i][:1]]
        preds.append(pred)

        labs = gtruth[i]
        labslist.append(labs)

    return(preds, labslist)

def is_match(la, ev):
    ''' Evaluate human readable prediction against ground truth.
    (Used in both methods)'''
    if re.search(la.replace('_', ' '), str(ev).replace('_', ' ')):
        match = True
    else:
        match = False
    return(match)


@dask.delayed
def run_batch_to_s3(iteritem):
    ''' Accepts iterable result of preprocessing,
    generates inferences and evaluates. '''

    with s3.open('s3://dask-datasets/dogs/imagenet1000_clsidx_to_labels.txt') as f:
        classes = [line.strip() for line in f.readlines()]

    names, images, truelabels = iteritem

    images = torch.stack(images)

    with torch.no_grad():
        # Set up model
        resnet = models.resnet50(pretrained=True)
        resnet = resnet.to(device)
        resnet.eval()

        # run model on batch
        images = images.to(device)
        pred_batch = resnet(images)

        #Evaluate batch
        preds, labslist = evaluate_pred_batch(pred_batch, truelabels, classes)

        #Organize prediction results
        for j in range(0, len(images)):
            predicted = preds[j]
            groundtruth = labslist[j]
            name = names[j]
            match = is_match(groundtruth, predicted)

            outcome = {'name': name, 'ground_truth': groundtruth, 'prediction': predicted, 'evaluation': match}

            # Write each result to S3 directly
            with s3.open(f"s3://dask-datasets/dogs/preds/{name}.pkl", "wb") as f:
                pickle.dump(outcome, f)

        return(names)


Now we can start the Dask client

In [None]:
futures = client.map(run_batch_to_s3, image_batches)
futures_gathered = client.gather(futures)
futures_computed = client.compute(futures_gathered, sync=False)


With map we ensure all our batches will get the function applied to them. With gather, we can collect all the results simultaneously rather than one by one. With compute(sync=False) we return all the futures, ready to be calculated when we want them. This may seem arduous, but these steps are required to allow us to iterate over the future.

Now we actually run the tasks, and we also have a simple error handling system just in case any of our files are messed up or anything goes haywire.

In [None]:
import logging

results = []
errors = []
for fut in futures_computed:
    try:
        result = fut.result()
    except Exception as e:
        errors.append(e)
        logging.error(e)
    else:
        results.extend(result)


Evaluate
We want to make sure we have high-quality results coming out of this model, of course! First, we can peek at a single result.

In [None]:
with s3.open('s3://dask-datasets/dogs/preds/n02086240_1082.pkl', 'rb') as data:
    old_list = pickle.load(data)
    old_list