In [1]:
from typing import List, Dict, Any, Tuple

import mlflow
import mlflow.pyfunc
import numpy as np
import torch
from torch.utils.data.dataloader import DataLoader
from torchvision import datasets, transforms

In [2]:

model_name = 'mnist_lr_optimized'
model_version = 3

mlflow.set_tracking_uri('http://localhost:5001/')
model = mlflow.pyfunc.load_model(model_uri=f"models:/{model_name}/{model_version}")
type(model)
#model.predict(data)

  from .autonotebook import tqdm as notebook_tqdm
Downloading artifacts: 100%|██████████| 6/6 [00:00<00:00, 20.09it/s]  


mlflow.pyfunc.PyFuncModel

In [3]:
def load_test_images(batch_size: int) -> Tuple[Any]:
    # Define a transform to normalize the data
    transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

    # Download and load the testing data
    test_dataset = datasets.MNIST('./mnistdata', download=True, train=False, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    return test_loader

In [4]:
def test_model(model: mlflow.pyfunc.PyFuncModel, loader: DataLoader) -> Dict[str, Any]:
    correct_count, total_count = 0, 0
    for images, labels in loader:
        for i in range(len(labels)):
            img = images[i].view(1, 784)
            # Turn off gradients to speed up this part
            with torch.no_grad():
                logps = model.predict(img.numpy())

            # Output of the network are log-probabilities, need to take exponential for probabilities
            ps = np.exp(logps)
            probab = list(ps[0])
            pred_label = probab.index(max(probab))
            true_label = labels.numpy()[i]
            if(true_label == pred_label):
                correct_count += 1
            total_count += 1
    
    testing_metrics = {
        'incorrect_count': total_count-correct_count,
        'correct_count': correct_count,
        'accuracy': (correct_count/total_count)
    }
    print("Number Of Images Tested =", total_count)
    print("\nModel Accuracy =", (correct_count/total_count))
    return testing_metrics

In [5]:
test_loader = load_test_images(64)
testing_metrics = test_model(model, test_loader)
testing_metrics

Number Of Images Tested = 10000

Model Accuracy = 0.9795


{'incorrect_count': 205, 'correct_count': 9795, 'accuracy': 0.9795}