In [None]:
import mindspore.dataset.vision.c_transforms as cv
import numpy as np
from mindspore import load_checkpoint, load_param_into_net
from mindspore.train import Model
import mindspore.nn as nn
import mindspore as ms
from PIL import Image
import matplotlib.pyplot as plt


class AlexNet(nn.Cell):
    def __init__(self, num_classes=2):
        super(AlexNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 96, 11, stride=4, pad_mode='valid')
        self.conv2 = nn.Conv2d(96, 256, 5, stride=1, pad_mode='same')
        self.conv3 = nn.Conv2d(256, 384, 3, stride=1, pad_mode='same')
        self.conv4 = nn.Conv2d(384, 384, 3, stride=1, pad_mode='same')
        self.conv5 = nn.Conv2d(384, 256, 3, stride=1, pad_mode='same')
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Dense(6*6*256, 4096)
        self.fc2 = nn.Dense(4096, 4096)
        self.fc3 = nn.Dense(4096, num_classes)


    def construct(self, x):
        x = self.conv1(x) #卷积1
        x = self.relu(x)  #激活
        x = self.max_pool2d(x) #池化

        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)

        x = self.conv3(x)
        x = self.relu(x)

        x = self.conv4(x)
        x = self.relu(x)

        x = self.conv5(x)
        x = self.relu(x)
        x = self.max_pool2d(x)

        x = self.flatten(x)

        x = self.fc1(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x


def data_pre_process(dataset, columns_list, resize):
    mean = [122.96757279 / 255, 122.96757279 / 255, 122.96757279 / 255]
    std = [55.55022323 / 255, 55.55022323 / 255, 55.55022323 / 255]

    trans = [cv.Decode(),
             cv.Resize([resize, resize]),
             cv.Normalize(mean=mean, std=std),
             cv.HWC2CHW()]

    dataset = dataset.map(operations=trans,
                          input_columns=columns_list[0],
                          num_parallel_workers=1)
    dataset = dataset.batch(1)

    return dataset


def img_pre_process(path):
    image = Image.open(path).convert("RGB")
    image = image.resize((227, 227))
    plt.imshow(image)

    # 归一化处理
    mean = np.array([122.96757279 / 255, 122.96757279 / 255, 122.96757279 / 255])
    std = np.array([55.55022323 / 255, 55.55022323 / 255, 55.55022323 / 255])
    image = np.array(image)
    image = (image - mean) / std
    image = image.astype(np.float32)
    # 缩放
    image = image / 255
    # 图像通道由(h, w, c)转换为(c, h, w)
    image = np.transpose(image, (2, 0, 1))
    # 扩展数据维数为(1, c, h, w)
    image = np.expand_dims(image, axis=0)
    return image


# 加载模型
def reload_model():
    #param_dict = load_checkpoint("./check_point/alex-150_79.ckpt")
    param_dict = load_checkpoint("alex-150_79.ckpt")
    net = AlexNet(num_classes=2)

    load_param_into_net(net, param_dict)
    net_loss = nn.MSELoss()
    net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)
    md = Model(net, loss_fn=net_loss, optimizer=net_opt, metrics={"accuracy"})
    return md


labels = {0:"NORMAL", 1:"PNEUMONIA"}
image1 = r'F:\chest_xray\test\NORMAL\IM-0079-0001.jpeg'
plt.figure(figsize=(15, 7))
plt.subplot(1, 2, 1)
model = reload_model()
x1 = img_pre_process(image1)
y1_pre = model.predict(ms.Tensor(x1))
label = np.argmax(y1_pre.asnumpy(), axis=1)
plt.title(labels[label[0]])
plt.show()
