## 数据集下载和加载

In [6]:
import mindspore
import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
from mindspore.dataset import MnistDataset, transforms
from download import download

In [None]:
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"

path = download(url, "./", kind="zip", progressbar=True, replace=True)

In [11]:
train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')

In [4]:
image, label = next(train_dataset.create_tuple_iterator())
print(image.shape, image.dtype, label)

(28, 28, 1) UInt8 1


In [4]:
print(train_dataset.get_col_names())

['image', 'label']


## 数据集处理和增强

#### 图像缩放

In [12]:
rescale = vision.Rescale(1.0/255.0, 0)
rescaled_image = rescale(image.asnumpy())
print(rescaled_image)

[[[[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  ...

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]]


 [[[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  ...

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]]


 [[[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  ...

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.]
   [0.]
   [0.]
   ...
   [0.]
   [0.]
   [0.]]

  [[0.

#### 标准化

In [13]:
normalize = vision.Normalize(mean=(0.1307,), std=(0.3081,))
normalized_image = normalize(rescaled_image)
print(normalized_image)

[[[[-0.42421296]
   [-0.42421296]
   [-0.42421296]
   ...
   [-0.42421296]
   [-0.42421296]
   [-0.42421296]]

  [[-0.42421296]
   [-0.42421296]
   [-0.42421296]
   ...
   [-0.42421296]
   [-0.42421296]
   [-0.42421296]]

  [[-0.42421296]
   [-0.42421296]
   [-0.42421296]
   ...
   [-0.42421296]
   [-0.42421296]
   [-0.42421296]]

  ...

  [[-0.42421296]
   [-0.42421296]
   [-0.42421296]
   ...
   [-0.42421296]
   [-0.42421296]
   [-0.42421296]]

  [[-0.42421296]
   [-0.42421296]
   [-0.42421296]
   ...
   [-0.42421296]
   [-0.42421296]
   [-0.42421296]]

  [[-0.42421296]
   [-0.42421296]
   [-0.42421296]
   ...
   [-0.42421296]
   [-0.42421296]
   [-0.42421296]]]


 [[[-0.42421296]
   [-0.42421296]
   [-0.42421296]
   ...
   [-0.42421296]
   [-0.42421296]
   [-0.42421296]]

  [[-0.42421296]
   [-0.42421296]
   [-0.42421296]
   ...
   [-0.42421296]
   [-0.42421296]
   [-0.42421296]]

  [[-0.42421296]
   [-0.42421296]
   [-0.42421296]
   ...
   [-0.42421296]
   [-0.42421296]
   [-0.4242

#### HWC2CHW

In [15]:
hwc2chw = vision.HWC2CHW()
chw_image = hwc2chw(normalized_image)
print(normalized_image.shape, chw_image.shape)

RuntimeError: Exception thrown from dataset pipeline. Refer to 'Dataset Pipeline Error Message'. 

------------------------------------------------------------------
- Dataset Pipeline Error Message: 
------------------------------------------------------------------
[ERROR] HWC2CHW: image shape should be <H,W> or <H,W,C>, but got rank: 4.

------------------------------------------------------------------
- C++ Call Stack: (For framework developers) 
------------------------------------------------------------------
mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc(726).




#### 数据分批

In [16]:
train_dataset = train_dataset.batch(batch_size=32)
for image, label in train_dataset.create_tuple_iterator():
    print(f"shape of image [N C H W]:{image.shape}")
    break

shape of image [N C H W]:(32, 28, 28, 1)


#### 数据预处理流水线

In [4]:
def datapipe(path:str, batch_size:int=32):
    image_transforms = [
        vision.Rescale(1.0/255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]

    label_transforms = transforms.TypeCast(mindspore.int32)

    dataset = MnistDataset(path)
    dataset = dataset.map(operations=image_transforms, input_columns="label")
    dataset = dataset.map(operations=label_transforms, input_columns="label")
    dataset = dataset.batch(batch_size=batch_size)

    return dataset

In [9]:
train_dataset = datapipe("MNIST_Data/train", 64)
test_dataset = datapipe("MNIST_Data/test", 64)

## 数据集迭代

In [17]:
for data in train_dataset.create_dict_iterator():
    print(data['image'].shape)
    print(data['label'].shape)
    break

(32, 28, 28, 1)
(32,)


In [18]:
image, label = next(train_dataset.create_tuple_iterator())
print(image.shape)
print(label.shape)

(32, 28, 28, 1)
(32,)
