In [30]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import pandas as pd

In [129]:
n = 5 * 12 + 6
x = np.arange(n)
y = np.abs(np.random.randn(n))
y_cs = np.cumsum(y) * 1000

data = pd.DataFrame({"x": x, "y": y, "A": y_cs, "B": 2 * y_cs, "C": 3 * y_cs})

In [281]:
def sum_lines_traces(data, ycols, ycol_highlight, left, right):
    hover_template = "<b>Cumulative Cost</b>: $%{y:,.0f}"

    traces = []
    for column in ycols:
        if column == ycol_highlight:
            # plot from 0 to shade_l
            data_left = data.loc[data.index < left]
            data_right = data.loc[data.index > right]
            data_middle = data.loc[(data.index >= left) & (data.index <= right)]
            for d in [data_left, data_right, data_middle]:
                opacity = 1 if d is data_middle else 0.2
                traces.append(
                    go.Scatter(
                        x=d.index,
                        y=d[column],
                        mode="lines",
                        line=dict(color="rgba(65,40,200,1)"),
                        showlegend=False,
                        opacity=opacity,
                        hovertemplate=hover_template,
                        name=column,
                    )
                )

        else:
            traces.append(
                go.Scatter(
                    x=data.index,
                    y=data[column],
                    mode="lines",
                    line=dict(color="rgba(60,60,60,0.25)"),
                    showlegend=False,
                    hovertemplate=hover_template,
                    name=column,
                )
            )

    # add maker at left, right
    y_l = data.loc[data.index == left, ycol_highlight].values[0]
    y_r = data.loc[data.index == right, ycol_highlight].values[0]

    traces.append(
        go.Scatter(
            x=[left, right],
            y=[y_l, y_r],
            mode="markers",
            marker=dict(color="rgba(95,40,200,1)", size=6),
            showlegend=False,
            hoverinfo="skip",
        )
    )

    return traces

In [297]:
def vertical_traces(data, ycol_highlight, left, right):

    traces = []
    # plot from 0 to shade_l
    data_left = data.loc[data.index < left, [ycol_highlight]]
    data_right = data.loc[data.index > right, [ycol_highlight]]
    data_middle = data.loc[
        (data.index >= left) & (data.index <= right), [ycol_highlight]
    ]
    for d in [data_left, data_right, data_middle]:
        opacity = 1 if d is data_middle else 0.2
        traces.append(
            go.Scatter(
                x=[0] * len(d),
                y=d[ycol_highlight],
                mode="lines",
                line=dict(color="rgba(65,40,200,1)"),
                showlegend=False,
                opacity=opacity,
                hoverinfo="skip",
            )
        )

    # add marker at left, right
    y_l = data_middle[ycol_highlight].min()
    y_r = data_middle[ycol_highlight].max()

    traces.append(
        go.Scatter(
            x=[0, 0],
            y=[y_l, y_r],
            mode="markers",
            marker=dict(color="rgba(95,40,200,1)", size=6),
            showlegend=False,
            hovertemplate="<br><b>Cumulative Cost</b>: $%{y:,.0f}",
            name=ycol_highlight,
        )
    )

    traces.append(
        go.Scatter(
            x=[0],
            y=[(y_r + y_l) / 2],
            # text=[f"${y_r - y_l:,.0f}"],
            # textfont=dict(size=14, color="rgba(95,40,200,1)", weight="bold"),
            # textposition="middle center",
            hoverinfo="skip",
            mode="text",
            showlegend=False,
        )
    )

    return traces

