In [None]:
import torch
import torchvision
import torch.nn.functional as F

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from PIL import Image, ImageFont, ImageDraw
# %matplotlib inline
plt.rcParams['font.sans-serif']=['SimHei']  
plt.rcParams['axes.unicode_minus']=False 

font = ImageFont.truetype('SimHei.ttf', 64)
idx_to_labels = np.load('idx_to_labels.npy', allow_pickle=True).item()#前に作成した辞書を読み込む
print(idx_to_labels)

model = torch.load('net.pth',map_location='cpu')
model = model.eval()#テストモードに入り、モデルのパラメータはこの段階で更新されない

from torchvision import transforms
# testデータセット画像前処理
#1.resize the input image to the given size.
#2.crops the given image at the center.
#3.convert the array into a Tensor
#4.the data is transformed to [- 1,1]
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
                                    ])


img_path = 'medicine.jpg'
img_pil = Image.open(img_path)
print(np.array(img_pil).shape)
# img_pil.show()
input_img = test_transform(img_pil) # test画像前処理
print(input_img.shape)
input_img = input_img.unsqueeze(0)
print(input_img.shape)
pred_logits = model(input_img)
print(pred_logits)

pred_softmax = F.softmax(pred_logits, dim=1)
print(pred_softmax)


plt.figure(figsize=(22, 10))

x = idx_to_labels.values()
y = pred_softmax.cpu().detach().numpy()[0] * 100
width = 0.45 
ax = plt.bar(x, y, width)
# plt.bar_label(ax, fmt='%.2f', fontsize=15) 
plt.tick_params(labelsize=20) 
plt.title(img_path, fontsize=30)
plt.xticks(rotation=45) # 横轴文字旋转
plt.xlabel('类别', fontsize=20)
plt.ylabel('置信度', fontsize=20)
plt.show()

n = 4
top_n = torch.topk(pred_softmax, n) # 取置信度最大的 n 个结果
pred_ids = top_n[1].cpu().detach().numpy().squeeze() 
confs = top_n[0].cpu().detach().numpy().squeeze() 
draw = ImageDraw.Draw(img_pil)
for i in range(n):
    class_name = idx_to_labels[pred_ids[i]] 
    confidence = confs[i] * 100  
    text = '{:<15} {:>.4f}'.format(class_name, confidence)
    print(text)

    draw.text((50, 100 + 50 * i), text, font=font, fill=(255, 0, 0, 1))

img_pil.show()