In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from ExtCIFAR10 import ExtCIFAR10
from resnet import *

model_file = 'd:/Lab/models/resnet20.th'

def test_validate(model, device, test_loader, test_valid='Test'):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn.CrossEntropyLoss(reduction='sum')(output, target).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    acc = 100. * correct / len(test_loader.dataset)
    print('\n{} set: Average loss: {:.6f}, Accuracy: {}/{} ({:.6f}%)\n'.format(
        test_valid, test_loss, correct, len(test_loader.dataset), acc))
    return test_loss

def main():        
    dataroot = 'D:\Lab\dataset'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    b_size = 512

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    test_dataset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=False, transform=transform_test)
    
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=b_size, shuffle=False, num_workers=2)
                
    print('Start testing...')
    model = resnet20()
    model.to(device)
    checkpoint = torch.load(model_file)
    print(checkpoint.keys())
    new_state_dict = {}
    for k, v in checkpoint['state_dict'].items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    test_validate(model, device, test_loader)
    
main()

Start testing...


  init.kaiming_normal(m.weight)


dict_keys(['best_prec1', 'state_dict'])

Test set: Average loss: 0.372661, Accuracy: 9173/10000 (91.730000%)

