In [None]:
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import resnet50

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_transform = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])   # 和训练一样的标准化处理参数

# load image
img_path = "../tulip.jpg"
img = Image.open(img_path)
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)

# read class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

# create model
model = resnet50(num_classes=2).to(device)

# load model weights
weights_path = "./resNet34.pth"
model.load_state_dict(torch.load(weights_path, map_location=device))

# prediction
model.eval()
with torch.no_grad():   # 不对损失梯度进行跟踪
    # predict class
    output = torch.squeeze(model(img.to(device))).cpu()   # squeeze压缩batch维度
    predict = torch.softmax(output, dim=0)   # 得到概率分布
    predict_cla = torch.argmax(predict).numpy()   # 寻找最大值所对应的索引

print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                             predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
    print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                              predict[i].numpy()))   # 打印类别信息及概率
plt.show()