# Tutorial: Using a Dataset of Planes in PyTorch

In [None]:
from hrtfdata.torch.planar import CIPICPlane, ARIPlane, ListenPlane, BiLiPlane, ITAPlane, HUTUBSPlane, SADIE2Plane, ThreeDThreeAPlane, CHEDARPlane, WidespreadPlane, SONICOMPlane
from hrtfdata.torch import collate_dict_dataset
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
import matplotlib.pyplot as plt

## The minimum necessary to get started

The purpose of `hrtfdata` is to provide PyTorch [`Datasets`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) for multiple collections in a unified programming interface. Currently the following data collections are supported:

- CIPIC
- ARI
- Listen
- BiLi
- ITA
- HUTUBS
- SADIE II
- 3D3A
- CHEDAR
- Widespread
- SONICOM

Each of them has a corresponding class that can be loaded from `hrtfdata.torch.planar`, as is done above.

At minimum, you need to select which plane to use (out of `horizontal`, `median` or `frontal`), the HRIR/HRTF representation (`time`, `magnitude`, `magnitude_db`, `phase` or `complex`) and the side of the head (`left`, `right`, `both`, `both-left`, `both-right`).

In [None]:
plane = 'median'
domain = 'magnitude_db'
side = 'left'

These parameters can then be used together with a `XPlane` class and the path to the root directory of files for the corresponding collection (having the same directory structure as on the cluster).

In [None]:
base_dir = Path('../HRTF Datasets')
ds = ARIPlane(base_dir / 'ARI', plane, domain=domain, side=side)

The resulting object `ds` is an instance of `torch.utils.data.Dataset`.

In [None]:
isinstance(ds, Dataset)

Therefore you can get its size and access individual data points in the data set by indexing it.

In [None]:
len(ds)

In [None]:
p = ds[0]

Each datapoint is a dictionary that contains the keys `features`, `target` and `group`.

In [None]:
p.keys()

The `features` key gives the plane in the requested domain, the `target` key returns the side of the head and `group` gives the subject id.

In [None]:
p['features'].shape, p['target'], p['group']

This dict format makes it possible to split the dataset while keeping individuals grouped, but also requires the use of the non-default `collate_dict_dataset` collation function when creating a `torch.util.data.DataLoader` (to convert the dataset into expected `(feature, target)` pairs).

In [None]:
DataLoader(ds, collate_fn=collate_dict_dataset)

## Additional Dataset Functionality

In addition to the minimal methods provided by PyTorch `Datasets`, `XPlane` objects have some extra functionality.

You can get the sample rate of the HRIR:

In [None]:
ds.hrir_samplerate

or the frequencies of the HRTF:

In [None]:
ds.hrtf_frequencies

The angles in the plane can be obtained with:

In [None]:
ds.plane_angles

By default, the interval for the angles is [-180, 180) (horizontal and frontal plane) or [-90, 27) (median plane). Positive angles in the range [0, 360) can be requested by passing `positive_angles=True` to the `XPlane` constructor.

The angle interval can also be changed after the creation of the dataset, without needing to reload from disk, by setting the `.positive_angles` boolean property.

In [None]:
ds.positive_angles = True
ds.plane_angles

In all cases, the angle extrema are available as:

In [None]:
ds.min_angle, ds.max_angle

Using this additional info, you could create your own plots, but plotting functionality is also include. You can plot the angles of the plane with the `.plot_angles()` method.

In [None]:
ds.plot_angles()
plt.show()

Its complete type signature is `.plot_angles(ax=None, title=None)`, which contains optional arguments to draw on existing Matplotlib `Axes` or to overwrite the default title.

You can also plot the HRTF of a data point, by passing its index to `.plot_plane()`.

In [None]:
ds.plot_plane(0)
plt.show()

