In [5]:
import torch
from torch.utils.data import Dataset, Sampler, DataLoader
from PIL import Image
import os, math
import random
import torchvision.transforms as T

class PKDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.samples = []
        self.class_to_idx = {}
        classes = sorted(os.listdir(root))
        self.num_classes = len(classes)
        for label, cls in enumerate(classes):
            path = os.path.join(root, cls)
            self.class_to_idx[cls] = label
            for fname in os.listdir(path):
                if fname.lower().endswith((".jpg", ".png", ".jpeg")):
                    img_path = os.path.join(path, fname)
                    self.samples.append((img_path, label))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform: img = self.transform(img)
        return img, label

class PKSampler(Sampler):
    def __init__(self, data, P=16, K=16, shuffle=True):
        self.data = data
        self.P = P
        self.K = K
        self.shuffle = shuffle
        self.lab2idx = {}
        # 建立 label -> indices 映射
        for idx, (_, label) in enumerate(data.samples):
            self.lab2idx.setdefault(label, []).append(idx)
        self.labels_unique = list(self.lab2idx.keys())
        # self.lab2idx
        # {
        #     0: [0, 5, 12, 27, ...],
        #     1: [1, 9, 18, ...],
        #     2: [2, 6, 8, ...],
        #     ...
        # }
        # self.labels_unique [0, 1, 2, ...], 随机打乱标签抽取 P 个，每类再取 K 个样本
    def __iter__(self):
        # 拷贝，避免修改原列表，打乱的话，就在类级别打乱
        classes = self.labels_unique.copy()
        if self.shuffle:
            random.shuffle(classes)
        batch = []
        # 步长为 P，每次切出 P 个类
        for i in range(0, len(classes), self.P):
            cls_batch = classes[i:i + self.P]
            # 不足 P 个，从所有类里面随机补到 P 个
            if len(cls_batch) < self.P:
                cls_batch += random.sample(self.labels_unique, self.P - len(cls_batch))
            for c in cls_batch:
                # 取出这类的所有样本，足够长，无放回随机抽 K 个，不足的话就有放回抽样
                idx_pool = self.lab2idx[c]
                if len(idx_pool) >= self.K:
                    chosen = random.sample(idx_pool, self.K)
                else:
                    chosen = random.choices(idx_pool, k=self.K)
                # 凑好的 batch 返回，再次将 batch 置为空
                batch.extend(chosen)
            random.shuffle(batch)
            yield batch
            batch = []

    def __len__(self):
        # 不是样本数，而是批次数的估计值，最后一个可能因为补类而略有误差，不过对 DataLoader 来说够用了
        return math.ceil(len(self.labels_unique) / self.P)
    
