In [66]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Subset, DataLoader

In [74]:
def get_datasets(tasks: int):
    classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
    classes_per_task= torch.linspace(0, len(classes), tasks+1, dtype=torch.int)
    trainsets = []
    testsets = []
    train_targets = torch.tensor(trainset.targets)
    test_targets = torch.tensor(testset.targets)
    for i in range(len(classes_per_task)-1):
        train_indices = []
        test_indices = []
        for j in range(classes_per_task[i], classes_per_task[i+1]):
            train_indices.extend((train_targets == j).nonzero(as_tuple=False).flatten().tolist())
            test_indices.extend((test_targets == j).nonzero(as_tuple=False).flatten().tolist())
        trainsets.append(Subset(trainset, train_indices))
        testsets.append(Subset(testset, test_indices))
    return trainsets, testsets


In [75]:
trainsets, testsets = get_datasets(5)

Files already downloaded and verified
Files already downloaded and verified


In [76]:
datasets = trainsets+testsets
print(datasets)

[<torch.utils.data.dataset.Subset object at 0x7d17e2ae1570>, <torch.utils.data.dataset.Subset object at 0x7d17e2ae08e0>, <torch.utils.data.dataset.Subset object at 0x7d17e2ae27a0>, <torch.utils.data.dataset.Subset object at 0x7d17e2ae34f0>, <torch.utils.data.dataset.Subset object at 0x7d17e2ae3520>, <torch.utils.data.dataset.Subset object at 0x7d17e2ae0910>, <torch.utils.data.dataset.Subset object at 0x7d17e2ae0c40>, <torch.utils.data.dataset.Subset object at 0x7d17e2ae2200>, <torch.utils.data.dataset.Subset object at 0x7d17e2ae1f00>, <torch.utils.data.dataset.Subset object at 0x7d17e2ae2e60>]


In [77]:
for dataset in datasets:
    print(f"{dataset}")
    dataloader = DataLoader(dataset, batch_size=len(dataset),shuffle=False, num_workers=0)
    dataiter = iter(dataloader)
    images, labels = next(dataiter)
    print(f"{labels.shape=}")
    print(f"{torch.unique(labels)=}")

<torch.utils.data.dataset.Subset object at 0x7d17e2ae1570>
labels.shape=torch.Size([10000])
torch.unique(labels)=tensor([0, 1])
<torch.utils.data.dataset.Subset object at 0x7d17e2ae08e0>
labels.shape=torch.Size([10000])
torch.unique(labels)=tensor([2, 3])
<torch.utils.data.dataset.Subset object at 0x7d17e2ae27a0>
labels.shape=torch.Size([10000])
torch.unique(labels)=tensor([4, 5])
<torch.utils.data.dataset.Subset object at 0x7d17e2ae34f0>
labels.shape=torch.Size([10000])
torch.unique(labels)=tensor([6, 7])
<torch.utils.data.dataset.Subset object at 0x7d17e2ae3520>
labels.shape=torch.Size([10000])
torch.unique(labels)=tensor([8, 9])
<torch.utils.data.dataset.Subset object at 0x7d17e2ae0910>
labels.shape=torch.Size([2000])
torch.unique(labels)=tensor([0, 1])
<torch.utils.data.dataset.Subset object at 0x7d17e2ae0c40>
labels.shape=torch.Size([2000])
torch.unique(labels)=tensor([2, 3])
<torch.utils.data.dataset.Subset object at 0x7d17e2ae2200>
labels.shape=torch.Size([2000])
torch.unique(la