# Tutorial: Using an HRTF Dataset with PyTorch

In [None]:
from hartufo.torch import collate_dict_dataset
from hartufo import CipicPlane, AriPlane
from hartufo import CollectionSpec
from torch.utils.data import DataLoader, ConcatDataset, Subset
from pathlib import Path

In [None]:
base_dir = Path('../HRTF Datasets')

## Creating a PyTorch DataLoader

All instances of `hartufo.Dataset` have a class interface that is directly compatible with PyTorch [`Datasets`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset), specifically [map-style datasets](https://pytorch.org/docs/stable/data.html#map-style-datasets). They can therefore be directly used to create a a [`torch.util.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader). However, because `hartufo.Dataset` returns datapoints with a dict format (to allow advanced dataset splitting strategies based on indivisible groups), the provided `collate_dict_dataset` collation function is required when creating a `DataLoader` to convert the dataset into expected `(feature, target)` pairs.

In [None]:
cipic_ds = CipicPlane(base_dir / 'CIPIC', 'horizontal', 'magnitude_db', 'both-left', plane_angles = (-30, 0, 30),
                      other_specs={'target_spec': CollectionSpec()}, subject_ids=(3, 8))
cipic_loader = DataLoader(cipic_ds, collate_fn=collate_dict_dataset)
features, target = next(iter(cipic_loader))
features.shape, target, len(cipic_ds)

Due to the compatible interface of `hartufo.Dataset`, other `torch.util.data` functionality can be used too, to [chain](https://pytorch.org/docs/stable/data.html#torch.utils.data.ChainDataset) or [concatenate](https://pytorch.org/docs/stable/data.html#torch.utils.data.ConcatDataset) datasets, take [subsets](https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler) or use [custom samplers](https://pytorch.org/docs/stable/data.html#torch.utils.data.Sampler), for instance.

In [None]:
ari_ds = AriPlane(base_dir / 'ARI', 'horizontal', 'magnitude_db', 'both-left', plane_angles = (-30, 0, 30),
                  other_specs={'target_spec': CollectionSpec()}, subject_ids=(2, 4))

In [None]:
combined_ds = ConcatDataset([cipic_ds, ari_ds])
for ex in combined_ds:
    print(ex['target'])

## Creating data splits

The compatibility of `hartufo.Dataset` with `torch.util.data.Dataset` means that there are two possible approaches to choose from when splitting a dataset. You can either specify a `subject_ids` argument during the construction of an `hartufo.Dataset` or use `torch.util.data.Subset` functionality.

The former strategy has the advantage that you can use subject ids to specify the split instead of sequential indices (and remember that each subject can provide to one or two datapoint depending on the choice of `side`). The latter strategy is advisable when a single dataset is repeatedly split, such as during K-fold cross-validation, because this way all datapoints are read only once from disk whereas the former strategy would lead to reading each datapoint K times.

In [None]:
cipic_selected_ids = CipicPlane(base_dir / 'CIPIC', 'horizontal', 'magnitude_db', 'both-left', plane_angles = (-30, 0, 30),
                                other_specs={'target_spec': CollectionSpec()}, subject_ids=(8,)) # reads again from disk
cipic_subset = Subset(cipic_ds, [2, 3]) # uses dataset in memory

In [None]:
(cipic_selected_ids[-1]['features'] == cipic_subset[-1]['features']).all()

Finally, in case of contiguous splits, it is also possible index the dataset directly, but this is not always applicable.

In [None]:
cipic_indexed = cipic_ds[2:4]
(cipic_indexed['features'][-1] == cipic_subset[-1]['features']).all()