In [None]:
import mindspore.nn as nn
from mindspore import context
from  mindspore.train import Model
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.callback import LossMonitor
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
import mindspore.dataset.vision.c_transforms as vision
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore.common import dtype as mstype
import mindspore.dataset as ds

batch_size = 32
lr = 0.01
momentum = 0.9
epoch_size = 10 #训练轮次
repeat_size = 1
num_classes = 2 #分类数目
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")  #仅用于演示程序结果，实际使用服务器GPU多轮次训练的ckpt文件
#context.set_context(mode=context.GRAPH_MODE,device_target="GPU")
train_file_name = './datasets/convert_dataset_to_mindrecord/data_to_mindrecord/train.mindrecord'
val_file_name = './datasets/convert_dataset_to_mindrecord/data_to_mindrecord/test/val.mindrecord'
test_file_name = './datasets/convert_dataset_to_mindrecord/data_to_mindrecord/reasoning/test.mindrecord'

#创建训练数据集
def create_dateset(file_name, batch_size=32, repeat_size=1, status='train', num_parallel_workers=1):
    define_data_set = ds.MindDataset(file_name, columns_list=['data', 'label'])  # 读取解析MindRecord数据文件构建数据集
    decode_op = vision.Decode()
    define_data_set = define_data_set.map(operations=decode_op, input_columns=["data"], num_parallel_workers=num_parallel_workers)

    resize_height, resize_width = 227, 227
    rescale = 1.0/255.0 #缩放因子
    shift = 0.0 #平移因子

    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) #双线性插值调整尺寸大小

    rescale_op = CV.Rescale(rescale, shift)#根据所给的缩放平移因子调整图像的尺寸大小 output = image * rescale + shift

    # normalize_op = CV.Normalize((122.96757279 / 255, 122.96757279 / 255, 122.96757279 / 255), (55.55022323 / 255, 55.55022323 / 255, 55.55022323 / 255))  #归一化

    if status == 'train':
        random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4]) #(h,w)；4个数 左上右下填充
        random_horizontal_op = CV.RandomHorizontalFlip() #按照0.5的概率水平随机翻转
    channel_swap_op = CV.HWC2CHW()
    type_cast_op = C.TypeCast(mstype.int32) #tensor转指定数据

    define_data_set = define_data_set.map(operations=type_cast_op, input_columns='label', num_parallel_workers=num_parallel_workers)
    define_data_set = define_data_set.map(operations=random_crop_op, input_columns='data', num_parallel_workers=num_parallel_workers)
    define_data_set = define_data_set.map(operations=random_horizontal_op, input_columns='data', num_parallel_workers=num_parallel_workers)
    define_data_set = define_data_set.map(operations=resize_op, input_columns='data', num_parallel_workers=num_parallel_workers)
    define_data_set = define_data_set.map(operations=rescale_op, input_columns='data', num_parallel_workers=num_parallel_workers)
    define_data_set = define_data_set.map(operations=channel_swap_op, input_columns='data', num_parallel_workers=num_parallel_workers)

    buffer_size = 10000

    define_data_set = define_data_set.shuffle(buffer_size=buffer_size)

    define_data_set = define_data_set.batch(batch_size, drop_remainder=True)

    define_data_set = define_data_set.repeat(repeat_size)

    return define_data_set
#创建验证数据集
def create_dateset_val(file_name, batch_size=32, repeat_size=1, status='val', num_parallel_workers=1):
    define_data_set = ds.MindDataset(file_name, columns_list=['data', 'label'])  # 读取解析MindRecord数据文件构建数据集
    decode_op = vision.Decode()
    define_data_set = define_data_set.map(operations=decode_op, input_columns=["data"], num_parallel_workers=num_parallel_workers)

    resize_height, resize_width = 227, 227
    rescale = 1.0 / 255.0 #缩放因子
    shift = 0.0 #平移因子

    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) #双线性插值调整尺寸大小

    rescale_op = CV.Rescale(rescale, shift)  #根据所给的缩放平移因子调整图像的尺寸大小 output = image * rescale + shift

    # normalize_op = CV.NormalizeCV.Normalize((122.96757279 / 255, 122.96757279 / 255, 122.96757279 / 255), (55.55022323 / 255, 55.55022323 / 255, 55.55022323 / 255))  #归一化

    channel_swap_op = CV.HWC2CHW()
    type_cast_op = C.TypeCast(mstype.int32) #tensor转指定数据

    define_data_set = define_data_set.map(operations=type_cast_op, input_columns='label', num_parallel_workers=num_parallel_workers)
    define_data_set = define_data_set.map(operations=resize_op, input_columns='data', num_parallel_workers=num_parallel_workers)
    define_data_set = define_data_set.map(operations=rescale_op, input_columns='data', num_parallel_workers=num_parallel_workers)
    define_data_set = define_data_set.map(operations=channel_swap_op, input_columns='data', num_parallel_workers=num_parallel_workers)

    buffer_size = 10000

    define_data_set = define_data_set.shuffle(buffer_size=buffer_size)

    define_data_set = define_data_set.batch(batch_size, drop_remainder=True)

    define_data_set = define_data_set.repeat(repeat_size)

    return define_data_set



class AlexNet(nn.Cell):
    def __init__(self, num_classes=10, channel=3):
        super(AlexNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 96, 11, stride=4, pad_mode='valid')
        self.conv2 = nn.Conv2d(96, 256, 5, stride=1, pad_mode='same')
        self.conv3 = nn.Conv2d(256, 384, 3, stride=1, pad_mode='same')
        self.conv4 = nn.Conv2d(384, 384, 3, stride=1, pad_mode='same')
        self.conv5 = nn.Conv2d(384, 256, 3, stride=1, pad_mode='same')
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Dense(6*6*256, 4096)
        self.fc2 = nn.Dense(4096, 4096)
        self.fc3 = nn.Dense(4096, num_classes)

    def construct(self, x):
        x = self.conv1(x) #卷积1
        x = self.relu(x)  #激活
        x = self.max_pool2d(x) #池化

        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)

        x = self.conv3(x)
        x = self.relu(x)

        x = self.conv4(x)
        x = self.relu(x)

        x = self.conv5(x)
        x = self.relu(x)
        x = self.max_pool2d(x)

        x = self.flatten(x)

        x = self.fc1(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = self.relu(x)

        x = self.fc3(x)

        return x

ds_train = create_dateset(train_file_name, batch_size, repeat_size)
ds_val = create_dateset_val(val_file_name, batch_size, repeat_size)
ds_test = create_dateset_val(test_file_name, batch_size, repeat_size)

network = AlexNet(num_classes=num_classes)

net_opt = nn.Momentum(network.trainable_params(), lr, momentum)

net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

metrics = {"Accuracy": nn.Accuracy()}
model = Model(network, net_loss, net_opt, metrics=metrics)

ckpt_config = CheckpointConfig(save_checkpoint_steps=30, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='alex', directory='./ckpt', config=ckpt_config)

#loss_cb = LossMonitor()
model.train(epoch_size, ds_train, callbacks=[ckpt_callback, LossMonitor(30)], dataset_sink_mode=False)

result = model.eval(ds_val)
print(result)
