In [None]:
from hartufo import CipicPlane
from hartufo import ImageSpec, SubjectSpec
from hartufo.torch import collate_dict_dataset
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, ToPILImage
from torchvision.transforms import RandomRotation, CenterCrop, ToTensor
from pathlib import Path
import numpy as np

In [None]:
base_dir = Path('../hartufo-collections')

In [None]:
ds = CipicPlane(base_dir / 'CIPIC', side='any-left', plane='horizontal', hrir_role='target', other_specs=dict(
        features_spec=ImageSpec(transform=[RandomRotation(10), ToTensor()]),
        group_spec=SubjectSpec(),
    )
)

In [None]:
ToPILImage()(ds[0]['features'])

Photos are loaded as given, which means they could be different sizes. This leads to problems when trying to batch them together in a `DataLoader`.

In [None]:
try:
    ds_loader = DataLoader(ds, batch_size=8, collate_fn=collate_dict_dataset)
    next(iter(ds_loader))
except RuntimeError as err:
    print(err)

The solution is to add a (PyTorch) Transform that outputs images of equal size, e.g. `RandomCrop`. But what is the size we want to crop to? Needs to be adjusted for to the original size. Let's find out then.

In [None]:
min_size = np.min([i['features'].shape[-2:] for i in ds])
min_width, min_height = np.min([i['features'].shape[-2:] for i in ds], axis=0)
min_size, min_width, min_height

Now that we know the minimum size of all images, we can add another Transform to the processing pipeline. Since it's a stochastic transformation, it makes sense to add it as a `transform`, not a `preprocessing`. This will help with overfitting. An additional advantage is that we don't need to load the dataset again, but can just append the transformation.

In [None]:
ds.append_transform(ImageSpec, CenterCrop(min_size))

Now batching the data works, and the all image sizes are the equal to the one given.

In [None]:
ds_loader = DataLoader(ds, batch_size=8, collate_fn=collate_dict_dataset)
X, y = next(iter(ds_loader))
X.shape

In [None]:
ToPILImage()(ds[0]['features'])