In [13]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import argparse
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
from network import resnet


Setting parameters(dataroot, batch size etc.)

In [14]:
args = argparse.ArgumentParser()
args.add_argument('--dataroot', type=str, default='G:/mydataset/monkeypoxskin/Fold1/test')
args.add_argument('--batch_size', type=int, default=32)
args.add_argument('--num_workers', type=int, default=4)
args.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')

opt = args.parse_args(args=[])

 Load dataset and model

In [19]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = datasets.ImageFolder(root=opt.dataroot, transform=transform)
dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
model = resnet.resnet34()
model.fc = nn.Linear(512, 3)
model.load_state_dict(torch.load('checkpoints/model_epoch_best.pth')['model'])
model = model.to(opt.device)


Validation, calculate accuracy.

In [20]:
def validate(model, dataroot, device):
    model.eval()
    correct = 0
    label_list, pred_list = [], []
    with torch.no_grad():
        for img, label in tqdm(dataloader):
            img = img.to(device)
            label = label.to(device)
            pred = model(img).sigmoid()
            correct += (pred.argmax(-1) == label).type(torch.float).sum().item()

            label_list.extend(label.cpu().numpy())
            pred_list.extend(pred.argmax(-1).cpu().numpy())

    correct /= len(dataloader.dataset)
    return correct, label_list, pred_list


# Main

In [21]:
correct, label_list, pred_list = validate(model, opt.dataroot, opt.device)
num_health = np.sum(np.array(label_list) == 0)
num_monkey = np.sum(np.array(label_list) == 1)
num_other = np.sum(np.array(label_list) == 2)

print('Accuracy: {:.2f}%'.format(correct * 100))
print(f'Number of healthy: {num_health}, Number of monkeypox: {num_monkey}, Number of other: {num_other}')
print(
    'Health: {:.2f}%'.format(np.sum(np.array(label_list)[np.array(label_list) == 0] == np.array(pred_list)[np.array(label_list) == 0]) / num_health * 100),
    '\tMonkeypox: {:.2f}%'.format(np.sum(np.array(label_list)[np.array(label_list) == 1] == np.array(pred_list)[np.array(label_list) == 1]) / num_monkey * 100),
    '\tOther: {:.2f}%'.format(np.sum(np.array(label_list)[np.array(label_list) == 2] == np.array(pred_list)[np.array(label_list) == 2]) / num_other * 100)
)


100%|██████████| 3/3 [00:02<00:00,  1.04it/s]

Accuracy: 73.85%
Number of healthy: 20, Number of monkeypox: 20, Number of other: 25
Health: 80.00% 	Monkeypox: 85.00% 	Other: 60.00%



