In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

#from tqdm import tqdm
from tqdm import tqdm_notebook as tqdm
from time import sleep

# DataLoader(CIFAR10)

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified


# Model

In [3]:
model = models.vgg11_bn(pretrained=False) 
model.classifier[0] = nn.Linear(512, 4096)
model.classifier[6] = nn.Linear(model.classifier[6].in_features, len(classes))
model.cuda()
model.load_state_dict(torch.load('params/model.pth'))

# Validation function

In [4]:
def inference_loop(model, loader):
    model.eval()
    total_acc, total_num = 0, 0
    bar = tqdm(test_loader, total=len(test_loader), leave=False)
    for i, feed in enumerate(loader):
        with torch.no_grad():
            # Prepare data
            inputs, labels = feed
            inputs = inputs.cuda()
            labels = labels.cuda()
            # Foward
            outputs = model(inputs)
            # Update bar
            ## Accuracy
            pred = outputs.data.max(1, keepdim=True)[1]
            acc = pred.eq(labels.data.view_as(pred)).sum()
            ## Calcurate Score
            total_acc += acc.item()
            total_num += labels.size(0)

            bar.set_description("Accuracy: {:.2f}".format(total_acc / total_num * 100), refresh=True)
            bar.update()
    bar.close()
    return total_num, total_acc / total_num * 100

In [5]:
data_num, test_acc = inference_loop(model, test_loader)
print('Data num: {}, Test Accuracy: {:.2f}'.format(data_num, test_acc))


HBox(children=(IntProgress(value=0, max=79), HTML(value='')))

Data num: 10000, Test Accuracy: 85.51
