In [None]:
#default_exp visualizations

# Visualizations

> Functions for showing detection or segmentation results

In [None]:
# export

from icevision.all import *
from fastcore.dispatch import *
from fastai.data.all import *
from fastai.vision.all import *
from matplotlib.cm import ScalarMappable


## Object detection/instance segmentation (IceVision)

> Modified icevision functions to do what we want to.

In [None]:
#export

def plot_grid_preds_actuals_raws(
    raws, actuals, predictions, figsize=None, show=False, annotations=None, **kwargs
):
    fig, axs = plt.subplots(
        nrows=len(actuals),
        ncols=3,
        figsize=figsize or (6, 6 * len(actuals) / 2 / 0.75),
        **kwargs,
    )
    i = 0
    for im, ax in zip(zip(raws, actuals, predictions), axs.reshape(-1, 3)):
        ax[0].imshow(im[0])
        ax[0].set_title('Aerial image')
        ax[1].imshow(im[1], cmap=None)
        ax[1].set_title("Expert annotations")
        ax[2].imshow(im[2], cmap=None)
        ax[2].set_title("Predicted annotations")

        if annotations is None:
            ax[0].set_axis_off()
            ax[1].set_axis_off()
            ax[2].set_axis_off()
        else:
            ax[0].set_axis_off()
            ax[1].get_xaxis().set_ticks([])
            ax[1].set_frame_on(False)
            ax[1].get_yaxis().set_visible(False)
            ax[1].set_xlabel(annotations[i][0], ma="left")

            ax[2].get_xaxis().set_ticks([])
            ax[2].set_frame_on(False)
            ax[2].get_yaxis().set_visible(False)
            ax[2].set_xlabel(annotations[i][1], ma="left")

            i += 1

    plt.tight_layout()
    if show:
        plt.show()
    return axs

def show_raw_mask_pred(    
    samples: Union[Sequence[np.ndarray], Sequence[dict]],
    preds: Sequence[dict],
    class_map: Optional[ClassMap] = None,
    denormalize_fn: Optional[callable] = denormalize_imagenet,
    display_label: bool = True,
    display_bbox: bool = True,
    display_mask: bool = True,
    ncols: int = 1,
    figsize=None,
    show=False,
    annotations=None,
) -> None:

    if not len(samples) == len(preds):
        raise ValueError(
            f"Number of imgs ({len(samples)}) should be the same as "
            f"the number of preds ({len(preds)})"
        )

    if all(type(x) is dict for x in samples):
        raws = [
            draw_sample(
                sample=sample,
                class_map=class_map,
                display_label=False,
                display_bbox=False,
                display_mask=False,
                denormalize_fn=denormalize_fn,
            )
            for sample in samples
        ]
        actuals = [
            draw_sample(
                sample=sample,
                class_map=class_map,
                display_label=display_label,
                display_bbox=display_bbox,
                display_mask=display_mask,
                denormalize_fn=denormalize_fn,
            )
            for sample in samples
        ]

        imgs = [sample["img"] for sample in samples]
        predictions = [
            draw_pred(
                img=img,
                pred=pred,
                class_map=class_map,
                denormalize_fn=denormalize_fn,
                display_label=display_label,
                display_bbox=display_bbox,
                display_mask=display_mask,
            )
            for img, pred in zip(imgs, preds)
        ]

        plot_grid_preds_actuals_raws(
            raws, actuals, predictions, figsize=figsize, show=show, annotations=annotations
        )

    else:
        partials = [
            partial(
                show_pred,
                img=img,
                pred=pred,
                class_map=class_map,
                denormalize_fn=denormalize_fn,
                display_label=display_label,
                display_bbox=display_bbox,
                display_mask=display_mask,
                show=False,
            )
            for img, pred in zip(samples, preds)
        ]
        plot_grid(
            partials, ncols=ncols, figsize=figsize, show=show, annotations=annotations
        )