## 数据集定义与加载

[官网地址](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/02_paddle2.0_develop/02_data_load_cn.html)

In [None]:
import paddle
import warnings
warnings.filterwarnings('ignore')

### 框架自带的数据集

In [7]:

print("视觉相关的数据集",paddle.vision.datasets.__all__)
print("文本相关的数据集",paddle.text.__all__)

视觉相关的数据集 ['DatasetFolder', 'ImageFolder', 'MNIST', 'FashionMNIST', 'Flowers', 'Cifar10', 'Cifar100', 'VOC2012']
文本相关的数据集 ['Conll05st', 'Imdb', 'Imikolov', 'Movielens', 'UCIHousing', 'WMT14', 'WMT16', 'ViterbiDecoder', 'viterbi_decode']


In [8]:
from paddle.vision.transforms import ToTensor

从远端下载到本机的缓存目录

In [9]:
train_dataset = paddle.vision.datasets.MNIST(mode='train',transform=ToTensor())
val_dataset = paddle.vision.datasets.MNIST(mode='test',transform=ToTensor())

### 自定义数据集

In [19]:
import paddle
from paddle.io import Dataset

BATCH_SIZE = 64
BATCH_NUM = 20

IMAGE_SIZE = (28,28)
CLASS_NUM = 10

class MyDataset(Dataset):
    """
    步骤一：继承paddle.io.Dataset类
    """
    def __init__(self,num_samples):
        """
        步骤二：实现构造函数，定义数据集大小
        """
        super(MyDataset,self).__init__()
        self.num_samples = num_samples
    
    def __getitem__(self,index):
        """
        步骤三：重写__getitem__方法，定义指定index时如何获取数据，返回数据及标签
        """
        data = paddle.uniform(IMAGE_SIZE,dtype="float32")
        label = paddle.randint(0,CLASS_NUM-1,dtype='int64')
        return data,label
    
    def __len__(self):
        """
        步骤四：重写__len__方法，返回数据集的数目
        """
        return self.num_samples
    
    

In [20]:
custom_dataset = MyDataset(BATCH_SIZE*BATCH_NUM)
print("==========custom dataset=============")
for data,label in custom_dataset:
    print(data.shape,label.shape)
    break

[28, 28] [1]


### 数据加载
>paddle.io.DataLoader 采用异步加载数据的方式读取数据，提升数据加载的速度

In [21]:
train_data = paddle.io.DataLoader(custom_dataset,batch_size=BATCH_SIZE,shuffle=True)

In [23]:
for batch_id,data in enumerate(train_data):
    x_data = data[0]
    y_data = data[1]
    print(x_data.shape)
    print(y_data.shape)
    break

[64, 28, 28]
[64, 1]


## 数据预处理

In [25]:
print("数据处理方法：",paddle.vision.transforms.__all__)

数据处理方法： ['BaseTransform', 'Compose', 'Resize', 'RandomResizedCrop', 'CenterCrop', 'RandomHorizontalFlip', 'RandomVerticalFlip', 'Transpose', 'Normalize', 'BrightnessTransform', 'SaturationTransform', 'ContrastTransform', 'HueTransform', 'ColorJitter', 'RandomCrop', 'Pad', 'RandomRotation', 'Grayscale', 'ToTensor', 'to_tensor', 'hflip', 'vflip', 'resize', 'pad', 'rotate', 'to_grayscale', 'crop', 'center_crop', 'adjust_brightness', 'adjust_contrast', 'adjust_hue', 'normalize']
