In [1]:
import torch
import numpy as np

## 主要内容
* `torch.utils.data.Dataset`   : 对单个样本的处理 
* `torch.utils.data.DataLoader`: 对多个样本而言

In [1]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt


In [3]:
training_data = datasets.FashionMNIST("../data/", train=True, download=False, transform=ToTensor() )
test_data = datasets.FashionMNIST("../data/", train=False, download=False, transform=ToTensor() )

In [6]:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader  = DataLoader(test_data, batch_size=64, shuffle=True)

# collate_fn: 对某个小批次进行处理

In [None]:
ds = datasets.FashionMNIST(root="../data/", train=True, download=False,
                           transform=ToTensor(),
                           # one-hot encoding
                           target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1)) 
                           )

## 构建自定义的`Dataset`
* 必须实现的3个函数：`__getitem__(self, )`, `__len__`, `__init__`
* 有两种类型：, 
    * `map-style datasets`: 使用更多
    * `iterable-style datasets`: 更适用于flow的计算场景

In [4]:
import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        # 
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir    = img_dir
        self.transform  = transform
        self.target_transform = target_transform
    
    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx): # 通常会比较复杂
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)               # 对原始数据的预处理
        if self.target_transform:
            label = self.target_transform(label)        # 对label的预处理
        return image, label

## 构建自定义的DataLoader
* 需要自行写下面一些函数
    * `sampler`
    * `batch_sampler(sampler, batch_size, drop_last)`
    * `collate_fn(batch_sample)`
* 通常利用`next(iter(dataloader))`进行调用