In [1]:
def pre_process(df):
    # Simplify the workers name
    df["worker_name"] = df.worker.map(str).map(
        lambda x: "".join(x.split("-")[1:]).split(".")[0]
    )

    # Rename function
    func_map = {
        "read_img": "Read",
        "save_results": "Write",
        "save_histogram": "Write",
        "increment": "Compute",
        "calculate_histogram": "Compute",
        "combine_histogram": "Compute",
        "run_participant": "Participant",
        "run_group": "Group",
        "flatten": "Compute",
    }
    #     df.func = df.func.apply(lambda x: func_map[x])

    # Simplify the thread number for each worker
    thread_worker = {
        w + "::" + str(t): i + 1
        for w in df.worker_name.unique()
        for i, t in enumerate(df[df.worker_name == w].thread.unique())
    }
    df["worker_thread"] = df.worker_name.map(str) + "::" + df.thread.map(str)
    df["thread_number"] = df.worker_thread.map(lambda x: thread_worker[x])
    df["worker_thread"] = (
        df.worker_name.map(str) + "::thread" + df.thread_number.map(str)
    )
    df = df.sort_values(by=["worker_name", "thread_number"], ascending=[False, True])
    return df

In [2]:
def gantt_tracker(
    df,
    *,
    pre_process=None,
    group,
    x_limit=None,
    save_name=None,
    framework,
    xaxis_label,
    yaxis_label,
):
    """Create an interactive gantt chart from a pandas dataframe.

    Parameters
    ----------
    df : pandas.Dataframe
        Data to plot.
    pre_process : func
        Function to pre-process the dataframe.
    group : string, optional
        Column name of the element to group together.
    x_limit: int, optional
        Maximum value for the x axis.
    save_name : str, optional
        Filename for the gantt chart.
    framework : str
        Name of the framework from which the data were collected.
        Currently only support Dask and Spark.
    xaxis_label : str
        Label for the x axis.
    yaxis_label : str
        Label for the y axis.

    Returns
    -------
    str
        DONE
    """
    from bokeh.models import CustomJS, ColumnDataSource, Grid, LinearAxis, Plot, Range1d
    from bokeh.models.annotations import Legend, LegendItem
    from bokeh.models.glyphs import Quad
    from bokeh.models.tools import (
        BoxZoomTool,
        HoverTool,
        PanTool,
        ResetTool,
        SaveTool,
        TapTool,
        WheelZoomTool,
    )
    from bokeh.io import curdoc, output_file, output_notebook, show
    from bokeh.palettes import Colorblind8

    # Verify that the framework is supported.
    try:
        if framework.lower() == "dask":
            pass
        elif framework.lower() == "spark":
            try:
                if "process" not in df.columns:
                    raise ValueError
                df["thread"] = df["process"]
            except ValueError:
                print(
                    f"fatal error : dataframe for spark must contain a process column."
                )
                return
        else:
            raise ValueError

    except ValueError:
        print(f"fatal error : {framework} is not a supported framework.")
        return

    # Pre-processing of the dataframe.
    if pre_process:
        df = pre_process(df)

    # Make sure the dataframe is value with the function standard.
    _MUST_HAVE_COLUMN = ["func", "start", "end", "filename", "thread", group]

    for column_name in _MUST_HAVE_COLUMN:
        try:
            if column_name not in df.columns:
                raise ValueError
        except ValueError:
            print(
                f"fatal error: the dataframe must contain '{column_name}' after the pre-processing."
            )
            return

    if x_limit is None:
        x_limit = df.end.max()

    plot = Plot(
        plot_width=1250 if save_name is not None else 800,
        plot_height=700 if save_name is not None else 600,
        x_range=Range1d(-x_limit * 0.05, x_limit * 1.05, bounds="auto"),
        y_range=Range1d(
            -max(len(df[group].unique()) * 0.05, 1),
            len(df[group].unique()) * 1.05,
            bounds="auto",
        ),
    )

    # Group the dataframe by user defined group.
    # Create label and associate an y-axis value for each group.
    y = 0
    labels = []
    for i, x in enumerate(df.groupby(group, sort=False)):
        labels.append(x[0])
        df.loc[df.index.isin(x[1].index), "bottom"] = y - 0.5
        df.loc[df.index.isin(x[1].index), "top"] = y + 0.5
        y += 1

    # Plot overhead
    overhead = Quad(
        left=0,
        right=df.end.max(),
        top=(y - 0.5),
        bottom=(-0.5),
        hatch_pattern="@",
        hatch_color="red",
        hatch_alpha=0.3,
        fill_color="red",
        fill_alpha=0.1,
        line_color="red",
        line_width=0.75,
    )

    plot.add_glyph(ColumnDataSource({}), overhead)

    # Define color map for the functions.
    glyphs = list()
    for i, x in enumerate(sorted(df.func.unique())):
        df.loc[df.func == x, "color"] = Colorblind8[i]

        glyphs.append(
            plot.add_glyph(
                ColumnDataSource({}),
                Quad(
                    fill_color=Colorblind8[i],
                    fill_alpha=0.66,
                    line_color=Colorblind8[i],
                    line_width=0.75,
                ),
            )
        )
    df["original_color"] = df["color"]

    df["runtime"] = df.end - df.start

    source = ColumnDataSource(df)

    glyph = Quad(
        left="start",
        right="end",
        top="top",
        bottom="bottom",
        fill_color="color",
        fill_alpha=0.66,
        line_color="color",
        line_width=0.75,
    )

    l = plot.add_glyph(source, glyph)

    # Legend
    legend = Legend(
        items=[
            LegendItem(label=func, renderers=[glyphs[i]])
            for i, func in enumerate(sorted(df.func.unique()))
        ]
    )
    plot.add_layout(legend, "above")
    plot.legend.orientation = "horizontal"

    # Axis
    xaxis = LinearAxis()
    plot.add_layout(xaxis, "below")
    plot.xaxis.axis_label = xaxis_label

    yaxis = LinearAxis()
    plot.add_layout(yaxis, "left")
    plot.yaxis.axis_label = yaxis_label
    plot.yaxis.major_label_text_font_size = (
        "6pt"  # Reduce font size to fit all group together.
    )

    plot.add_layout(Grid(dimension=0, ticker=xaxis.ticker))
    plot.add_layout(Grid(dimension=1, ticker=yaxis.ticker))

    # Set y axis tick label
    plot.yaxis.ticker = list(range(0, len(labels)))
    plot.yaxis.major_label_overrides = {
        k: v for k, v in zip(range(0, len(labels)), labels)
    }

    # Hover tool
    hover = HoverTool(
        tooltips=[
            ("filename", "@filename"),
            ("worker", f"@{group}"),
            ("function", "@func"),
            ("runtime", "@runtime{%8.3f sec}"),
            ("start time", "@start{%8.3f sec}"),
            ("end time", "@end{%8.3f sec}"),
        ],
        formatters={
            "runtime": "printf",
            "start": "printf",
            "end": "printf",
        },
        attachment="left",
    )

    # Tap tool custom select
    cb_click = CustomJS(
        args=dict(source=source),
        code="""
        const inds = source.selected.indices;
        const d = source.data;

        for (var i = 0; i < d['color'].length; i++){
            d['color'][i] = d['original_color'][i]
        }

        if (inds.length == 0)
            return;

        same_file = []
        for (var i = 0; i < d['color'].length; i++){
            if (d['filename'][i] == d['filename'][inds[0]]){
                same_file.push(i)
            }
        }

        for (var i = 0; i < same_file.length; i++){
            d['color'][same_file[i]] = "firebrick"
        }

        source.selected.indices = same_file
        source.change.emit();
    """,
    )
    source.selected.js_on_change("indices", cb_click)

    ## Tool
    plot.add_tools(BoxZoomTool())
    plot.add_tools(hover)
    plot.add_tools(PanTool())
    plot.add_tools(ResetTool())
    plot.add_tools(SaveTool())
    plot.add_tools(TapTool(callback=cb_click))
    plot.add_tools(WheelZoomTool())

    curdoc().add_root(plot)

    # Display mode
    if save_name:
        output_file(f"interactive-figures/{save_name}.html")
    else:
        output_notebook()

    show(plot)

    return "DONE"

In [3]:
import pandas as pd

filename = "../inc/data-1/results-spark_inc-2node.csv"

col_name = ["func", "start", "end", "filename", "worker", "thread", "process"]
df = pd.read_csv(filename, header=None, names=col_name)

gantt_tracker(
    df,
    pre_process=pre_process,
    group="worker_thread",
    x_limit=1400,
    save_name="spark-histo-showcase",
    framework="spark",
    xaxis_label="Time [s]",
    yaxis_label="Worker",
)

'DONE'