# Imports

In [None]:
import altair as alt
from altair_saver import save

# Altair Settings

In [None]:
alt.__version__

In [None]:
alt.renderers.names()

In [None]:
def roboto():
    font = "Roboto"
    color = "#2F2F2F"

    return {
        "config": {
            "title": {"font": font, "color": color},
            "axis": {
                "labelFont": font,
                "titleFont": font,
                "labelColor": color,
                "tickColor": color,
                "titleColor": color,
                "domainColor": color,
            },
            "header": {
                "labelFont": font,
                "titleFont": font,
                "labelColor": color,
                "titleColor": color,
            },
            "legend": {
                "labelFont": font,
                "titleFont": font,
                "labelColor": color,
                "titleColor": color,
            },
            "rule": {"color": color},
        }
    }


alt.themes.register("roboto", roboto)
alt.themes.enable("roboto")

# Score Bands Chart

In [None]:
source = [
    {
        "start": 0.4,
        "end": 0.6,
        "label": "Uncertain (for review)",
        "id": "uncertain",
    },
    {"start": 0.0, "end": 0.4, "label": "Non-Fraud", "id": "neg"},
    {"start": 0.6, "end": 1.0, "label": "Fraud", "id": "pos",},
]

classification_threshold = [{"threshold": 0.5}]

In [None]:
source = alt.pd.DataFrame(source)

classification_threshold = alt.pd.DataFrame(classification_threshold)

In [None]:
def score_bands_chart(
    source, classification_threshold, h=100, w=500, marker_color="#2F2F2F"
):
    ct = classification_threshold.iloc[0, 0]

    base = alt.Chart(source)

    domain = source.sort_values(by=["start"])["label"].tolist()
    range_ = ["#368F8B", "#EFF2EF", "#EE6C4D"]

    vline_shape = "M 0 1 L 0 -1"
    hline_shape = "M 1 0 L -1 0"

    rect = base.mark_rect().encode(
        x=alt.X(
            "start:Q",
            axis=alt.Axis(
                domainColor="transparent",
                title=None,
                ticks=False,
                format=".1",
                values=[
                    source["start"].min(),
                    source.loc[source["id"] == "uncertain", "start"].iloc[0],
                    source.loc[source["id"] == "uncertain", "end"].iloc[0],
                    source["end"].max(),
                ],
            ),
        ),
        x2=alt.X2("end:Q"),
        color=alt.Color(
            "label:N",
            legend=alt.Legend(title="Output Label", titleFontWeight="normal"),
            scale=alt.Scale(domain=domain, range=range_),
        ),
    )

    rule = (
        alt.Chart(classification_threshold)
        .mark_rule()
        .encode(
            x=alt.X("threshold:Q"),
            size=alt.SizeValue(1.0),
            color=alt.Color(
                "threshold:Q",
                scale=alt.Scale(domain=[ct], range=[marker_color]),
                legend=alt.Legend(
                    title="Classification Threshold",
                    titleFontWeight="normal",
                    format=".1",
                    type="symbol",
                    symbolType=vline_shape,
                ),
            ),
        )
    )

    chart = rect + rule

    chart = (
        chart.configure_axis(grid=False)
        .configure_view(strokeWidth=0)
        .properties(width=w, height=h)
    )

    return chart

In [None]:
chart = score_bands_chart(source, classification_threshold)

chart

In [None]:
?save

In [None]:
save(
    chart,
    "score_bands_chart.png",
    method="selenium",
    webdriver="chrome",
    scale_factor=6.0,
)

---