In [2]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from resnet18_32x32 import ResNet18_32x32
# 定义转换器
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化
])
# 检查是否有可用的GPU
device = torch.device("cuda" if torch.cuda.is_available() else False)
# 加载测试集
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

# 加载已训练的模型
model = ResNet18_32x32().to(device)  # 用你的已训练的模型替换 YourModel()
model.load_state_dict(torch.load('./resnet/model.pth'))
model.eval()

tensor_list = []

# 测试模型
correct = 0
total = 0
with torch.no_grad():
    for images, labels in testloader:
        images = images.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        probabilities = torch.softmax(outputs, dim=1) 
        tensor_list.append((predicted, probabilities)) #训练数据集

print("Size of tensor_list:", len(tensor_list))
for i in range(5):
    predicted, probabilities = tensor_list[i]
    print(f"Element {i+1}:")
    print("Predicted Label:", predicted)
    print("Probabilities:", probabilities)
    print()

Files already downloaded and verified
Size of tensor_list: 157
Element 1:
Predicted Label: tensor([2, 4, 4, 9, 0, 4, 2, 6, 0, 2, 2, 7, 9, 0, 2, 2, 2, 3, 1, 3, 3, 8, 0, 8,
        5, 1, 8, 7, 7, 2, 0, 6, 0, 0, 3, 3, 9, 5, 9, 9, 4, 2, 0, 5, 0, 1, 9, 1,
        6, 9, 3, 6, 2, 3, 5, 3, 8, 6, 3, 8, 5, 6, 7, 4], device='cuda:0')
Probabilities: tensor([[7.2259e-02, 5.3529e-05, 8.3453e-01, 6.9571e-05, 1.4357e-05, 4.0270e-04,
         2.2941e-05, 3.5111e-04, 8.9090e-02, 3.2109e-03],
        [2.2321e-09, 1.9119e-07, 2.4308e-07, 4.6671e-04, 9.9781e-01, 6.3366e-07,
         1.7203e-03, 9.4557e-07, 1.1654e-08, 1.5059e-07],
        [6.8752e-03, 3.5559e-07, 5.5659e-03, 7.8453e-07, 9.6762e-01, 1.5947e-04,
         6.0613e-05, 1.8605e-02, 1.1017e-03, 1.0300e-05],
        [7.0182e-04, 2.5245e-02, 3.5421e-04, 2.2778e-03, 2.6842e-04, 2.1041e-05,
         4.2666e-03, 2.2776e-04, 2.2363e-04, 9.6641e-01],
        [9.3274e-01, 1.1587e-05, 2.9657e-04, 5.3664e-03, 7.7533e-04, 2.9620e-04,
         1.3453e-03, 1.