## Tasks and Labels

First, we distinguish between a `Task` object and a `Label` object.

A `Label` object represents only the _label_ data for a training example, decoupled from the input image entirely. This could be a list of points as in `PersonPoseLabel`, or a numpy array as in `SegmentationLabel`, or a list of bounding boxes as in `ObjectLabel`.

A `Task` is a coupling of a _label_ with an input _data_: it is essentially a pair, consisting of the input to the model and the ground truth. For instance, a `SegmentationTask` has two attributes: the `data` (a `np.array` representing an image) and the `label` (a `SegmentationLabel`).

A `MultiTask` is a coupling of an input _data_ with _multiple labels_. This way, we only need to store _one copy_ of the input data for _multiple different task types_. A `MultiTask` instance consists of two attributes: the `data` (a `np.array` representing an image), and a `labels` list, containing multiple different `Label` objects for multiple different tasks.

Every `Task` object comes with a `show()`, which displays the input data (image) along with its label(s).

Every `Label` object comes with a `collate()` classmethod, which determines how to collate multiple different labels into a single batched object.

## Transforms
The basic callable object, designed to handle multitask data, is a subclass of `AbstractMultiTaskTransform`. These objects, sometimes referred to as "transforms", are meant to be "smart functions" which can figure out what type of input they receive, and how to transform that input based on its type.

The general workflow is as follows:
1. Implement the `get_shared_parameters` abstract method.
2. Implement various `__call__` functions for the input _data_ type (in our case, an image, represented by an `nd.array`), and the input _label_ types (`SegmentationLabel`, `ObjectLabel`, etc.).

With these two simple steps you can then:
* Use your transform on `MultiTask` data,
* use your transform on single `Task` data,
* use your transform on lone `Label` data,
* or use your transform on any other type of input data, assuming an implementation for that data type has been provided.

Under the hood, the `AbstractMultiTaskTransform` will automatically figure out what to do.

In [None]:
class HorizontalFlip(AbstractMultiTaskTransform):
    def __init__(self, p = 0.5):
        super().__init__() 
        self.p = p # Probability of performing a flip

    def get_shared_parameters(self, inputs):
        return {"do_flip" : random.random() <= self.p}

    @dispatch_by_type(np.array)
    def __call__(self, image, params):
        if not params["do_flip"]:
            return image
        return numpy.flip(image, axis = 1)

    @dispatch_by_type(SegmentationLabel)
    def __call__(self, label, params):
        ...

    @dispatch_by_type(ObjectDetectionLabel)
    def __call__(self, label, params):
        ...

Above, we've implemented a `HorizontalFlip` transform which, with probability 0.5, will horizontally flip its inputs. We wrote an implementation for the _data_ type (`np.array`), and _two label_ types (`SegmentationLabel` and `ObjectDetectionLabel`). The decorators `@dispatch_by_type` are how `AbstractMultiTaskTransform` is able to figure out which function to call on the fly.

Now, we can use our transform as follows.

In [None]:
transform = HorizontalFlip()

task = MultiTask(image, [segmentation_label, object_label])

results = transform(task) 
# Even though we haven't specified a __call__ method for the `MultiTask` input type, the transform is able to figure out
# that it needs to apply itself to the task *data* (the image) and the task *labels* (the segmentation_label and object_label)

assert isinstance(results, MultiTask)
results.show() # The image, segmentation label, and object label will be present, and all will either be flipped or not (simultaneously).

The `AbstractMultiTaskTransform` object is extremely versatile - it can operate on `MultiTask` objects, `Task` objects, `Label` objects, or any other object (provided an implementation for the input type exists), and under the hood it will automatically figure out what it has to do.

In [None]:
# Transforms can operate on individual tasks:
single_task = SegmentationTask(image, segmentation_label)
results = transform(single_task)
assert isinstance(results, SegmentationTask)
results.show()

# Transforms can operate on individual labels:
results = transform(segmentation_label)
assert isinstance(results, SegmentationLabel)

# It can even operate on an image (represented as an np.array):
results = transform(image)
assert isinstance(results, np.array)

# We don't have to make _any modifications_ to `transform` for this to be possible. This functionality exists from the start!

In this example transform, we explicitly _left out_ an implementation of the transform for `PersonPoseLabel` objects. What happens when we attempt to call it?

In [None]:
task = MultiTask(image, [segmentation_label, object_label, pose_label]) # Now we have a label the transform can't operate on
results = transform(task)

# results is now a tuple, where the first item is the _transformed_ multitask, and the second item is the _untransformed_ multitask:
transformed, untransformed = results

transformed.show() # Displays the segmentation label and object label
untransformed.show() # Displays only the pose label

The `AbstractMultiTaskTransform` was able to figure out that it _couldn't operate_ on all of the input label types - it could only operate on some of them. As a result, it splits the output into the transformed labels and the untransformed labels.

We remark that this behaviour is _only available_ when the input is a `MultiTask` object. If we attempt to call `transform` on a single `PoseTask` or `PoseLabel` object, we'll get an error:

