# Draco Debugging

In [6]:
import toml
import pandas as pd
from typing import Iterable, NamedTuple
import altair as alt

from draco import Draco

In [7]:
default_draco = Draco()

In [None]:
feature_names: set[str] = set(default_draco.soft_constraint_names)


def full_pref_count_dict(draco: Draco, spec: str | Iterable[str]) -> dict[str, int]:
    """Constructs a dict of preference counts, including all the soft constraint names in the keys"""
    base: dict[str, int] = draco.count_preferences(spec) or {}
    included_feature_names = set(base.keys())
    not_included_feature_names = feature_names - included_feature_names
    return base | {feature_name: 0 for feature_name in not_included_feature_names}


def unnest_pref_count_dict(
    dct: dict[str, dict[str, int]]
) -> list[tuple[str, str, int]]:
    return [
        (chart_name, pref_name, count)
        for chart_name, prefs in dct.items()
        for pref_name, count in prefs.items()
    ]


with open("./data/example_charts.toml") as file:
    specs = toml.load(file)
    pref_count_dict: dict[str, dict[str, int]] = {
        chart_name: full_pref_count_dict(default_draco, spec)
        for chart_name, spec in specs.items()
    }
    pref_tuples = unnest_pref_count_dict(pref_count_dict)

# Shape: len(specs) * len(feature_names) rows x 3 columns
chart_preferences = pd.DataFrame(
    pref_tuples, columns=["chart_name", "pref_name", "count"]
)

In [None]:
class ChartConfig(NamedTuple):
    title: str
    sort_x: alt.Sort | str
    sort_y: alt.Sort | str


# One chart gets generated for each configuration
configs: list[ChartConfig] = [
    ChartConfig(title="Sort alphabetically", sort_x="ascending", sort_y="ascending"),
    ChartConfig(
        title="Sort by count sum",
        sort_x=alt.EncodingSortField(field="count", op="sum", order="descending"),
        sort_y=alt.EncodingSortField(field="count", op="sum", order="descending"),
    ),
]

width, height = 1200, 300

In [None]:
def create_chart(cfg: ChartConfig) -> alt.Chart:
    return (
        alt.Chart(chart_preferences)
        .mark_rect(stroke="lightgray", strokeWidth=0.25)
        .encode(
            x=alt.X(field="pref_name", type="nominal", sort=cfg.sort_x),
            y=alt.Y(field="chart_name", type="nominal", sort=cfg.sort_y),
            color=alt.Color(field="count", type="quantitative"),
            tooltip=chart_preferences.columns.tolist(),
        )
        .properties(width=width, height=height, title=cfg.title)
    )


charts = [create_chart(cfg) for cfg in configs]
alt.VConcatChart(vconcat=charts).configure_axis(labelFontSize=8)