In [114]:
import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision.datasets import OxfordIIITPet

In [118]:
def load_dataset(split="trainval", root="/tmp/adl_data") -> Dataset:
    """
    returns a Dataset with:
    - images as normalized float32 tensors
    - labels as uint 8 tensors

    the labels are:
    - 1: foreground
    - 2: background
    - 3: not classified
    """
    image_transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
    label_transform = lambda label: torch.from_numpy(np.array(label))
    ds = OxfordIIITPet(
        root=root,
        split=split,
        target_types="segmentation",
        download=True,
        transform=image_transform,
        target_transform=label_transform,
    )
    return ds

In [116]:
def create_datasets(data_dir, valid_frac, labelled_frac):
    trainval_ds = load_dataset('trainval')
    test_ds = load_dataset('test')

    rng = torch.Generator()
    rng.manual_seed(0)
    train_all_ds, valid_ds = torch.utils.data.random_split(trainval_ds, (1-valid_frac, valid_frac), generator=rng)
    _, train_lab_ds = torch.utils.data.random_split(train_ds, (1-labelled_frac, labelled_frac), generator=rng)
    
    return train_all_ds, train_lab_ds, valid_ds

In [None]:
train_all_ds, train_lab_ds, valid_ds = create_datasets()

In [108]:
data_dir = 'data'
valid_frac = 0.2
labelled_frac = 0.125

In [110]:
len(test_ds), len(valid_ds), len(unlab_train_ds), len(lab_train_ds)

(3669, 736, 2576, 368)