In [334]:
def plot_trend(data, xcol, columns_included, column_highlight, left, right):
    # create a copy of data to avoid modifying the original
    data = data.copy()
    # make xcol the index
    data.set_index(xcol, inplace=True)

    # offset data to start at left
    data[columns_included] = (
        data[columns_included] - data.loc[left, columns_included]
    )
    # subplots
    fig = make_subplots(
        rows=1,
        cols=2,
        column_widths=[8, 1],
        row_heights=[8],
        shared_yaxes=True,
        horizontal_spacing=0.02,
    )

    # plot lines in the main plot

    y_l = data.loc[left, column_highlight]
    y_r = data.loc[right, column_highlight]

    duration_sum = y_r - y_l

    main_traces = sum_lines_traces(
        data, columns_included, column_highlight, left, right
    )

    vert_traces = vertical_traces(data, column_highlight, left, right)

    for trace in main_traces:
        fig.add_trace(trace, row=1, col=1)

    for trace in vert_traces:
        fig.add_trace(trace, row=1, col=2)

    for y in [y_l, y_r]:
        fig.add_hline(y=y, line_dash="dot", row=1, col="all", opacity=0.2)

    # add anotation to the right plot
    fig.add_annotation(
        x=0,
        y=(y_r + y_l) / 2,
        text=f"${duration_sum:,.0f}",
        showarrow=False,
        font=dict(size=20, color="rgba(95,40,200,1)", weight="bold"),
        xshift=15,
        yshift=0,
        row=1,
        col=2,
        textangle=90,
    )

    # add title as annotation to the left plot
    fig.add_annotation(
        text=f"Cumulative Cost of Childcare",
        font=dict(
            size=20,
            color="rgba(95,40,200,1)",
            weight="bold",
            family="Helvetica",
        ),
        xref="x domain",
        yref="y domain",
        x=0,  # Left edge of the subplot
        y=1,  # Top edge of the subplot
        xshift=20,
        yshift=-15,
        xanchor="left",
        yanchor="top",
        showarrow=False,
    )

    # hide tick labels from the right plot
    fig.update_xaxes(
        range=[-0.1, 0.1],
        row=1,
        col=2,
        showgrid=False,  # This hides the vertical gridlines
        tickformat="%d-%b",  # This formats the tick labels (e.g., "01-Jan")
    )

    # update layout

    fig.update_layout(
        font=dict(size=14),
        width=800,
        height=600,
        # ymin
        yaxis=dict(
            range=[
                data[columns_included].min().min(),
                data[columns_included].max().max(),
            ],
            showgrid=False,
        ),
        xaxis=dict(
            showgrid=False,
        ),
        xaxis_title="Age (Month)",
        hovermode="x",
    )

    # turn off grid
    for i in range(1, 3):
        fig.update_yaxes(showgrid=False, row=1, col=i)
        fig.update_xaxes(showgrid=False, row=1, col=i)

    fig.update_xaxes(showticklabels=False, row=1, col=2)

    fig.show()

In [335]:
plot_trend(data, "x", ["A", "B", "C"], "B", 6, 4 * 12 + 6)

## Old Code


In [None]:
def add_vrect(fig, l, r, x_min, x_max):
    fig.add_vrect(
        x0=x_min,
        x1=l,
        fillcolor="black",
        opacity=0.2,
        line_width=0,
        row=1,
        col=1,
    )
    fig.add_vrect(
        x0=r,
        x1=x_max,
        fillcolor="black",
        opacity=0.2,
        line_width=0,
        row=1,
        col=1,
    )


def add_hrect(fig, b, t, y_min, y_max):
    fig.add_hrect(
        y0=y_min,
        y1=b,
        fillcolor="black",
        opacity=0.2,
        line_width=0,
        row=1,
        col=1,
    )
    fig.add_hrect(
        y0=t,
        y1=y_max,
        fillcolor="black",
        opacity=0.2,
        line_width=0,
        row=1,
        col=1,
    )


def add_lines(fig, l, r, b, t):
    fig.add_vline(
        x=l, line_width=1, line_dash="dot", line_color="black", opacity=0.5
    )

    fig.add_vline(
        x=r, line_width=1, line_dash="dot", line_color="black", opacity=0.5
    )

    fig.add_hline(
        y=b, line_width=1, line_dash="dot", line_color="black", opacity=0.5
    )

    fig.add_hline(
        y=t, line_width=1, line_dash="dot", line_color="black", opacity=0.5
    )