In [None]:
import torch
from torch.utils.data import Dataset, DataLoader


class CustomDataset(Dataset):
    def __init__(self, x, y):
        """
        Initialize the dataset with x and y values.
        Arguments:
        x (torch.Tensor): The input features.
        y (torch.Tensor): The output labels.
        """
        self.x = x
        self.y = y

    def __len__(self):
        """
        Return the total number of samples in the dataset.
        """
        return len(self.x)

    def __getitem__(self, idx):
        """
        Fetch the sample at index `idx` from the dataset.
        Arguments:
        idx (int): The index of the sample to retrieve.
        """
        return self.x[idx], self.y[idx]


# Generate synthetic data
torch.manual_seed(0)  # For reproducibility
x = torch.arange(0, 100, dtype=torch.float32)
y = 2 * x - 1

# Create an instance of CustomDataset
dataset = CustomDataset(x, y)

# Use DataLoader to handle batching and shuffling
data_loader = DataLoader(dataset, batch_size=10, shuffle=True)

# Iterate over the DataLoader
for batch_idx, (inputs, labels) in enumerate(data_loader):
    print(f"Batch {batch_idx + 1}")
    print("Inputs:", inputs)
    print("Labels:", labels)
    # Break after the first batch for demonstration
    if batch_idx == 0:
        break


Batch 1
Inputs: tensor([26., 88., 59., 58., 73., 11., 65.,  2., 84., 79.])
Labels: tensor([ 51., 175., 117., 115., 145.,  21., 129.,   3., 167., 157.])


In [None]:
class DebugDataset(Dataset):
    def __init__(self, size=20):
        self.data = torch.arange(size, dtype=torch.float32)

    def __len__(self):
        print(f"__len__ called! Returning {len(self.data)}")
        return len(self.data)

    def __getitem__(self, idx):
        print(f"__getitem__ called with idx={idx}")
        return self.data[idx], self.data[idx] * 2


# 実行してみる
debug_dataset = DebugDataset(size=20)
debug_loader = DataLoader(debug_dataset, batch_size=3, shuffle=True)

print("Creating iterator...")
data_iter = iter(debug_loader)

print("\nGetting first batch...")
batch = next(data_iter)

__len__ called! Returning 20
__len__ called! Returning 20
Creating iterator...

Getting first batch...
__len__ called! Returning 20
__len__ called! Returning 20
__getitem__ called with idx=12
__getitem__ called with idx=1
__getitem__ called with idx=13
