# Working with DataLoader

In [1]:
import csv
import random

def generate_csv(root_dir, file_label, num_rows: int = 5000, num_features: int = 20) -> None:
    fieldnames = ['label'] + [f'c{i}' for i in range(num_features)]
    writer = csv.DictWriter(open(f"{root_dir}/sample_data{file_label}.csv", "w"), fieldnames=fieldnames)
    writer.writerow({col: col for col in fieldnames})  # writing the header row
    for i in range(num_rows):
        row_data = {col: random.random() for col in fieldnames}
        row_data['label'] = random.randint(0, 9)
        writer.writerow(row_data)

In [2]:
import numpy as np
import torchdata.datapipes as dp

def build_datapipes(root_dir="."):
    datapipe = dp.iter.FileLister(root_dir)
    datapipe = datapipe.filter(filter_fn=lambda filename: "sample_data" in filename and filename.endswith(".csv"))
    datapipe = dp.iter.FileOpener(datapipe, mode='rt')
    datapipe = datapipe.parse_csv(delimiter=",", skip_lines=1)
    datapipe = datapipe.map(lambda row: {"label": np.array(row[0], np.int32),
                                         "data": np.array(row[1:], dtype=np.float64)})
    return datapipe

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import os
root_dir = 'data/demo1'
os.makedirs(root_dir, exist_ok=True)

In [4]:
from torch.utils.data import DataLoader

num_files_to_generate = 3
for i in range(num_files_to_generate):
    generate_csv(root_dir, file_label=i)

In [5]:
datapipe = build_datapipes(root_dir)
dl = DataLoader(dataset=datapipe, batch_size=4, shuffle=True)



In [6]:
first = next(iter(dl))
labels, features = first['label'], first['data']
print(f"Labels batch shape: {labels.size()}")
print(f"Feature batch shape: {features.size()}")

Labels batch shape: torch.Size([4])
Feature batch shape: torch.Size([4, 20])


In [7]:
labels

tensor([1, 2, 2, 1], dtype=torch.int32)

In [8]:
features

tensor([[0.1605, 0.4524, 0.6453, 0.3441, 0.8943, 0.9665, 0.7061, 0.0848, 0.0392,
         0.7990, 0.3001, 0.2297, 0.9757, 0.4467, 0.1888, 0.2140, 0.6491, 0.1927,
         0.3649, 0.6646],
        [0.5885, 0.1191, 0.0158, 0.9088, 0.3281, 0.3654, 0.8590, 0.4963, 0.4889,
         0.3741, 0.6869, 0.7386, 0.8826, 0.0956, 0.6551, 0.6319, 0.2490, 0.8867,
         0.9994, 0.6703],
        [0.5185, 0.8365, 0.8828, 0.6613, 0.8167, 0.6729, 0.4997, 0.8273, 0.6124,
         0.2519, 0.3120, 0.6760, 0.4806, 0.3849, 0.9350, 0.5775, 0.6906, 0.3942,
         0.5723, 0.3695],
        [0.1265, 0.9049, 0.9926, 0.3432, 0.3639, 0.5668, 0.0344, 0.1707, 0.6944,
         0.1797, 0.7056, 0.3878, 0.1656, 0.1083, 0.3677, 0.3308, 0.2670, 0.6090,
         0.5456, 0.6112]], dtype=torch.float64)