In [None]:
try:
    single_task = PoseTask(image, pose_label)
    transform(single_task)
except DynamicDispatchError:
    print("No implementation for pose labels available")


try:
    transform(pose_label)
except DynamicDispatchError:
    print("No implementation for pose labels available")

We can further control the behaviour of an `AbstractMultiTaskTransform` by changing it's `MultiTaskReturnPolicy`. By default, a transform will either return:
- One item, the transformed labels (if everything could be transformed).
- Two items, the transformed labels and untransformed labels (if some things couldn't be transformed).
- One item, the _original_ untransformed labels (if none of the labels could be transformed).

If we _only want_ to return the _transformed_ labels as output, and raise an error if no labels could be transformed, we can do the following:

In [None]:
transform.multitask_return_policy = MultiTaskReturnPolicy.OnlyTransformedStrict

# Now let's try operating on a multitask
task = MultiTask(image, [segmentation_label, object_label, pose_label])
results = transform(task)

# Now, `results` is a single item: the transformed labels
assert isinstance(results, MultiTask)
results.show() # only the segmentation label and object label

try:
    task = MultiTask(image, [pose_label1, pose_label2]) # A task with only pose data
    transform(task)
except HeterogeneousDataError:
    print("Received a task with no transformable labels.")

There are many different `MultiTaskReturnPolicy`'s you can use to control the behaviour of the transform. See the documentation for more details on what each policy does.

## The Problem with `.map`
Consider the following example:

In [None]:
datapipe = datapipe.map(MyTransform())

for item in datapipe:
    # Uh oh, `item` can be lots of things now, depending on whether or not `MyTransform` could operate on it.
    
    # `item` could be a single `MultiTask` object, it could be two `MultiTask` objects, or it could be None.
    ...

This is a problem because now we can't easily chain transforms together.

In [None]:
datapipe = datapipe.map(Transform1()).map(Transform2())
# This won't work - the results of the first transform might be a tuple, or None, and Transform2 can't operate on that.

The solution? Instead of using `.map`, use `.transform`:

In [None]:
datapipe = datapipe.transform(MyTransform())

for item in datapipe:
    # Now, `item` is guaranteed to be a single `MultiTask` object.
    ...

# Now this works as well:
datapipe = datapipe.transform(Transform1()).transform(Transform2())

# After each application of `transform`, the datapipe is _guaranteed_ to yield individual `MultiTask` instances one-by-one.

In [None]:
datapipe = [MultiTask(image, [segmentation_label, pose_label, object_label])]

datapipe = datapipe.transform(MyTransfom()) # Assume MyTransform only works on segmentation labels

for i, item in datapipe.enumerate():
    print(item.labels)

# Output:
#   0, [segmentation_label]
#   1, [pose_label, object_label]

The `.transform` datapipe is essentially a "smart" version of `.map`. It can:
1. Automatically flatten the output of the transform.
2. Skip over None values returned from the transform.
3. Modify the MultiTaskReturnPolicy of the transform.

In [None]:
datapipe = datapipe.transform(MyTransform(), skip_none = True, flatten = True, multitask_return_policy = ...)
# `skip_none` and `flatten` are True by default, `multitask_return_policy` is None by default, meaning we use the transform's return policy instead

## Notes on Implementing Transforms

**Q:** What if I have a transform that _only_ operates on the data, and not the labels? Do I have to write a bunch of empty `__call__` methods for each label type?

**A:** You can accomplish this by setting the `data_only` flag to `True` in the transform body:

In [None]:
class RandomColorJitter(AbstractMultiTaskTransform):

    data_only = True

    def get_shared_parameters(self, item):
        ...

    @dispatch_by_type(np.array)
    def __call__(self, image, params):
        ...


# Equivalently:
class RandomColorJitter(AbstractDataOnlyTransform):
    ...

**Q:** What if I have a new type of label data I'd like the premade transforms to support? Do I have to edit the source code for each transform each time?

**A:** You can achieve this by **subclassing** the transform:

In [None]:
class MyRandomAffine(RandomAffine):

    @dispatch_by_type(MyLabelType)
    def __call__(self, my_label, params):
        ...

**Q:** I want my transform to use the inputs to determine its parameters. For example, I want my `RandomCrop` transform to take into account the positions of the bounding boxes, so that all the boxes are visible after the transform. How can I do this?

**A:** The `get_shared_parameters` method takes as its input the _raw, unprocessed_ input to the transform (meaning it could be a `MultiTask` object, a `Task` object, a `Label` object, or anything else). You can then use the input to calculate the parameters:

In [None]:
class BBoxSafeRandomCrop(RandomCrop):

    def get_shared_parameters(self, item):
        params = super().get_shared_parameters(item)

        if (isinstance(item, MultiTask) and item.has_object_labels()) or \
            isinstance(item, ObjectDetectionTask) or \
            isinstance(item, ObjectDetectionLabel):
            # Modify or recalculate the params
            ...

        return params
            