### Test out Altair `extract_data`
* https://altair-viz.github.io/user_guide/transform/index.html

In [None]:
__all__ = ["apply", "extract_data", "transform_chart"]

In [None]:
def apply(
    df: pd.DataFrame,
    transform: Union[alt.Transform, List[alt.Transform]],
    inplace: bool = False,
) -> pd.DataFrame:
    """Apply transform or transforms to dataframe.

    Parameters
    ----------
    df : pd.DataFrame
    transform : list|dict
        A transform specification or list of transform specifications.
        Each specification must be valid according to Altair's transform
        schema.
    inplace : bool
        If True, then dataframe may be modified in-place. Default: False.

    Returns
    -------
    df_transformed : pd.DataFrame
        The transformed dataframe.

    Example
    -------
    >>> import pandas as pd
    >>> data = pd.DataFrame({'x': range(5), 'y': list('ABCAB')})
    >>> chart = alt.Chart(data).transform_aggregate(sum_x='sum(x)', groupby=['y'])
    >>> apply(data, chart.transform)
       y  sum_x
    0  A      3
    1  B      5
    2  C      2
    """
    if not inplace:
        df = df.copy()
    if transform is alt.Undefined:
        return df
    return visit(transform, df)

In [12]:
def extract_data(
    chart: alt.Chart, apply_encoding_transforms: bool = True
) -> pd.DataFrame:
    """Extract transformed data from a chart.

    This only works with data and transform defined at the
    top level of the chart.

    Parameters
    ----------
    chart : alt.Chart
        The chart instance from which the data and transform
        will be extracted
    apply_encoding_transforms : bool
        If True (default), then apply transforms specified within an
        encoding as well as those specified directly in the transforms
        attribute.

    Returns
    -------
    df_transformed : pd.DataFrame
        The extracted and transformed dataframe.

    Example
    -------
    >>> import pandas as pd
    >>> data = pd.DataFrame({'x': range(5), 'y': list('ABCAB')})
    >>> chart = alt.Chart(data).mark_bar().encode(x='sum(x)', y='y')
    >>> extract_data(chart)
       y  sum_x
    0  A      3
    1  B      5
    2  C      2
    """
    if apply_encoding_transforms:
        chart = extract_transform(chart)
    return apply(to_dataframe(chart.data, chart), chart.transform)

In [13]:
data = pd.DataFrame({"x": range(5), "y": list("ABCAB")})

In [14]:
chart = alt.Chart(data).mark_bar().encode(x="sum(x)", y="y")

In [16]:
extract_data(chart)

Unnamed: 0,y,sum_x
0,A,3
1,B,5
2,C,2


In [17]:
def extract_data_altair(chart):
    chart_dict = chart.to_dict()
    encoding = chart_dict["datasets"]
    df1 = pd.DataFrame(encoding)

    column = df1.columns[0]
    normalized_df = pd.json_normalize(df1[column])
    # Combine the original DataFrame with the extracted values DataFrame
    df2 = pd.concat([df1.drop(column, axis=1), normalized_df], axis=1)
    return df2

In [23]:
def heatmap(
    df: pd.DataFrame,
    color_col: str,
    title: str,
    subtitle1: str,
    subtitle2: str,
    subtitle3: str,
):
    df = df.assign(
        time_period=df.time_period.str.replace("_", " ").str.title()
    ).reset_index(drop=True)

    # Grab original column that wasn't categorized
    original_col = color_col.replace("_cat", "")

    tooltip_cols = [
        "direction_id",
        "time_period",
        "route_combined_name",
        "organization_name",
        color_col,
        original_col,
    ]

    chart = (
        alt.Chart(df)
        .mark_rect(size=30)
        .encode(
            x=alt.X(
                "yearmonthdate(service_date):O",
                axis=alt.Axis(labelAngle=-45, format="%b %Y"),
                title=["Grouped by Direction ID", "Service Date"],
            ),
            y=alt.Y("time_period:O", title=["Time Period"]),
            xOffset=alt.X(f"direction_id:N", title="Direction ID"),
            color=alt.Color(
                f"{color_col}:N",
                title=labeling(color_col),
                scale=alt.Scale(range=cp.CALITP_SEQUENTIAL_COLORS),
            ),
            tooltip=tooltip_cols,
        )
        .properties(
            title={"text": [title], "subtitle": [subtitle1, subtitle2, subtitle3]},
            width=500,
            height=300,
        )
    )

    text = chart.mark_text(baseline="middle").encode(
        alt.Text("direction_id"), color=alt.value("white")
    )

    final_chart = chart + text
    return final_chart