The type signature of this method is `.plot_plane(idx, ax=None, cmap='viridis', continuous=False, vmin=None, vmax=None, title=None, colorbar=True, log_freq=False)`, which contains quite some optional arguments. The `ax` and `title` options do the same as in `.plot_angles()`. Any [colour map available in Matplotlib](https://matplotlib.org/stable/gallery/color/colormap_reference.html) can be used by passing it to `cmap`. By default the minimum and maximum values of the colour map are set to the minimum and maximum value in the entire dataset, but they can be set using the `vmin` and `vmax` options. The colour bar that is shown by default can be disabled with `colorbar=False`.

The plane angles are plotted on a linear scale, so if the sampling of angles is non-uniform, certain angles will be drawn over larger areas in the plot than others. By default, the area up to halfway the next angle is filled with a uniform colour, resulting in a block-like appearance that can be used to visually inspect the distribution of angles in the plane. By passing `continuous=True`, intermediate angle values will be interpolated leading to a smooth picture. For frequency-domain HRTF representations, the option `log_freq=True` can be used to plot frequency on a logarithmic axis. Finally, the Matplotlib `Axes` get returned, allowing for further customisation of the plot.

An example demonstrating all these options can be found below.

In [None]:
_, ax = plt.subplots(figsize=(8,8))
ax = ds.plot_plane(0, ax=ax, cmap='gray', continuous=True, vmin=-120, vmax=0, title='', colorbar=False, log_freq=True)
ax.set_ylim(0.1, 18)
plt.show()

## Customising Dataset Contents

By default, all available subjects will be loaded into a data set. A list of ids of the subjects composing the data set can be obtained as:

In [None]:
ds.subject_ids

A subset of the data can be obtained using the standard [`torch.util.data.Subset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset) functionality. However, if you want more control over the contents of a data set, you can pass the `subject_ids` argument to an `XPlane` constructor. It should contain an iterable of subject ids.

In [None]:
ds = ARIPlane(base_dir / 'ARI', plane, domain=domain, side=side, subject_ids=(1, 2, 3, 4, 5))

If no subject with the given id exists, it gets silently skipped.

In [None]:
ds.subject_ids

For any dataset, regardless its contents, you can request the possible subject ids you can pass to the constructor:

In [None]:
ds.available_subject_ids

For convenience, you can create an empty data set by passing an empty list or tuple to `subject_ids`, then read `.available_subject_ids` to find out what subjects are available and create another data set with a subset of these ids.

In [None]:
ds = ARIPlane(base_dir / 'ARI', plane, domain=domain, side=side, subject_ids=())
ds.subject_ids

If you just need a single example of a data collection, you can instead pass one of the strings `first`, `last` or `random` to `subject_ids`. The first two deterministically load the first, respectively last, id in the collection, wheras `random` loads a random subject.

In [None]:
ds = ARIPlane(base_dir / 'ARI', plane, domain=domain, side=side, subject_ids='first')
ds.subject_ids

In [None]:
ds = ARIPlane(base_dir / 'ARI', plane, domain=domain, side=side, subject_ids='last')
ds.subject_ids

In [None]:
ds = ARIPlane(base_dir / 'ARI', plane, domain=domain, side=side, subject_ids='random')
ds.subject_ids

As a concluding example, the snippet below plots the angular distribution and an example HRTF of all available data collections for each of the three fundamental planes.

In [None]:
domain = 'magnitude_db'
side = 'both'
positive_angles = False
subject_ids = 'first'

for collection, data_dir in [
    (CIPICPlane, base_dir / 'CIPIC'),
    (ARIPlane, base_dir / 'ARI'),
    (ListenPlane, base_dir / 'Ircam Listen'),
    (BiLiPlane, base_dir / 'Ircam BiLi'),
    (ITAPlane, base_dir / 'ITA Aachen'), 
    (HUTUBSPlane, base_dir / 'HUTUBS'),
    (SADIE2Plane, base_dir / 'SADIE II'),
    (ThreeDThreeAPlane, base_dir / '3D3A'),
    (CHEDARPlane, base_dir / 'CHEDAR'),
    (WidespreadPlane, base_dir / 'Widespread'),
    (SONICOMPlane, base_dir / 'SONICOM'),
]:
    fig = plt.figure(figsize=(16, 18))
    fig.suptitle(collection.__name__)
    for idx, plane in enumerate(('horizontal', 'median', 'frontal')):
        plane_offset = -0.72 if collection == ITAPlane and plane == 'horizontal' else 0
        ds = collection(data_dir, plane, plane_offset=plane_offset, domain=domain, side=side, positive_angles=positive_angles, subject_ids=subject_ids)
        ax0 = fig.add_subplot(3, 2, 2*idx+1, projection='polar')
        ax1 = fig.add_subplot(3, 2, 2*idx+2)
        if domain.startswith('magnitude'):
            ax1.set_ylim((0, 18))
        elif domain == 'time':
            ax1.set_ylim((0, 3))
        ds.plot_angles(ax=ax0)
        ds.plot_plane(0, ax=ax1, continuous=False, log_freq=False)
    plt.show()