In [1]:
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from torchvision import datasets
import torchvision.transforms as transforms
from tqdm import tqdm

import model_files as model_all #模型的连接在__init__.py中体现

#### 全局参数

In [2]:
device = torch.device("cuda:1")
dataset_type = "CIFAR10"
# model_name = "AlexNet" #要被测试的模型
# model_name = "MobileNetV2_x1_4" #要被测试的模型
model_name = "ResNet56"

In [3]:
# 不同的transform
if dataset_type == "MNIST":
	transform = transforms.Compose([transforms.ToTensor(),transforms.Resize([32,32]), transforms.Normalize([0.5], [0.5])])
	train_dataloader = DataLoader(datasets.MNIST('./static/data/MNIST/MNIST', train=True, download=True, transform=transform), batch_size=128, shuffle=True)
	test_dataloader = DataLoader(datasets.MNIST('./static/data/MNIST/MNIST', train=False, download=True, transform=transform), batch_size=256, shuffle=True)
elif dataset_type == "CIFAR10":
	transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.4914,0.4822,0.4465], [0.2023,0.1994,0.2010])]) #CIFAR10数据集的均值和方差，多处网络验证
	train_dataloader = DataLoader(datasets.CIFAR10('./static/data/CIFAR10/CIFAR10', train=True, download=True, transform=transform), batch_size=100, shuffle=True)
	test_dataloader = DataLoader(datasets.CIFAR10('./static/data/CIFAR10/CIFAR10', train=False, download=True, transform=transform), batch_size=100, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


#### 加载模型

In [4]:
model = model_all.get_DNN_model(dataset_type, model_name)
model.eval() #!!!!!!!!!!要注意这个地方
model = model.to(device)
model.load_state_dict(torch.load("./model_files/" + dataset_type + "/checkpoints/classify_model/" + model_name + ".pt"))

<All keys matched successfully>

In [5]:
# 测试模型
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
cifar10_classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

with torch.no_grad():
    for data in test_dataloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(100):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

for i in range(10):
    print('Accuracy of %5s : %2d %%, total=%2d' % (cifar10_classes[i], 100 * class_correct[i] / class_total[i], class_total[i]))

Accuracy of plane : 94 %, total=1000
Accuracy of   car : 97 %, total=1000
Accuracy of  bird : 92 %, total=1000
Accuracy of   cat : 88 %, total=1000
Accuracy of  deer : 96 %, total=1000
Accuracy of   dog : 90 %, total=1000
Accuracy of  frog : 96 %, total=1000
Accuracy of horse : 96 %, total=1000
Accuracy of  ship : 96 %, total=1000
Accuracy of truck : 95 %, total=1000


#### 使用生成数据集进行测试

In [6]:
model_files_dir = "./model_files/" # 模型位置
sys.path.append(model_files_dir)
import numpy as np
from CIFAR10.models import BigGAN

number = 100
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
cifar10_classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
checkpoints_path = "./model_files/CIFAR10/checkpoints/BigGAN/model=G-best-weights-step=162000.pth"
G = BigGAN.Generator().to(device)
G.load_state_dict(torch.load(checkpoints_path, map_location=device)["state_dict"])
G.eval()

for class_type in tqdm(range(10)):
    for i in range(10): # 分10次，每次只生成100张，总计1000张
        for num in range(number):
            z = torch.tensor(np.random.RandomState(num).randn(1, 80)).to(torch.float32).to(device)    # latent code
            label = torch.tensor(class_type).unsqueeze(0).to(device)
            shared_label = G.shared(label)
            z_and_shared_label = torch.cat((z, shared_label), dim = 1)
            if num == 0:
                z_and_shared_labels = z_and_shared_label
            else:
                z_and_shared_labels = torch.cat((z_and_shared_labels, z_and_shared_label))
            img = G(z = z_and_shared_label)                           # NCHW, float32, dynamic range [-1, +1]
            # img = ((img + 1)/2).clamp(0.0, 1.0) # 变换到[0,1]范围内
            if num == 0:
                imgs = img
                labels = label
            else:
                imgs = torch.cat((imgs, img))
                labels = torch.cat((labels, label))
        
        with torch.no_grad():
            images, labels = imgs.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(number):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1

for i in range(10):
    print('Accuracy of %5s : %2d %%, total=%2d' % (cifar10_classes[i], 100 * class_correct[i] / class_total[i], class_total[i]))

100%|██████████| 10/10 [01:50<00:00, 11.04s/it]

Accuracy of plane : 79 %, total=1000
Accuracy of   car : 87 %, total=1000
Accuracy of  bird : 75 %, total=1000
Accuracy of   cat : 88 %, total=1000
Accuracy of  deer : 85 %, total=1000
Accuracy of   dog : 47 %, total=1000
Accuracy of  frog : 94 %, total=1000
Accuracy of horse : 68 %, total=1000
Accuracy of  ship : 95 %, total=1000
Accuracy of truck : 70 %, total=1000





In [7]:
np.sum(class_correct)

7880.0