In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from cloudpickle import load, loads

from mlrun.artifacts import get_model, update_model

def prep_model(device, num_classes=2):
    model_resnet50 = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
    
    for name, param in model_resnet50.named_parameters():
        if "bn" not in name:
            param.requires_grad = False
        
    model_resnet50.fc = nn.Sequential(nn.Linear(model_resnet50.fc.in_features,512),
                                      nn.ReLU(),
                                      nn.Dropout(),
                                      nn.Linear(512, num_classes))
    
    return model_resnet50.to(device)

def handler(context, event):
    context.logger.info("Loading Model")
    device = str(context.inputs['device'])
    model_file, _, _ = get_model(context.inputs['model'].url, suffix='.pth')
    model = prep_model(device=device)
    model.load_state_dict(torch.load(open(model_file, "rb")))
    model.eval()
    
    context.logger.info("Loading Data Loader")
    test_data_loader = loads(context.inputs['test_data_loader'].get())
    
    context.logger.info("Inferencing Model")
    num_correct = 0 
    num_examples = 0
    test_loss = 0.0
    
    with torch.no_grad():    
        for batch in test_data_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            
            loss = nn.CrossEntropyLoss()(output,targets) 
            test_loss += loss.data.item() * inputs.size(0)
                        
            correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets).view(-1)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        test_loss /= len(test_data_loader.dataset)        
    
    context.logger.info("Logging Results")
    context.logger.info('Test Loss: {:.4f}, Test Accuracy = {:.4f}'.format(test_loss, num_correct / num_examples))
    results = {"test_loss" : test_loss,
               "test_accuracy" : num_correct / num_examples}
    context.log_results(results)
    update_model(context.inputs['model'], metrics=context.results)