In [2]:
import mindspore
from mindspore import nn
import mindspore.dataset.vision as vision
from mindspore.dataset import MnistDataset, transforms

## 定义模型类

In [3]:
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 10, weight_init="normal", bias_init="zeros"),
        )


    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

In [4]:
model = Network()
print(model)

Network(
  (flatten): Flatten()
  (dense_relu_sequential): SequentialCell(
    (0): Dense(input_channels=784, output_channels=512, has_bias=True)
    (1): ReLU()
    (2): Dense(input_channels=512, output_channels=512, has_bias=True)
    (3): ReLU()
    (4): Dense(input_channels=512, output_channels=10, has_bias=True)
  )
)


## 数据加载

In [5]:
def datapipe(path: str, batch_size: int = 32):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]

    label_transforms = transforms.TypeCast(mindspore.int32)

    dataset = MnistDataset(path)
    dataset = dataset.map(operations=image_transforms, input_columns="image")
    dataset = dataset.map(operations=label_transforms, input_columns="label")
    dataset = dataset.batch(batch_size=batch_size)

    return dataset


train_dataset = datapipe("MNIST_Data/train", 64)
test_dataset = datapipe("MNIST_Data/test", 64)

image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, label.shape)

(64, 1, 28, 28) (64,)


## 模型层

#### nn.Flatten

In [6]:
flatten_layer = nn.Flatten()
image_after_flatten = flatten_layer(image)
print(image_after_flatten.shape)

(64, 784)


#### nn.Dense

In [7]:
print("before dense:", image_after_flatten.shape)
dense_layer = nn.Dense(in_channels=784, out_channels=512, weight_init="normal", bias_init="zeros")
image_after_dense1 = dense_layer(image_after_flatten)
print("After dense:", image_after_dense1.shape)

before dense: (64, 784)
After dense: (64, 512)


#### nn.ReLU

In [8]:
print(f"before ReLU:{image_after_dense1.shape[0:5]}")
relu_layer = nn.ReLU()
image_after_relu = relu_layer(image_after_dense1[0:5])
print(f"after ReLU:{image_after_relu.shape[0:5]}")

before ReLU:(64, 512)
after ReLU:(5, 512)


#### nn.SequentialCell

In [9]:
dense_relu_sequential = nn.SequentialCell(
    nn.Dense(in_channels=28*28, out_channels=512),
    nn.ReLU(),
    nn.Dense(in_channels=512, out_channels=512),
    nn.ReLU(),
    nn.Dense(in_channels=512, out_channels=10)
)

image_after_sequential = dense_relu_sequential(image_after_flatten)
print(f"Shape of image after SequentialCell:{image_after_sequential.shape}")

Shape of image after SequentialCell:(64, 10)


#### nn.Softmax

In [None]:
softmax = nn.Softmax(axis=-1)
pred_probab = softmax(logits)

## 模型参数

In [11]:
model =  Network()

for param in model.get_parameters():
    print(param)
    break

Parameter (name=dense_relu_sequential.0.weight, shape=(512, 784), dtype=Float32, requires_grad=True)


In [15]:
model.untrainable_params

<bound method Cell.trainable_params of Network(
  (flatten): Flatten()
  (dense_relu_sequential): SequentialCell(
    (0): Dense(input_channels=784, output_channels=512, has_bias=True)
    (1): ReLU()
    (2): Dense(input_channels=512, output_channels=512, has_bias=True)
    (3): ReLU()
    (4): Dense(input_channels=512, output_channels=10, has_bias=True)
  )
)>