# 概要
[PyTorch Tutorial](https://pytorch.org/data/main/tutorial.html) を元にPyTorch DataのDataPipeの使い方を解説。

In [1]:
import pandas as pd
import torchdata.datapipes as dp
from torchdata.datapipes.iter import IterDataPipe

## Using DataPipes
DataPipeをつなげて使う方法の基本

In [2]:
ls data/

a.csv  b.csv  c.csv


In [3]:
# サンプルのCSVファイル
pd.read_csv('data/a.csv').head()

Unnamed: 0,a,b,c,d
0,3,2,8,7
1,7,7,5,0
2,8,3,4,5
3,0,1,8,4
4,8,6,0,4


In [4]:
# FileListerでファイルのリストを取得する
FOLDER = 'data'
datapipe = dp.iter.FileLister([FOLDER]).filter(filter_fn=(lambda filename: filename.endswith('.csv')))
datapipe



<torch.utils.data.datapipes.iter.selecting.FilterIterDataPipe at 0x123072ca0>

In [5]:
list(datapipe)

['data/a.csv', 'data/b.csv', 'data/c.csv']

In [6]:
# FileOpenerでファイルを開く
# 引数にFileListenerで取得したDataPipeを与えている
datapipe = dp.iter.FileOpener(datapipe, mode='rt')
datapipe

<torch.utils.data.datapipes.iter.fileopener.FileOpenerIterDataPipe at 0x123072100>

In [7]:
# (ファイル名, TextIO, StreamWrapper) というtupleで構成されたイテレータが返ってくる
list(datapipe)

[('data/a.csv',
  StreamWrapper<<_io.TextIOWrapper name='data/a.csv' mode='rt' encoding='UTF-8'>>),
 ('data/b.csv',
  StreamWrapper<<_io.TextIOWrapper name='data/b.csv' mode='rt' encoding='UTF-8'>>),
 ('data/c.csv',
  StreamWrapper<<_io.TextIOWrapper name='data/c.csv' mode='rt' encoding='UTF-8'>>)]

In [8]:
# StreamWrapperからファイルの内容を取得できる
sw = list(datapipe)[0][1]
sw.read()

'a,b,c,d\n3,2,8,7\n7,7,5,0\n8,3,4,5\n0,1,8,4\n8,6,0,4\n7,0,6,2\n7,4,5,5\n9,6,0,2\n9,9,8,8\n8,3,8,9\n3,7,2,1\n8,8,9,2\n6,6,2,9\n4,9,4,6\n2,1,0,4\n2,2,8,2\n8,6,0,5\n4,6,6,4\n1,2,4,7\n3,8,9,3\n7,5,1,4\n9,7,9,6\n2,6,2,5\n1,4,0,7\n3,6,5,0\n8,0,2,0\n1,3,9,0\n7,3,8,2\n4,4,6,1\n2,0,8,3\n0,4,6,6\n9,6,6,4\n9,2,9,1\n8,2,7,0\n0,1,8,4\n1,3,4,7\n6,9,0,1\n0,3,4,9\n5,3,3,5\n5,6,3,1\n9,7,2,9\n7,1,1,7\n4,0,7,6\n7,6,3,5\n8,6,9,3\n2,9,1,9\n4,5,5,2\n3,0,7,4\n0,1,5,6\n6,5,8,7\n4,4,8,1\n6,4,9,0\n3,6,8,0\n3,1,6,5\n8,8,5,3\n8,1,6,1\n7,3,9,8\n9,2,7,2\n4,5,4,4\n9,7,9,1\n5,3,9,3\n6,9,0,9\n9,3,4,2\n0,9,0,3\n2,6,5,4\n8,5,6,1\n6,2,9,2\n9,1,3,7\n7,7,4,4\n0,5,3,7\n4,6,0,3\n0,0,2,2\n6,8,3,0\n1,4,6,6\n4,7,8,6\n2,3,9,9\n6,6,4,8\n2,1,2,7\n0,8,3,0\n0,2,4,1\n4,0,8,0\n8,9,8,6\n7,0,1,5\n0,7,5,6\n3,2,9,6\n8,6,0,5\n2,4,1,7\n1,7,9,1\n5,9,8,5\n9,3,4,0\n2,8,0,2\n9,4,0,5\n4,9,1,1\n1,0,0,6\n0,0,8,1\n8,0,6,0\n5,1,2,3\n2,5,2,0\n4,5,1,0\n7,8,7,3\n'

In [9]:
# parse_csvメソッドでCSVをパース
datapipe_csv = datapipe.parse_csv(delimiter=',')
datapipe_csv

<torchdata.datapipes.iter.util.plain_text_reader.CSVParserIterDataPipe at 0x123093100>

In [10]:
list(datapipe_csv)

[['a', 'b', 'c', 'd'],
 ['3', '2', '8', '7'],
 ['7', '7', '5', '0'],
 ['8', '3', '4', '5'],
 ['0', '1', '8', '4'],
 ['8', '6', '0', '4'],
 ['7', '0', '6', '2'],
 ['7', '4', '5', '5'],
 ['9', '6', '0', '2'],
 ['9', '9', '8', '8'],
 ['8', '3', '8', '9'],
 ['3', '7', '2', '1'],
 ['8', '8', '9', '2'],
 ['6', '6', '2', '9'],
 ['4', '9', '4', '6'],
 ['2', '1', '0', '4'],
 ['2', '2', '8', '2'],
 ['8', '6', '0', '5'],
 ['4', '6', '6', '4'],
 ['1', '2', '4', '7'],
 ['3', '8', '9', '3'],
 ['7', '5', '1', '4'],
 ['9', '7', '9', '6'],
 ['2', '6', '2', '5'],
 ['1', '4', '0', '7'],
 ['3', '6', '5', '0'],
 ['8', '0', '2', '0'],
 ['1', '3', '9', '0'],
 ['7', '3', '8', '2'],
 ['4', '4', '6', '1'],
 ['2', '0', '8', '3'],
 ['0', '4', '6', '6'],
 ['9', '6', '6', '4'],
 ['9', '2', '9', '1'],
 ['8', '2', '7', '0'],
 ['0', '1', '8', '4'],
 ['1', '3', '4', '7'],
 ['6', '9', '0', '1'],
 ['0', '3', '4', '9'],
 ['5', '3', '3', '5'],
 ['5', '6', '3', '1'],
 ['9', '7', '2', '9'],
 ['7', '1', '1', '7'],
 ['4', '0',

In [11]:
# parse_csvはCSVParserをDataPipeに登録したものなので、CSVParserを使ってもできる
datapipe_with_csv_parser = dp.iter.CSVParser(datapipe, delimiter=',')

In [12]:
list(datapipe_with_csv_parser)

[['a', 'b', 'c', 'd'],
 ['3', '2', '8', '7'],
 ['7', '7', '5', '0'],
 ['8', '3', '4', '5'],
 ['0', '1', '8', '4'],
 ['8', '6', '0', '4'],
 ['7', '0', '6', '2'],
 ['7', '4', '5', '5'],
 ['9', '6', '0', '2'],
 ['9', '9', '8', '8'],
 ['8', '3', '8', '9'],
 ['3', '7', '2', '1'],
 ['8', '8', '9', '2'],
 ['6', '6', '2', '9'],
 ['4', '9', '4', '6'],
 ['2', '1', '0', '4'],
 ['2', '2', '8', '2'],
 ['8', '6', '0', '5'],
 ['4', '6', '6', '4'],
 ['1', '2', '4', '7'],
 ['3', '8', '9', '3'],
 ['7', '5', '1', '4'],
 ['9', '7', '9', '6'],
 ['2', '6', '2', '5'],
 ['1', '4', '0', '7'],
 ['3', '6', '5', '0'],
 ['8', '0', '2', '0'],
 ['1', '3', '9', '0'],
 ['7', '3', '8', '2'],
 ['4', '4', '6', '1'],
 ['2', '0', '8', '3'],
 ['0', '4', '6', '6'],
 ['9', '6', '6', '4'],
 ['9', '2', '9', '1'],
 ['8', '2', '7', '0'],
 ['0', '1', '8', '4'],
 ['1', '3', '4', '7'],
 ['6', '9', '0', '1'],
 ['0', '3', '4', '9'],
 ['5', '3', '3', '5'],
 ['5', '6', '3', '1'],
 ['9', '7', '2', '9'],
 ['7', '1', '1', '7'],
 ['4', '0',

## Working with DataLoader
DataPipeのDataLoaderへの渡し方。

In [13]:
# 1個のlabel、num_features個(defaultは20)のfeature、labelとfeatureは0 ~ 9の整数となるデータnum_rows行(defaultは5000行)を生成する関数。
# sample_data{file_label}.csvというファイルに保存される
import csv
import random

def generate_csv(file_label, num_rows: int = 5000, num_features: int = 20) -> None:
    fieldnames = ["label"] + [f"c{i}" for i in range(num_features)]
    writer = csv.DictWriter(
        open(f"sample_data{file_label}.csv", "w"), fieldnames=fieldnames
    )
    writer.writeheader()
    for i in range(num_rows):
        row_data = {col: random.random() for col in fieldnames}
        row_data["label"] = random.randint(0, 9)
        writer.writerow(row_data)

In [14]:
# sample_data*.csvというファイルをFileListerでリストし、FileOpenerで開き、parse_csvでCSVをparseし、
# mapにより、labelとdataに分割する関数。
# DataPipeを返す。

import numpy as np
import torchdata.datapipes as dp

def build_datapipes(root_dir="."):
    datapipe = dp.iter.FileLister(root_dir)
    datapipe = datapipe.filter(
        filter_fn=(
            lambda filename: "sample_data" in filename and filename.endswith(".csv")
        )
    )
    datapipe = dp.iter.FileOpener(datapipe, mode="rt")
    datapipe = datapipe.parse_csv(delimiter=",", skip_lines=1)
    datapipe = datapipe.map(
        lambda row: {
            "label": np.array(row[0], np.int32),
            "data": np.array(row[1:], dtype=np.float64),
        }
    )
    return datapipe

In [15]:
# ファイルを生成
num_files_to_generate = 3
for i in range(num_files_to_generate):
    generate_csv(file_label=i)

In [16]:
ls sample* # 3つのファイルが作成されている

sample_data0.csv  sample_data1.csv  sample_data2.csv


In [17]:
# 5000行21列のデータが存在
pd.read_csv("sample_data0.csv").shape

(5000, 21)

In [18]:
# labelとc0~c19のfeatureのデータ
pd.read_csv("sample_data0.csv").head()

Unnamed: 0,label,c0,c1,c2,c3,c4,c5,c6,c7,c8,...,c10,c11,c12,c13,c14,c15,c16,c17,c18,c19
0,2,0.733731,0.73441,0.007763,0.864862,0.766706,0.684404,0.85887,0.847387,0.346198,...,0.703272,0.095117,0.956363,0.036068,0.868932,0.233669,0.521803,0.109048,0.985162,0.697013
1,2,0.311069,0.600335,0.667525,0.349549,0.962604,0.200384,0.420654,0.613835,0.769992,...,0.656904,0.365,0.810078,0.548219,0.43337,0.065903,0.189145,0.106306,0.971309,0.073375
2,4,0.76448,0.223954,0.505824,0.912549,0.788068,0.502638,0.141343,0.927995,0.109977,...,0.63408,0.067928,0.536124,0.472923,0.604922,0.679119,0.330595,0.314843,0.245045,0.796531
3,7,0.046492,0.447479,0.245859,0.892229,0.536488,0.688338,0.936303,0.3106,0.623684,...,0.535334,0.291872,0.949824,0.312508,0.151952,0.616139,0.498997,0.01087,0.911386,0.574811
4,9,0.465764,0.060966,0.828345,0.169574,0.594609,0.635722,0.43834,0.72482,0.957098,...,0.798309,0.281965,0.123671,0.102455,0.061425,0.046492,0.777849,0.764276,0.59407,0.211788


In [19]:
# datapipeを作成
datapipe = build_datapipes()



In [20]:
# DataLoaderにはdataset=datapipeで渡せば良い
from torch.utils.data import DataLoader

dl = DataLoader(dataset=datapipe, batch_size=50, shuffle=True)
dl

<torch.utils.data.dataloader.DataLoader at 0x1230a39a0>

In [21]:
first = next(iter(dl))
first

{'label': tensor([2, 2, 4, 7, 9, 6, 1, 2, 4, 0, 9, 7, 1, 5, 3, 5, 7, 3, 5, 6, 9, 6, 1, 5,
         6, 0, 3, 3, 0, 9, 2, 2, 1, 6, 9, 4, 5, 8, 0, 0, 8, 7, 1, 5, 3, 1, 5, 2,
         2, 3], dtype=torch.int32),
 'data': tensor([[7.3373e-01, 7.3441e-01, 7.7627e-03, 8.6486e-01, 7.6671e-01, 6.8440e-01,
          8.5887e-01, 8.4739e-01, 3.4620e-01, 8.7435e-01, 7.0327e-01, 9.5117e-02,
          9.5636e-01, 3.6068e-02, 8.6893e-01, 2.3367e-01, 5.2180e-01, 1.0905e-01,
          9.8516e-01, 6.9701e-01],
         [3.1107e-01, 6.0033e-01, 6.6753e-01, 3.4955e-01, 9.6260e-01, 2.0038e-01,
          4.2065e-01, 6.1383e-01, 7.6999e-01, 3.2894e-01, 6.5690e-01, 3.6500e-01,
          8.1008e-01, 5.4822e-01, 4.3337e-01, 6.5903e-02, 1.8914e-01, 1.0631e-01,
          9.7131e-01, 7.3375e-02],
         [7.6448e-01, 2.2395e-01, 5.0582e-01, 9.1255e-01, 7.8807e-01, 5.0264e-01,
          1.4134e-01, 9.2800e-01, 1.0998e-01, 1.5579e-02, 6.3408e-01, 6.7928e-02,
          5.3612e-01, 4.7292e-01, 6.0492e-01, 6.7912e-01, 3

In [22]:
labels, features = first["label"], first["data"]

In [23]:
print(f"Labels batch shape: {labels.size()}")
print(f"Feature batch shape: {features.size()}")

Labels batch shape: torch.Size([50])
Feature batch shape: torch.Size([50, 20])


## Implementing a Custom DataPipe
独自のDataPipeを作成する。

命名規則は"Operation"-eｒ + IterDataPipe or MapDataPipe。エイリアスではIterDataPipeとMapDataPipeは取り除く。

この例では、 `MapperIterDataPipe` を作る。

In [24]:
# IterDataPipeを継承して、MapperIterDataPipeを作成。
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe


@functional_datapipe("new_map")   # DataPipeにmapメソッドを登録
class MapperIterDataPipe(IterDataPipe):
    def __init__(self, source_dp: IterDataPipe, fn) -> None:
        super().__init__()
        self.dp = source_dp
        self.fn = fn   # 関数により変換を加える

    def __iter__(self):
        for d in self.dp:
            yield self.fn(d["data"])   # 変換を加えたあとのiteratorを作成

    def __len__(self):   # DataPipeの長さを返す
        return len(self.dp)

In [25]:
# MaapperIterDataPipe内で実行したい関数を定義
def decoder(x):
    return x*2

In [26]:
datapipe = build_datapipes()
list(datapipe)

[{'label': array(2, dtype=int32),
  'data': array([0.73373068, 0.73441013, 0.00776266, 0.86486213, 0.76670628,
         0.68440372, 0.85886995, 0.84738717, 0.34619812, 0.8743479 ,
         0.70327222, 0.09511704, 0.95636327, 0.03606831, 0.86893188,
         0.23366899, 0.5218033 , 0.10904759, 0.98516194, 0.69701251])},
 {'label': array(2, dtype=int32),
  'data': array([0.31106927, 0.60033499, 0.66752523, 0.34954898, 0.96260424,
         0.20038396, 0.42065393, 0.61383459, 0.76999198, 0.32894419,
         0.65690442, 0.36499982, 0.8100784 , 0.54821943, 0.43336967,
         0.06590321, 0.18914463, 0.10630628, 0.97130934, 0.0733747 ])},
 {'label': array(4, dtype=int32),
  'data': array([0.76448027, 0.22395405, 0.5058241 , 0.91254884, 0.78806814,
         0.50263827, 0.14134268, 0.92799505, 0.10997732, 0.01557919,
         0.63408005, 0.06792791, 0.53612388, 0.47292323, 0.60492193,
         0.679119  , 0.33059482, 0.31484348, 0.24504498, 0.79653113])},
 {'label': array(7, dtype=int32),
  '

In [27]:
list(datapipe.new_map(fn=decoder))

[array([1.46746135, 1.46882026, 0.01552532, 1.72972426, 1.53341256,
        1.36880745, 1.7177399 , 1.69477433, 0.69239623, 1.7486958 ,
        1.40654444, 0.19023409, 1.91272655, 0.07213661, 1.73786377,
        0.46733799, 1.0436066 , 0.21809518, 1.97032387, 1.39402501]),
 array([0.62213854, 1.20066998, 1.33505045, 0.69909795, 1.92520848,
        0.40076791, 0.84130786, 1.22766918, 1.53998396, 0.65788838,
        1.31380884, 0.72999964, 1.62015681, 1.09643887, 0.86673934,
        0.13180643, 0.37828927, 0.21261256, 1.94261867, 0.14674939]),
 array([1.52896053, 0.44790811, 1.01164821, 1.82509768, 1.57613628,
        1.00527654, 0.28268537, 1.8559901 , 0.21995464, 0.03115838,
        1.2681601 , 0.13585581, 1.07224776, 0.94584645, 1.20984386,
        1.35823799, 0.66118964, 0.62968696, 0.49008996, 1.59306226]),
 array([0.09298397, 0.89495734, 0.49171844, 1.78445804, 1.07297629,
        1.37667534, 1.87260526, 0.62120076, 1.24736754, 1.96513032,
        1.0706688 , 0.58374302, 1.89964817