# ======================================================
# =============== 验证数据集的创建情况 ==================
# ======================================================
if __name__ == "__main__":
    transform = T.Compose([
        T.Resize((112, 112)),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    dataset = PKDataset("./CASIA-WebFace", transform=transform)
    sampler = PKSampler(dataset, P=8, K=8)
    loader = DataLoader(dataset, batch_sampler=sampler, num_workers=0, pin_memory=True)
    loader_iter = iter(loader)
    for i in range(4):
        imgs, labels = next(loader_iter)
        print(f"batch {i}: imgs.shape={imgs.shape}, unique_classes={len(torch.unique(labels))}")
        print(labels)

batch 0: imgs.shape=torch.Size([64, 3, 112, 112]), unique_classes=8
tensor([ 8533,  8072,  8072,  8533,  9137,  8072,  8047,  9137,  8047,  8533,
         4938,  8533,   297,   297, 10219, 10219, 10219,  8533,   297,  9137,
         4938, 10219,  8072,  4938,  8072,  1927, 10219,  8047,  4938,  4938,
         8533,   297,  8533,  4938,  9137,  1927,   297,  9137,  8047,  8047,
         1927,  1927,  8047,  1927,  8533,   297,  8072,  4938,  1927,  9137,
         1927, 10219,  1927,  8072,   297,   297, 10219,  8072,  4938,  8047,
         8047,  9137, 10219,  9137])
batch 1: imgs.shape=torch.Size([64, 3, 112, 112]), unique_classes=8
tensor([ 931, 6165, 7620, 6165,  931,  931, 7620, 7620, 3396, 1070, 1070,  835,
        7620, 1070, 7620, 3396, 7620,  835,  931, 3396, 1070, 3396,  835, 5486,
        3536,  931, 3536, 3536, 3536, 3536, 6165, 5486,  931, 5486, 7620,  931,
        6165, 6165,  835, 5486,  835, 7620,  835, 3536, 3396, 5486, 3396,  931,
        6165, 1070, 3396, 5486, 3536, 1

In [2]:
import torchvision.transforms as T
from torch.utils.data import Dataset, Sampler, DataLoader
from PIL import Image
import os
import random

class OverlapDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.samples = []
        self.class_to_idx = {}
        classes = sorted(os.listdir(root))
        self.num_classes = len(classes)
        for label, cls in enumerate(classes):
            path = os.path.join(root, cls)
            self.class_to_idx[cls] = label
            for fname in os.listdir(path):
                if fname.lower().endswith((".jpg", ".png", ".jpeg")):
                    img_path = os.path.join(path, fname)
                    self.samples.append((img_path, label))
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("L")
        if self.transform: img = self.transform(img)
        return img, label

class OverlapSampler(Sampler):
    def __init__(self, data, batch_size, overlap_ratio=0.4, shuffle=True):
        self.data = data
        self.batch_size = batch_size
        self.overlap_ratio = overlap_ratio
        self.step = int(batch_size * (1 - overlap_ratio))
        self.shuffle = shuffle
        self.indices = [
            list(range(i, i + batch_size))
            for i in range(0, len(data) - batch_size + 1, self.step)
        ]
    def __iter__(self):
        indices = list(range(len(self.data)))
        if self.shuffle: random.shuffle(indices)
        for i in range(0, len(self.data) - self.batch_size + 1, self.step):
            batch_idxs = indices[i:i + self.batch_size]
            yield batch_idxs
    def __len__(self):
        return (len(self.data) - self.batch_size) // self.step + 1

# ======================================================
# ===============验证数据集的创建情况====================
# ======================================================
if __name__ == "__main__":
    transform = T.Compose([
        T.ToTensor()
    ])
    dataset = OverlapDataset("./mnist_test_torch", transform=transform)
    sampler = OverlapSampler(dataset, batch_size=8, overlap_ratio=0.4)
    loader = DataLoader(dataset, batch_sampler=sampler, num_workers=0, pin_memory=True)
    # DataLoader 的底层逻辑的先看你有没有提供自己的 Sampler，否则的话自己创建随机的采样，当循环遍历 loader 的时候，这个迭代器会不断的调用 indices=next(iter(sampler))
    # 和 Dataset 类似，sampler 必须实现两个函数 __iter__ 和 __len__，__iter__ 返回的一堆索引，也就是 DataLoader 加载数据的索引
    loader_iter = iter(loader)
    for i in range(4):  # 取前4个batch
        imgs, labels = next(loader_iter)
        print(f"batch {i}: imgs.shape={imgs.shape}, labels={labels.tolist()}")

batch 0: imgs.shape=torch.Size([8, 1, 28, 28]), labels=[1, 7, 6, 5, 4, 4, 3, 8]
batch 1: imgs.shape=torch.Size([8, 1, 28, 28]), labels=[4, 4, 3, 8, 7, 9, 5, 1]
batch 2: imgs.shape=torch.Size([8, 1, 28, 28]), labels=[7, 9, 5, 1, 8, 9, 7, 7]
batch 3: imgs.shape=torch.Size([8, 1, 28, 28]), labels=[8, 9, 7, 7, 5, 6, 7, 8]
