# Visualizing pipelines of filters 

In [None]:
import plotly.express as px
from plotly.subplots import make_subplots
import numpy as np
import cv2
from glob import glob

pipelines = {
    "Standard" : [
        lambda img : cv2.cvtColor(img, cv2.COLOR_BGR2RGB), # Warning: original pipeline skips this step!
        lambda img : cv2.cvtColor(img, cv2.COLOR_RGBA2RGB),
        lambda img : cv2.cvtColor(img, cv2.COLOR_BGR2GRAY),
        lambda img : cv2.threshold(img,127,255,cv2.THRESH_BINARY)[1],
        # lambda img : cv2.findContours(img, mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_SIMPLE), # This filter is a no-op
    ],
    # "Legacy pipeline" : [
    #     lambda img : cv2.cvtColor(img, cv2.COLOR_BGR2RGB),
    #     lambda img : cv2.pyrMeanShiftFiltering(img, sp=10, sr=100),
    #     lambda img : cv2.cvtColor(img, cv2.COLOR_BGR2GRAY),
    #     lambda img : cv2.threshold(img,127,255,cv2.THRESH_BINARY)[1],
    # ],
}

original_image_names = glob("dataset_cc2/*")
original_images = [cv2.imread(x) for x in original_image_names]

row_cnt = max([len(pipeline) for pipeline in pipelines.values()]) + 1
col_cnt = len(pipelines)*len(original_images)
padding = 100
avg_img_size = 200
fig_height = row_cnt*avg_img_size + padding*(row_cnt+1)
fig_width = col_cnt*avg_img_size + padding*(col_cnt+1)

fig = make_subplots(
    rows=row_cnt,
    cols=col_cnt,
    column_titles = [
        f"{pipeline_name}<br>{image_name.split('/')[1]}"
        for image_name in original_image_names
        for pipeline_name in pipelines.keys()
    ])

for image_idx, original_image in enumerate(original_images):
    for pipeline_idx, (_, pipeline) in enumerate(pipelines.items()):
        intermediates = [np.copy(original_image)]
        for filter in pipeline:
            intermediates.append(filter(intermediates[-1]))

        for intermediate_idx, img in enumerate(intermediates):
            processed_img_cnt = image_idx*len(pipelines)
            fig.add_trace(
                px.imshow(img).data[0],
                row = 1 + intermediate_idx,
                col = 1 + processed_img_cnt + pipeline_idx)

fig.update_xaxes(visible=False, showticklabels=False)
fig.update_yaxes(visible=False, showticklabels=False)
fig.update_coloraxes(showscale=False)
fig.update_layout(height=fig_height, width=fig_width)
fig.show()