In [2]:
import torch
import torchvision
from torch.utils.data import SubsetRandomSampler
from torchvision import transforms
import numpy as np

In [14]:
def sample_indices(ipc, split):
    if split == 'cifar10':
        cifar_dataset = torchvision.datasets.CIFAR10(root=f'../cifar10', train=True)
        num_classs = 10
    else:
        cifar_dataset = torchvision.datasets.CIFAR100(root=f'../cifar100', train=True)
        num_classs = 100
    indices = []

    for class_idx in range(num_classs):
        class_indices = np.where(np.array(cifar_dataset.targets) == class_idx)[0]
        sampled_indices = np.random.choice(class_indices, ipc, replace=False)
        indices.extend(sampled_indices)

    indices = np.array(indices)
    np.random.shuffle(indices)
    np.savetxt(f'./random_ipc_index/{split}/ipc_{ipc}.txt', indices, fmt='%d')

In [30]:
sample_indices(1, 'cifar10')

In [31]:
sample_indices(50, 'cifar10')

In [32]:
sample_indices(500, 'cifar10')

In [33]:
sample_indices(1000, 'cifar10')

Verify indices

In [35]:
ipc1000 = np.loadtxt('./random_ipc_index/cifar10/ipc_1000.txt', dtype=int)
ipc1000

array([24726,  9094, 42213, ..., 14137, 24313, 48237])

In [36]:
transform = transforms.Compose([
                                # transforms.Resize(128),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                # image_normalize,
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                              ])

cifar_dataset = torchvision.datasets.CIFAR10(root='../cifar10', train=True, transform=transform)
batch_size = 64
data_loader = torch.utils.data.DataLoader(
    cifar_dataset, 
    batch_size=batch_size, 
    sampler=SubsetRandomSampler(ipc1000)
)

all_targets = []
# Example usage of the data loader
for batch_idx, (data, targets) in enumerate(data_loader):
    all_targets.extend(targets.tolist())

In [37]:
from collections import Counter

counter = Counter(all_targets)
counter

Counter({3: 1000,
         8: 1000,
         7: 1000,
         9: 1000,
         1: 1000,
         4: 1000,
         2: 1000,
         5: 1000,
         0: 1000,
         6: 1000})

CIFAR100

In [15]:
sample_indices(1, 'cifar100')

In [16]:
sample_indices(10, 'cifar100')

In [17]:
sample_indices(50, 'cifar100')

In [18]:
ipc50 = np.loadtxt('./random_ipc_index/cifar100/ipc_50.txt', dtype=int)
ipc50

array([35640, 47103, 38839, ..., 48602, 17907, 37111])

In [19]:
transform = transforms.Compose([
                                # transforms.Resize(128),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                # image_normalize,
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                              ])

cifar_dataset = torchvision.datasets.CIFAR100(root='../cifar100', train=True, transform=transform)
batch_size = 64
data_loader = torch.utils.data.DataLoader(
    cifar_dataset, 
    batch_size=batch_size, 
    sampler=SubsetRandomSampler(ipc50)
)

all_targets = []
# Example usage of the data loader
for batch_idx, (data, targets) in enumerate(data_loader):
    all_targets.extend(targets.tolist())

In [20]:
from collections import Counter

counter = Counter(all_targets)
counter

Counter({89: 50,
         39: 50,
         98: 50,
         92: 50,
         38: 50,
         19: 50,
         18: 50,
         83: 50,
         57: 50,
         85: 50,
         67: 50,
         74: 50,
         54: 50,
         51: 50,
         44: 50,
         1: 50,
         5: 50,
         56: 50,
         12: 50,
         9: 50,
         25: 50,
         90: 50,
         94: 50,
         20: 50,
         4: 50,
         11: 50,
         29: 50,
         24: 50,
         15: 50,
         10: 50,
         22: 50,
         37: 50,
         78: 50,
         69: 50,
         3: 50,
         40: 50,
         36: 50,
         70: 50,
         64: 50,
         31: 50,
         0: 50,
         84: 50,
         33: 50,
         21: 50,
         62: 50,
         91: 50,
         41: 50,
         34: 50,
         81: 50,
         42: 50,
         46: 50,
         99: 50,
         66: 50,
         73: 50,
         13: 50,
         72: 50,
         47: 50,
         45: 50,
         43: 50,
   