# Preliminaries

## Imports

In [None]:
import functools
from pathlib import Path

import holoviews as hv
import panel as pn

from bridge.display.vision import PanelDetectionClassification
from bridge.primitives.element.data.cache_mechanism import CacheMechanism
from bridge.primitives.element.data.uri_components import URIComponents
from bridge.utils import pmap

hv.extension("bokeh")
pn.extension()

TMP_NOTEBOOK_ROOT = Path("/tmp/bridge-ds/tutorials")

## Load Dataset

In [None]:
from bridge.providers.vision import Coco2017Detection

root_dir = TMP_NOTEBOOK_ROOT / "coco"

provider = Coco2017Detection(root_dir, split="val", img_source="download")
ds = provider.build_dataset()
ds

# Demo: Data Processing - From Sources to Pytorch

In this demo, we'll be working with COCO-val. We began by loading it into Bridge Dataset, and we will proceed by applying data augmentations, visualizing the results, and once we're satisfied with our augmentation pipeline we will finally convert this augmented Dataset into a training-ready PyTorch dataset.

### Applying Data Augmentations
We want to apply data augmentations on our Dataset before feeding it to our model for training.
For this purpose, we have `ds.transform_samples()` which accepts **SampleTransform** objects. One of such SampleTransforms is **AlbumentationsCompose**, our adapter which allows users to use [albumentations](https://albumentations.ai/) with Dataset.

First, let's define our transforms:

In [None]:
import albumentations as A

from bridge.primitives.sample.transform.vision import AlbumentationsCompose

transforms = AlbumentationsCompose(
    [
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(p=0.5),
        A.RandomBrightnessContrast(p=0.3),
        A.RandomResizedCrop((448, 448), scale=(0.01, 0.05)),
    ],
    bbox_format="coco",
)

Let's apply these transforms on our Dataset using `transform_samples()`. Note that `transform_samples()` adheres to the Sample API, _not_ the Table API. This means that behind the scenes we iterate over all samples, rather than using a vectorized pandas implementation.

In [None]:
import random
import warnings

import numpy as np

random.seed(0)
np.random.seed(0)

# Cache the resulting augmented images into a local path ${TMP_NOTEBOOK_ROOT}/ds_augs
caches = {
    "image": CacheMechanism(URIComponents.from_str(str(TMP_NOTEBOOK_ROOT / "ds_augs"))),
}

# Function responsible for iterating and applying the SampleTransform.
# It could be as simple as `map`, but we can use a multi-process variant for better performance.
map_fn = functools.partial(pmap, backend="dataloader", progress_bar=False)

with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=UserWarning)  # hide "low contrast" warnings
    ds_augs = ds.transform_samples(
        transform=transforms,
        map_fn=map_fn,
        cache_mechanisms=caches,
        display_engine=PanelDetectionClassification(bbox_format="xywh"),
    )

After a few seconds, we have our augmented dataset. By observing the `samples` table we can see that the new images were saved locally to our directory of choice:

In [None]:
ds_augs.samples.head(3)

And we can browse this augmented Dataset just like the original one:

In [None]:
ds_augs.show()

By manually browsing our Dataset, we can see that we completely mis-parameterized the `RandomCrop` augmentation - the crops are too small!

We can confirm this by extracting statistics over the remaining annotations:

In [None]:
print(f"num annotations ds: {len(ds.annotations)}")
print(f"num annotations ds_augs: {len(ds_augs.annotations)}")

n_annotations_per_image_ds = (
    ds.annotations.groupby("sample_id")
    .size()
    # samples with no annotations won't have a group in the groupby
    .reindex(ds.samples.index.get_level_values("sample_id"), fill_value=0)
    .mean()
)
n_annotations_per_image_ds_augs = (
    ds_augs.annotations.groupby("sample_id")
    .size()
    .reindex(ds_augs.samples.index.get_level_values("sample_id"), fill_value=0)
    .mean()
)

print(f"mean num annotations per image ds: {n_annotations_per_image_ds}")
print(f"mean num annotations per image ds_augs: {n_annotations_per_image_ds_augs}")

We can see that the numbers tell the same story - we've lost many annotations. Let's fix the transform parameters and reapply them:

