In [1]:
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torchvision import datasets
import torchvision.transforms as transforms
from tqdm import tqdm

import model_files as model_all #模型的连接在__init__.py中体现

#### 全局参数

In [2]:
device = torch.device("cuda:0")
dataset_type = "CIFAR10"
model_name = "kjl_AlexNet" #要被测试的模型

In [3]:
# 不同的transform
if dataset_type == "MNIST":
	transform = transforms.Compose([transforms.ToTensor(),transforms.Resize([32,32]), transforms.Normalize([0.5], [0.5])])
	train_dataloader = DataLoader(datasets.MNIST('./static/data/MNIST/MNIST', train=True, download=True, transform=transform), batch_size=128, shuffle=True)
	test_dataloader = DataLoader(datasets.MNIST('./static/data/MNIST/MNIST', train=False, download=True, transform=transform), batch_size=256, shuffle=True)
elif dataset_type == "CIFAR10":
	transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.4914,0.4822,0.4465], [0.2023,0.1994,0.2010])]) #CIFAR10数据集的均值和方差，多处网络验证
	train_dataloader = DataLoader(datasets.CIFAR10('./static/data/CIFAR10/CIFAR10', train=True, download=True, transform=transform), batch_size=128, shuffle=True)
	test_dataloader = DataLoader(datasets.CIFAR10('./static/data/CIFAR10/CIFAR10', train=False, download=True, transform=transform), batch_size=256, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


#### 加载模型

In [4]:
model = model_all.get_DNN_model(dataset_type, model_name)
model.eval() #!!!!!!!!!!要注意这个地方
model = model.to(device)
model.load_state_dict(torch.load("./model_files/" + dataset_type + "/checkpoints/classify_model/" + model_name + ".pt"))

<All keys matched successfully>

In [5]:
total = 0
test_correct = 0
for inputs, labels in test_dataloader:
    inputs,labels = inputs.to(device),labels.to(device)
    # ============= forward =============
    model.eval()
    outputs = model(inputs)
    # ============= precision ===========
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    test_correct += (predicted == labels).sum().item()

print("acc=", test_correct/total)

acc= 0.8294
