In [2]:
import torchvision.transforms as T
from torch.utils.data import Dataset, Sampler, DataLoader
from PIL import Image
import os
import random

class OverlapDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.samples = []
        self.class_to_idx = {}
        classes = sorted(os.listdir(root))
        self.num_classes = len(classes)
        for label, cls in enumerate(classes):
            path = os.path.join(root, cls)
            self.class_to_idx[cls] = label
            for fname in os.listdir(path):
                if fname.lower().endswith((".jpg", ".png", ".jpeg")):
                    img_path = os.path.join(path, fname)
                    self.samples.append((img_path, label))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("L")
        if self.transform: img = self.transform(img)
        return img, label

class OverlapSampler(Sampler):
    def __init__(self, data, batch_size, overlap_ratio=0.4, shuffle=True):
        self.data = data
        self.batch_size = batch_size
        self.overlap_ratio = overlap_ratio
        self.step = int(batch_size * (1 - overlap_ratio))
        self.shuffle = shuffle
        self.indices = [
            list(range(i, i + batch_size))
            for i in range(0, len(data) - batch_size + 1, self.step)
        ]
    def __iter__(self):
        indices = list(range(len(self.data)))
        if self.shuffle: random.shuffle(indices)
        for i in range(0, len(self.data) - self.batch_size + 1, self.step):
            batch_idxs = indices[i:i + self.batch_size]
            yield batch_idxs
    def __len__(self):
        return (len(self.data) - self.batch_size) // self.step + 1

# ======================================================
# ===============验证数据集的创建情况====================
# ======================================================
if __name__ == "__main__":
    transform = T.Compose([
        T.ToTensor()
    ])
    dataset = OverlapDataset("./mnist_test_torch", transform=transform)
    sampler = OverlapSampler(dataset, batch_size=8, overlap_ratio=0.4)
    loader = DataLoader(dataset, batch_sampler=sampler, num_workers=0, pin_memory=True)
    # DataLoader 的底层逻辑的先看你有没有提供自己的 Sampler，否则的话自己创建随机的采样，当循环遍历 loader 的时候，这个迭代器会不断的调用 indices=next(iter(sampler))
    # 和 Dataset 类似，sampler 必须实现两个函数 __iter__ 和 __len__，__iter__ 返回的一堆索引，也就是 DataLoader 加载数据的索引
    loader_iter = iter(loader)
    for i in range(4):  # 取前4个batch
        imgs, labels = next(loader_iter)
        print(f"batch {i}: imgs.shape={imgs.shape}, labels={labels.tolist()}")

batch 0: imgs.shape=torch.Size([8, 1, 28, 28]), labels=[1, 7, 6, 5, 4, 4, 3, 8]
batch 1: imgs.shape=torch.Size([8, 1, 28, 28]), labels=[4, 4, 3, 8, 7, 9, 5, 1]
batch 2: imgs.shape=torch.Size([8, 1, 28, 28]), labels=[7, 9, 5, 1, 8, 9, 7, 7]
batch 3: imgs.shape=torch.Size([8, 1, 28, 28]), labels=[8, 9, 7, 7, 5, 6, 7, 8]
