In [1]:
import math
from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler

In [2]:
class MyDataset(Dataset):
    def __init__(self):
        self.data = [i*2 for i in range(10)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [3]:
class MySampler(Sampler):
    def __init__(self, data_source):
        self.data = data_source

    def __len__(self):
        len(self.data)

    def __iter__(self):
        return iter(reversed(range(len(self.data))))

In [4]:
class MyBatchSampler(BatchSampler):
    def __init__(self, data_source, batch_size):
        self.data_source = data_source
        self.batch_size = batch_size
    
    # Return an iteration, where each element is a list of indices
    def __iter__(self):
        N = len(self.data_source)
        for i in range(0, N, self.batch_size):
            yield list(range(i, min(i+self.batch_size, N)))

    def __len__(self):
        return math.ceil(len(self.data) / self.batch_size)


In [10]:
myds = MyDataset()
mysampler = MySampler(myds)
# Sequential sampling based on the indices provided by sampler
mydl = DataLoader(myds, batch_size=1, sampler=mysampler) # batch_size=1 mean each batch has one element
print("Use Sampler to determine the loop order of all values from the dataset (sequential, reversed sequential, or random)")
print("The Sampler will return a list of indices")
for dl in mydl:
    # print(f"{len(dl)}, {dl}")
    print(f"{dl}")

Use Sampler to determine the loop order of all values from the dataset (sequential, reversed sequential, or random)
The Sampler will return a list of indices
tensor([18])
tensor([16])
tensor([14])
tensor([12])
tensor([10])
tensor([8])
tensor([6])
tensor([4])
tensor([2])
tensor([0])


In [None]:
myds = MyDataset()
mybatchsampler = MyBatchSampler(myds, 3)
# Direct batch sampling, each batch is retrived from the indices provided by mybatchsampler
mydl = DataLoader(myds, batch_sampler=mybatchsampler)
print("Use BatchSampler to specify which elements to pick for a batch")
print("BatchSampler returns an iterator, where each element is a list of indices")
for dl in mydl:
    print(f"{len(dl)}, {dl}")

3, tensor([0, 2, 4])
3, tensor([ 6,  8, 10])
3, tensor([12, 14, 16])
1, tensor([18])
