# Draco Debugging

In [1]:
import toml
import pandas as pd
from typing import Iterable, NamedTuple
import altair as alt

from draco import Draco

In [2]:
default_draco = Draco()
feature_names: set[str] = set(default_draco.soft_constraint_names)
weights: dict[str, int] = default_draco.weights

In [3]:
def full_pref_count_dict(draco: Draco, spec: str | Iterable[str]) -> dict[str, int]:
    """Constructs a dict of preference counts, including all the soft constraint names in the keys"""
    base: dict[str, int] = draco.count_preferences(spec) or {}
    included_feature_names = set(base.keys())
    not_included_feature_names = feature_names - included_feature_names
    return base | {feature_name: 0 for feature_name in not_included_feature_names}


def unnest_pref_count_dict(
    dct: dict[str, dict[str, int]]
) -> list[tuple[str, str, int]]:
    """Flattens the dict of dicts into a list of tuples"""
    return [
        (chart_name, pref_name, count)
        for chart_name, prefs in dct.items()
        for pref_name, count in prefs.items()
    ]


def pref_tuple_extended_with_weight(
    tpl: tuple[str, str, int]
) -> tuple[str, str, int, int]:
    """Extends a `(chart_name, pref_name, count)` tuple to include `weight` as its last element"""
    chart_name, pref_name, count = tpl
    weight = weights[f"{pref_name}_weight"]
    return chart_name, pref_name, count, weight


def pref_tuples_extended_with_weights(
    tuples: list[tuple[str, str, int]]
) -> list[tuple[str, str, int, int]]:
    return [pref_tuple_extended_with_weight(tpl) for tpl in tuples]

In [4]:
# Construct a `DataFrame` from the `example_charts` file used for analysis
with open("./data/example_charts.toml") as file:
    specs = toml.load(file)
    pref_count_dict: dict[str, dict[str, int]] = {
        chart_name: full_pref_count_dict(default_draco, spec)
        for chart_name, spec in specs.items()
    }
    pref_tuples = unnest_pref_count_dict(pref_count_dict)
    pref_tuples_with_weights = pref_tuples_extended_with_weights(pref_tuples)

# Shape: len(specs) * len(feature_names) rows x 3 columns
chart_preferences = pd.DataFrame(
    pref_tuples_with_weights, columns=["chart_name", "pref_name", "count", "weight"]
)

In [5]:
class ChartConfig(NamedTuple):
    title: str
    sort_x: alt.Sort | str
    sort_y: alt.Sort | str


# One chart gets generated for each configuration
configs: list[ChartConfig] = [
    ChartConfig(title="Sort alphabetically", sort_x="ascending", sort_y="ascending"),
    ChartConfig(
        title="Sort by count sum",
        sort_x=alt.EncodingSortField(field="count", op="sum", order="descending"),
        sort_y=alt.EncodingSortField(field="count", op="sum", order="descending"),
    ),
]

width, height = 1200, 300

## All Debugging Variants

In [6]:
def create_chart(cfg: ChartConfig) -> alt.VConcatChart:
    weight_bar = (
        alt.Chart(chart_preferences)
        .mark_bar()
        .encode(
            x=alt.X(field="pref_name", type="nominal", sort=cfg.sort_x, axis=None),
            y=alt.Y(field="weight", type="quantitative"),
            tooltip=["pref_name", "weight"],
        )
        .properties(width=width, height=height / 3, title=cfg.title)
    )
    pref_rect = (
        alt.Chart(chart_preferences)
        .mark_rect(stroke="lightgray", strokeWidth=0.25)
        .encode(
            x=alt.X(field="pref_name", type="nominal", sort=cfg.sort_x),
            y=alt.Y(field="chart_name", type="nominal", sort=cfg.sort_y),
            # Set rect color to white if `count == 0`
            color=alt.condition(
                alt.datum.count == 0,
                alt.value("white"),
                alt.Color(field="count", type="quantitative"),
            ),
            tooltip=chart_preferences.columns.tolist(),
        )
        .properties(width=width, height=height)
    )
    return alt.VConcatChart(vconcat=[weight_bar, pref_rect], spacing=0)


charts = [create_chart(cfg) for cfg in configs]
alt.VConcatChart(vconcat=charts, spacing=50).configure_axis(labelFontSize=8)

## Interactive Selection of Debugging Variants

_Interactions will work only in a Python-enabled environment!_

In [7]:
import ipywidgets as widgets
from IPython.display import display, clear_output

DEFAULT_CFG = configs[0]
chart_output: widgets.Output | None

config_selector = widgets.Dropdown(
    options=[cfg.title for cfg in configs],
    value=DEFAULT_CFG.title,
    description="Sorting:",
    disabled=False,
)


def on_config_selected(cfg_title: str):
    """Generates and displays a chart for the config identified by `cfg_title`"""
    cfg = [cfg for cfg in configs if cfg.title == cfg_title]
    if len(cfg) == 0:
        raise RuntimeError(f'No chart configuration with the title "{cfg_title}"')
    chart = create_chart(cfg[0])
    clear_output()
    display(config_selector)
    display(chart)


def handle_config_selection(event):
    """Handler registered to the `config_selector` dropdown widget"""
    if event["type"] == "change":
        cfg_title: str = event["owner"].value
        on_config_selected(cfg_title)


# Register the event handler to the dropdown
config_selector.observe(handle_config_selection)
# Display initial value
on_config_selected(DEFAULT_CFG.title)

Dropdown(description='Sorting:', options=('Sort alphabetically', 'Sort by count sum'), value='Sort alphabetica…