In [None]:
transforms = AlbumentationsCompose(
    [
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(p=0.5),
        A.RandomBrightnessContrast(p=0.3),
        A.RandomResizedCrop((448, 448), scale=(0.3, 1.0)),
    ],
    bbox_format="coco",
)

with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=UserWarning)  # hide "low contrast" warnings
    ds_augs = ds.transform_samples(
        transform=transforms,
        map_fn=map_fn,
        cache_mechanisms=caches,
        display_engine=PanelDetectionClassification(bbox_format="xywh"),
    )

In [None]:
print(f"num annotations ds: {len(ds.annotations)}")
print(f"num annotations ds_augs: {len(ds_augs.annotations)}")

n_annotations_per_image_ds = (
    ds.annotations.groupby("sample_id")
    .size()
    .reindex(ds.samples.index.get_level_values("sample_id"), fill_value=0)
    .mean()
)
n_annotations_per_image_ds_augs = (
    ds_augs.annotations.groupby("sample_id")
    .size()
    .reindex(ds_augs.samples.index.get_level_values("sample_id"), fill_value=0)
    .mean()
)

print(f"mean num annotations per image ds: {n_annotations_per_image_ds}")
print(f"mean num annotations per image ds_augs: {n_annotations_per_image_ds_augs}")

This time, we've lost significantly less annotations to the random crop operation. We can observe the samples manually as well, if we'd like:

In [None]:
ds_augs.show()

### Converting to tensors
At this point, we're satisfied with our augmented Dataset. The next step is converting this dataset into viable input for a deep learning model - that is, converting the dataset to tensors. For our engine of choice, we'll demonstrate with PyTorch, but this technique should generalize to other deep learning frameworks just as well.

NOTE: up until this point of the tutorial, we have no actual dependency on which deep learning framework we were using. All of this works just as well if our DL framework of choice were Keras or TensorFlow.

The transformation into tensors works exactly as before, with `transform_data`:

In [None]:
import warnings

from albumentations.pytorch import ToTensorV2

to_tensor_transform = AlbumentationsCompose(
    [
        A.ToRGB(),  # some COCO images are greyscale, and if not converted to RGB, they crash in `A.Normalize()`
        A.Normalize(),
        ToTensorV2(),
    ],
    bbox_format="coco",
)


with warnings.catch_warnings():
    warnings.filterwarnings(
        "ignore"
    )  # Applying A.ToRGB() on an image that is already RGB throws a warning, we'll filter these out
    ds_tensors = ds_augs.transform_samples(
        transform=to_tensor_transform,
        map_fn=map_fn,
        display_engine=None,  # the output is not images anymore, so a Panel DisplayEngine won't work
        cache_mechanisms={"image": CacheMechanism(URIComponents.from_str(str(TMP_NOTEBOOK_ROOT / "ds_tensors")))},
    )

Since we can't use `PanelDetectionClassification` rendering anymore, let's just use a few prints to make sure the data is in our required format:

In [None]:
img_data = ds_tensors.iget(0).data
print("shape:", img_data.shape, "\n")
print(img_data)

The last step is to convert this `ds_tensors` to a torch Dataset. We will do this using `PytorchEngineDataset` object, which directly inherits from `torch.utils.data.Dataset`:

In [None]:
import torch

from bridge.engines.pytorch import PytorchEngineDataset

ds_pytorch = PytorchEngineDataset(ds_tensors)

print(isinstance(ds_pytorch, torch.utils.data.Dataset))
print(type(ds_pytorch))
print(len(ds_pytorch))

In [None]:
item = ds_pytorch[0]

img = item["image"][0]
bboxes = item["bbox"]
print("Image: ")
print(img, img.shape)
print()
print("Bbox Classes: ")
print([bbox.class_label for bbox in bboxes])
print()
print("Bbox Coords: ")
print([bbox.coords for bbox in bboxes])

As we can see, every item in `PytorchEngineDataset` is a dictionary with string keys that match etypes (in our case, 'image' and 'bbox'); the values are lists of objects where the image is a torch.Tensor object, and the bboxes are a class we created, but you can use whatever you like. 