介绍了pytorch中的数据加载和批处理功能

In [7]:
import torch
import torch.utils.data as Data
import warnings
warnings.filterwarnings('ignore')

### 1. DataSet类

Dataset是一个抽象类, 为了能够方便的读取，需要将要使用的数据包装为Dataset类。同时Dataset也是Tensorloader类和Dataloader类等数据集操作相关类的父类.
自定义的Dataset需要继承它并且实现两个成员方法：
1. __getitem__() 该方法定义每次怎么获取数据
2. __len__() 该方法返回数据集的总长度

In [8]:
# 以鸢尾花数据为例
import pandas as pd
class IrisDataSet(Data.Dataset):
    def __init__(self, csv_file):
        """实现初始化方法，在初始化的时候将数据读载入"""
        self.df=pd.read_csv(csv_file)
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
            return self.df.iloc[idx]

In [11]:
iris = IrisDataSet('./iris.csv')

In [13]:
# 实现了__len__ 方法所以可以直接使用len获取数据总数
len(iris)

150

In [14]:
# 用索引可以直接访问对应的数据
iris[0]

SepalLength            5.1
SepalWidth             3.5
PetalLength            1.4
PetalWidth             0.2
Name           Iris-setosa
Name: 0, dtype: object

### 2. TensorDataSet类

继承自DataSet类，可以将不同的tensor进行组合封装（如features和label）

In [30]:
iris_df = pd.read_csv('./iris.csv')
map_dict = {"Iris-virginica":0, "Iris-setosa":1, "Iris-versicolor":2}   # TensorDateSet不支持String类型
iris_X = iris_df.iloc[:, :-1].values
iris_Y = iris_df.iloc[:, -1].apply(lambda x: map_dict[x]).values

In [31]:
iris_X_tensor = torch.from_numpy(iris_X)
iris_Y_tensor = torch.from_numpy(iris_Y)

In [34]:
iris_dataset = Data.TensorDataset(iris_X_tensor, iris_Y_tensor)
iris_dataset

<torch.utils.data.dataset.TensorDataset at 0x11e9a7128>

### 3. DataLoader类

数据加载的迭代器，可用来进行shuffle和batch_size等的定义

In [35]:
"""
shuffle:用于在每个batch中，将数据进行打乱
batch_size:每个batch的大小，对于每个epoch中最后一个batch，若不足按照不足的数量取
"""
loader = Data.DataLoader(dataset=iris_dataset, shuffle=True, batch_size=20)

for epoch in range(3):
    for step, (batch_x, batch_y) in enumerate(loader):    # 每个step
        if step % 10 == 0:
            print("Epoch: ", epoch, "| Step: ", step, "|batch_x: ", batch_x.numpy(), "|batch_y: ", batch_y.numpy())

Epoch:  0 | Step:  0 |batch_x:  [[5.5 2.4 3.7 1. ]
 [5.1 3.5 1.4 0.2]
 [4.9 3.1 1.5 0.1]
 [5.8 2.7 5.1 1.9]
 [5.1 2.5 3.  1.1]
 [4.6 3.2 1.4 0.2]
 [7.2 3.6 6.1 2.5]
 [6.5 2.8 4.6 1.5]
 [4.9 2.5 4.5 1.7]
 [6.1 3.  4.9 1.8]
 [6.7 3.1 4.7 1.5]
 [4.8 3.1 1.6 0.2]
 [7.1 3.  5.9 2.1]
 [6.7 3.3 5.7 2.5]
 [5.4 3.  4.5 1.5]
 [7.3 2.9 6.3 1.8]
 [6.  2.2 5.  1.5]
 [5.  3.3 1.4 0.2]
 [5.5 2.4 3.8 1.1]
 [4.4 3.2 1.3 0.2]] |batch_y:  [2 1 1 0 2 1 0 2 0 0 2 1 0 0 2 0 0 1 2 1]
Epoch:  1 | Step:  0 |batch_x:  [[4.9 2.5 4.5 1.7]
 [5.4 3.  4.5 1.5]
 [6.  3.4 4.5 1.6]
 [5.6 3.  4.5 1.5]
 [6.7 3.3 5.7 2.5]
 [5.3 3.7 1.5 0.2]
 [5.9 3.  4.2 1.5]
 [7.2 3.2 6.  1.8]
 [6.8 3.  5.5 2.1]
 [6.7 2.5 5.8 1.8]
 [4.8 3.  1.4 0.3]
 [7.2 3.  5.8 1.6]
 [5.8 2.6 4.  1.2]
 [6.9 3.1 4.9 1.5]
 [6.6 3.  4.4 1.4]
 [5.2 4.1 1.5 0.1]
 [5.  3.4 1.6 0.4]
 [6.2 3.4 5.4 2.3]
 [6.5 2.8 4.6 1.5]
 [5.5 2.5 4.  1.3]] |batch_y:  [0 2 2 2 0 1 2 0 0 0 1 0 2 2 2 1 1 0 2 2]
Epoch:  2 | Step:  0 |batch_x:  [[4.3 3.  1.1 0.1]
 [7.3 2.9 6.3 1.8