In [54]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
from pathlib import Path

In [55]:
class_names = ['covid', 'non']

In [56]:
def model_loader(path):
    """
    Returns loaded model
    """
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    # Here the size of each output sample is set to 2.
    # Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
    model.fc = nn.Linear(num_ftrs, 2)
    model.load_state_dict(torch.load(path, map_location='cpu'), strict=False)
    model.eval()

    return model

In [57]:
checkpoint = Path('model_ft.pt')
model = model_loader(checkpoint)

In [58]:
def preprocess(image_path):
    """
    Returns transformed image tensor.
    :param image_path: path to image.
    :return: transformed image.
    """
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image = Image.open(image_path).convert('RGB')
    # sends a single image
    return transform(image).unsqueeze(0)

In [59]:
def prediction(image_path):
    """
    Returns model prediction for drawing.
    :param image_path: path to image.
    :return: predicted label.
    """
    tensor = preprocess(image_path)
    output = model(tensor)
    _, pred = torch.max(output, 1)

    return class_names[pred]

In [65]:
print("Predictions for X-rays with COVID-19:")
prediction('demo_images/covid/5CBC2E94-D358-401E-8928-965CCD965C5C.jpeg'), prediction('demo_images/covid/53EC07C9-5CC6-4BE4-9B6F-D7B0D72AAA7E.jpeg'), prediction('demo_images/covid/7-fatal-covid19.jpg')

Predictions for X-rays with COVID-19:


('covid', 'covid', 'covid')

In [63]:
print("Predictions for X-rays without COVID-19:")
prediction('demo_images/non/patient00106-study1-view1_frontal.jpg'), prediction('demo_images/non/patient00107-study1-view1_frontal.jpg'), prediction('demo_images/non/patient00140-study8-view1_frontal.jpg')

Predictions for X-rays without COVID-19:


('non', 'non', 'non')

In [66]:
%%timeit
print("Predictions for X-rays with COVID-19:")
prediction('demo_images/covid/5CBC2E94-D358-401E-8928-965CCD965C5C.jpeg')

Predictions for X-rays with COVID-19:
Predictions for X-rays with COVID-19:
Predictions for X-rays with COVID-19:
Predictions for X-rays with COVID-19:
Predictions for X-rays with COVID-19:
Predictions for X-rays with COVID-19:
Predictions for X-rays with COVID-19:
Predictions for X-rays with COVID-19:
196 ms ± 35.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
