In [17]:
import glob
import json
import os
import random
from itertools import islice
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torchvision import transforms
import v3io.dataplane

class StreamDataLoader(IterableDataset):
    def __init__(self, image_list, device, batch_size=32, img_dimensions=224):
        """
        Stream Ingestion for DataLoader with batching. Also performs image pre-processing.
        """
        self.device = device
        self.image_list = image_list
        self.batch_size = batch_size
        self.img_dimensions = img_dimensions
        self.transform = transforms.Compose([transforms.Resize((img_dimensions, img_dimensions)),
                                            transforms.ToTensor()])
        
    def load_image(self, image_path):
        """
        Load image from path, perform pre-processing, and load onto device.
        """
        image = Image.open(image_path)
        return self.transform(image).to(self.device)
    
    def get_stream(self, image_list):
        """
        Iterator that loads image from stream.
        """
        for image in image_list:
            yield self.load_image(image)
    
    def __iter__(self):
        """
        PyTorch sub-classed method to invoke.
        """
        return self.get_stream(self.image_list)
    
class ModelHandler:
    def __init__(self, device, model_path='./dogs_vs_cats_resnet50.pth'):
        """
        Handler for PyTorch model. Loads pre-trained model, performs predictions,
        displays prediction labels, and displays original images with prediction label.
        """
        self.device = device
        self.labels = ["cat", "dog"]
        self.model = self.load_model(model_path)
    
    def load_model(self, model_path, num_classes=2):
        """
        Loads pre-trained model, sets to evaluation mode, and sends to device.
        """
        model = torch.hub.load('pytorch/vision', 'resnet50')
        model.fc = nn.Sequential(nn.Linear(model.fc.in_features,512),nn.ReLU(), nn.Dropout(), nn.Linear(512, num_classes))
        model.load_state_dict(torch.load(model_path))
        model.eval()
        return model.to(self.device)
    
    def batch_predict(self, batch):
        """
        Gives prediction for batch of inputs.
        """
        preds = self.model(batch)
        return [pred.argmax() for pred in preds]
        
    def get_preds_labels(self, preds):
        """
        Gives labels for batch of predictions
        """
        return [self.labels[pred] for pred in preds]
        
    def display_preds(self, batch, preds, width, height):
        """
        Displays original images with prediction labels in grid.
        """
        plt.figure(figsize=(15, 7))
        for num, (sample, pred) in enumerate(zip(batch, preds)):
            plt.subplot(height, width, num+1)
            plt.title(self.labels[pred])
            plt.axis('off')
            sample = sample.to(self.device).numpy()
            plt.imshow(np.transpose(sample, (1,2,0)))
  
def init_context(context):
    """
    Init pre-trained model
    """
    context.model_handler = ModelHandler(model_path=os.getenv("model_path"), device=os.getenv("device"))
    context.v3io_client = v3io.dataplane.Client()

def handler(context, event):
    """
    Handler to perform real-time model inference using images
    from stream. Can perform on GPU and CPU. Writes predictions to log.
    """
        
    image_list = json.loads(event.body)
    
    kwargs = {'num_workers': 0, 'pin_memory': False} if os.getenv("device")=='cuda' else {}
    context.stream_data_loader = StreamDataLoader(image_list=image_list, device=os.getenv("device"))
    context.loader = DataLoader(context.stream_data_loader, batch_size=int(os.getenv("batch_size")), **kwargs)
    
    infer_start = time.time()
    
    # Iterate through batches from DataLoader, Query model and log predictions
    batch_times = []
    for i, batch in enumerate(islice(context.loader, int(os.getenv("num_batches")))):
        # Time prediction
        batch_start = time.time()
        preds = context.model_handler.batch_predict(batch=batch)
        batch_end = time.time()
        
        # Calculate time taken
        batch_time = batch_end - batch_start
        batch_times.append(batch_time)
        
        # Log predicitons and batch inference time
#         context.logger.info(context.model_handler.get_preds_labels(preds))
        context.logger.info(f"Batch {i+1} Inference Time: {batch_time}")
    
    infer_end = time.time()
    total_inference_time = infer_end - infer_start
    avg_batch_inference_time = sum(batch_times) / int(os.getenv('num_batches'))
    
    context.logger.info(f"Total Inference Time: {total_inference_time}")
    context.logger.info(f"Avg Batch Inference Time: {avg_batch_inference_time}")
    
    # Write metrics to KV
    context.logger.info("Writing to KV")
    record = {str(infer_start) : {"total_inference_time" : str(total_inference_time),
                                  "avg_batch_inference_time" : str(avg_batch_inference_time),
                                  "device" : str(os.getenv("device")),
                                  "batch_size" : str(os.getenv("batch_size")),
                                  "num_batches" : str(len(batch_times))}}
    for key, attributes in record.items():
        context.v3io_client.kv.put(container="bigdata",
                                   table_path=os.getenv("table_path"),
                                   key=key,
                                   attributes=attributes)
    context.logger.info("Done")