# Grammar and Sample Statistics

In [None]:
import colorsys
import json
import os
import re
from typing import Any

import matplotlib as mpl
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyrootutils
import seaborn as sns

In [None]:
PROJECT_ROOT = pyrootutils.find_root(
    search_from=os.path.abspath(""), indicator=".project-root"
)

In [None]:
FIGURES_DIR = PROJECT_ROOT / "notebooks" / "figures"

PAPER_WIDTH_IN = 5.5


def darken(
    color: str
    | tuple[float, float, float]
    | dict[str, Any]
    | sns.palettes._ColorPalette,
    by: float = 0.2,
):
    """
    Darken a color by provided amount.
    """

    def _darken_color(c: str | tuple[float, float, float], by: float):
        by = min(max(0, by), 1)
        pct_darken = 1 - by

        if isinstance(c, str):
            c = sns.color_palette([c])[0]

        for c_i in c:
            if c_i > 1:
                c_i /= 255
        c_hls = colorsys.rgb_to_hls(c[0], c[1], c[2])
        # Darken the color by reducing the lightness

        c_hls = (
            c_hls[0],  # hue
            c_hls[1] * pct_darken,  # lightness
            c_hls[2],  # saturation
        )
        # Convert back to RGB
        c_rgb = colorsys.hls_to_rgb(c_hls[0], c_hls[1], c_hls[2])
        return c_rgb

    if isinstance(color, dict):
        # If color is a dictionary, assume it's a palette
        # and darken each color in the palette
        return {k: _darken_color(v, by) for k, v in color.items()}
    elif isinstance(color, sns.palettes._ColorPalette):
        colors = [_darken_color(c, by) for c in color]
        return sns.palettes._ColorPalette(colors)
    else:
        return _darken_color(color, by)


# For heatmaps, correlations, -1 to 1 scales, etc
CMAP_HEATMAP = "vlag"

# For any plots where color differentiates sample type
PALETTE_SAMPLE_TYPE = {
    "positive": darken("#ffcc66"),
    "negative": "#5c5cff",
    "unknown": "#000000",
}

# For any plots where color differentiates model
PALETTE_MODEL = darken(
    {
        # "gpt-4.1-nano": sns.color_palette("rocket_r", n_colors=3)[0],
        # "gpt-4.1-mini": sns.color_palette("rocket_r", n_colors=3)[1],
        # "gpt-4.1": sns.color_palette("rocket_r", n_colors=3)[2],
        "gpt-4.1-nano": sns.cubehelix_palette(start=0.5, rot=-0.5, n_colors=4)[0],
        "gpt-4.1-mini": sns.cubehelix_palette(start=0.5, rot=-0.5, n_colors=4)[1],
        "gpt-4.1": sns.cubehelix_palette(start=0.5, rot=-0.5, n_colors=4)[2],
        # "o4-mini": sns.color_palette("crest", n_colors=2)[0],
        # "o3": sns.color_palette("crest", n_colors=2)[1],
        "o4-mini": sns.color_palette("YlOrBr", n_colors=2)[0],
        "o3": sns.color_palette("YlOrBr", n_colors=2)[1],
        "gemma-3-1b": sns.cubehelix_palette(n_colors=5)[0],
        "gemma-3-4b": sns.cubehelix_palette(n_colors=5)[1],
        "gemma-3-12b": sns.cubehelix_palette(n_colors=5)[2],
        "gemma-3-27b": sns.cubehelix_palette(n_colors=5)[3],
    }
)

PALETTE_SCORE = {
    "Weighted F1": sns.color_palette("terrain", n_colors=2, desat=0.8)[0],
    "Macro F1": sns.color_palette("terrain", n_colors=2, desat=0.8)[1],
}

PALETTES = {
    "model": PALETTE_MODEL,
    "sample_type": PALETTE_SAMPLE_TYPE,
    "score": PALETTE_SCORE,
}

# For marking at-chance baselines
COLOR_AT_CHANCE = "#ff0000"  # Red
ALPHA_AT_CHANCE = 0.5

# Bar Chart settings
BAR_EDGE_COLOR = "black"
BAR_EDGE_WIDTH = 0.8


def display_palette(palette: dict[str, Any] | sns.palettes._ColorPalette):
    if isinstance(palette, sns.palettes._ColorPalette):
        display(palette)
    else:
        colors = list(palette.values())
        display(sns.color_palette(colors))


def filter_by_alpha(
    keys: list[str],
    ax,
    palette: dict[str, Any] | None = None,
    alpha=0.3,
    highlight: str | list[str] | None = None,
):
    alphas = defaultdict(lambda: alpha)
    if highlight is not None:
        if isinstance(highlight, str):
            highlight = [highlight]
        for h in highlight:
            alphas[h] = 1

    if palette is None:
        # look through all the palettes in PALETTE; find one whose keys match
        # the keys passed in; if found, use that palette; otherwise, throw an error
        for palette_name, palette_dict in PALETTES.items():
            if all(k in palette_dict for k in keys):
                palette = palette_dict
                break
        else:
            raise ValueError(f"No matching palette found for keys: {keys}")

    face_colors = ax.collections[0].get_facecolors()
    face_colors[:, 3] = alpha
    for key in keys:
        key_color = palette[key]

        # Get indices of face_colors whose first 3 values match the model color
        indices = [
            i
            for i, color in enumerate(face_colors)
            if (color[0], color[1], color[2]) == key_color[:3]
        ]
        for i in indices:
            face_colors[i][3] = alphas[key]
    ax.collections[0].set_facecolor(face_colors)


def legend_format(
    ax: mpl.axes._axes.Axes | sns.FacetGrid,
    keys: list[str] | None = None,
    title: str | None = None,
    **kwargs,
):
    if isinstance(ax, sns.FacetGrid):
        fg = ax
        ax = fg.ax

        _ = fg.legend.remove()

    # Legend Formatting
    handles, labels = ax.get_legend_handles_labels()

    if keys is not None:
        if "gpt-4.1" in keys:
            spacing_locs = [3, 6]
        else:
            ValueError(f"Unsure how to space legend for {keys=}")

        spacer = mpatches.Patch(alpha=0, linewidth=0)
        for sloc in spacing_locs:
            handles.insert(sloc, spacer)
            labels.insert(sloc, "")

    _ = ax.legend(handles, labels)
    _ = ax.get_legend().set_frame_on(False)

    if title is not None:
        _ = ax.get_legend().set_title(title)

    if "loc" not in kwargs:
        kwargs["loc"] = "upper left"
    if "bbox_to_anchor" not in kwargs:
        kwargs["bbox_to_anchor"] = (1, 1)
    _ = sns.move_legend(ax, **kwargs)


# MARK: Printing, experimentation, etc
display(sns.color_palette("terrain", n_colors=2, desat=0.8))
display(darken(sns.color_palette("terrain", n_colors=2, desat=0.8), by=0.2))

display_palette(PALETTE_MODEL)
display_palette(PALETTE_SAMPLE_TYPE)

In [None]:
def is_cyclic(g_string: str) -> bool:
    """Check for cycles in the graph string."""

    productions = {}
    for line in g_string.strip().split("\n"):
        match = re.match(r"(\w+)\s*->\s*(\w+)\s+(\w+)", line.strip())
        if match:
            lhs = match.group(1)
            rhs1 = match.group(2)
            rhs2 = match.group(3)
            productions.setdefault(lhs, []).extend([rhs1, rhs2])

    visited = set()
    recursion_stack = set()

    def check_cycle(node):
        visited.add(node)
        recursion_stack.add(node)

        for neighbor in productions.get(node, []):
            if neighbor not in visited:
                if check_cycle(neighbor):
                    return True
            elif neighbor in recursion_stack:
                return True
        recursion_stack.remove(node)
        return False

    if "S" in productions:
        return check_cycle("S")
    else:
        return False

In [None]:
rules1 = """
S -> B C
C -> D E
D -> F G
F -> C H
"""

rules2 = """
S -> A B
A -> C D
C -> E F
"""

rules3 = """
S -> A B
A -> B C
B -> C A
"""

rules4 = """
S -> A B
A -> C D
B -> E F
"""

print(f"Rules 1 have cycles: {is_cyclic(rules1)}")
print(f"Rules 2 have cycles: {is_cyclic(rules2)}")
print(f"Rules 3 have cycles: {is_cyclic(rules3)}")
print(f"Rules 4 have cycles: {is_cyclic(rules4)}")

## Load Grammars

In [None]:
PROJECT_ROOT = pyrootutils.find_root(
    search_from=os.path.abspath(""), indicator=".project-root"
)

grammars_dir = PROJECT_ROOT / "data" / "grammars"
grammar_stats_filename = "grammar_stats.json"
samples_stats_filename = "filtered_samples_stats.json"

grammars = [
    f
    for f in grammars_dir.iterdir()
    if (f.is_dir())
    and (f / grammar_stats_filename).exists()
    and (f / samples_stats_filename).exists()
]

stats = []
for g in grammars:
    g_stats = json.load(open(g / grammar_stats_filename))
    s_stats = json.load(open(g / samples_stats_filename))

    # Check to see if grammar is cyclic
    g_file = g / f"{g.name}.cfg"
    g_str = open(g_file).read()
    g_stats["is_cyclic"] = is_cyclic(g_str)

    merged = {**g_stats, **s_stats}
    stats.append(merged)
stats_df = pd.DataFrame(stats)

# Filter grammars to only keep those with at least 90% coverage of positive & negative
# samples to ensure we aren't testing models on languages which can't generate strings
# of the relevant lengths.
good_stats_df = (
    stats_df[stats_df.coverage > 0.8]
    .sort_values(by="grammar_name", ascending=True)
    .reset_index(drop=True)
)

good_stats_df

## Full Dataset

In [None]:
fig = plt.figure(figsize=(13, 3))
gs = gridspec.GridSpec(1, 4)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)
ax3 = fig.add_subplot(gs[0, 3], sharey=ax0)

axes = [ax0, ax1, ax2, ax3]
hparams = [
    "n_terminals",
    "n_nonterminals",
    "n_lexical_productions",
    "n_nonlexical_productions",
]
for ax, hparam in zip(axes, hparams):
    sns.histplot(
        data=good_stats_df,
        x=hparam,
        binwidth=100,
        ax=ax,
    )
    ax.set_title(hparam)
    ax.set_ylabel(None)

In [None]:
fig = plt.figure(figsize=(10, 2.5))
gs = gridspec.GridSpec(1, 4, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)
ax3 = fig.add_subplot(gs[0, 3], sharey=ax0)

axes = [ax0, ax1, ax2, ax3]

y_task = "compression_ratio"
x_tasks = [
    "n_terminals",
    "n_nonterminals",
    "n_lexical_productions",
    "n_nonlexical_productions",
]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_df,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

fig.suptitle("Compression ratio vs. Grammar HParams")

In [None]:
fig = plt.figure(figsize=(10, 3.3))
gs = gridspec.GridSpec(1, 3, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)

axes = [ax0, ax1, ax2]

y_task = "n_terminals"
x_tasks = ["n_nonterminals", "n_lexical_productions", "n_nonlexical_productions"]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_df,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)
    ax.set_yscale("log")
    ax.set_xscale("log")

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

In [None]:
fig = plt.figure(figsize=(10, 3.3))
gs = gridspec.GridSpec(1, 3, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)

axes = [ax0, ax1, ax2]

y_task = "n_nonterminals"
x_tasks = ["n_terminals", "n_lexical_productions", "n_nonlexical_productions"]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_df,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)
    ax.set_xscale("log")
    ax.set_yscale("log")

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

In [None]:
fig = plt.figure(figsize=(10, 3.3))
gs = gridspec.GridSpec(1, 3, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)

axes = [ax0, ax1, ax2]

y_task = "n_lexical_productions"
x_tasks = ["n_terminals", "n_nonterminals", "n_nonlexical_productions"]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_df,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)
    ax.set_xscale("log")
    ax.set_yscale("log")

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

In [None]:
fig = plt.figure(figsize=(10, 3.3))
gs = gridspec.GridSpec(1, 3, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)

axes = [ax0, ax1, ax2]

y_task = "n_nonlexical_productions"
x_tasks = ["n_terminals", "n_nonterminals", "n_lexical_productions"]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_df,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)
    ax.set_xscale("log")
    ax.set_yscale("log")

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

In [None]:
hyp_corr = good_stats_df[
    [
        "n_terminals",
        "n_nonterminals",
        "n_lexical_productions",
        "n_nonlexical_productions",
    ]
].corr()
hyp_mask = np.triu(np.ones_like(hyp_corr, dtype=bool))

_ = sns.heatmap(
    hyp_corr,
    # mask=hyp_mask,
    annot=True,
    cmap="vlag_r",
    center=0,
    vmin=-1,
    vmax=1,
)

## Small Subset

To ensure we have a set of grammars whose hyperparameters are not too correlated, we do grid searches for grammars for each hp between 1 and 100.

In [None]:
good_stats_small = good_stats_df = (
    stats_df[
        (stats_df.coverage > 0.5)
        & (stats_df.n_terminals < 100)
        & (stats_df.n_nonterminals < 100)
        & (stats_df.n_lexical_productions < 100)
        & (stats_df.n_nonlexical_productions < 100)
    ]
    .sort_values(by="grammar_name", ascending=True)
    .reset_index(drop=True)
)

good_stats_small

In [None]:
fig, ax = plt.subplots(figsize=(5, 3))

_ = sns.histplot(
    data=good_stats_small,
    x="coverage",
    hue="is_cyclic",
    binwidth=0.01,
    ax=ax,
)

_ = ax.set_xlabel("% Coverage")

In [None]:
fig = plt.figure(figsize=(13, 3))
gs = gridspec.GridSpec(1, 4)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)
ax3 = fig.add_subplot(gs[0, 3], sharey=ax0)

axes = [ax0, ax1, ax2, ax3]
hparams = [
    "n_terminals",
    "n_nonterminals",
    "n_lexical_productions",
    "n_nonlexical_productions",
]
for ax, hparam in zip(axes, hparams):
    sns.histplot(
        data=good_stats_small,
        x=hparam,
        binwidth=10,
        ax=ax,
    )
    ax.set_title(hparam)
    ax.set_ylabel(None)

In [None]:
fig = plt.figure(figsize=(10, 2.5))
gs = gridspec.GridSpec(1, 4, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)
ax3 = fig.add_subplot(gs[0, 3], sharey=ax0)

axes = [ax0, ax1, ax2, ax3]

y_task = "compression_ratio"
x_tasks = [
    "n_terminals",
    "n_nonterminals",
    "n_lexical_productions",
    "n_nonlexical_productions",
]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_small,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
        hue="is_cyclic",
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)
    ax.set_yscale("log")

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

_ = fig.suptitle("Compression ratio vs. Grammar HParams")
_ = axes[0].set_ylabel("Compression Ratio")

In [None]:
fig = plt.figure(figsize=(10, 2.5))
gs = gridspec.GridSpec(1, 4, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)
ax3 = fig.add_subplot(gs[0, 3], sharey=ax0)

axes = [ax0, ax1, ax2, ax3]

y_task = "compression_ratio"
x_tasks = [
    "n_terminals",
    "n_nonterminals",
    "n_lexical_productions",
    "n_nonlexical_productions",
]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_small,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
        hue="coverage",
        palette=sns.color_palette("ch:s=.25,rot=-.25", as_cmap=True),
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)
    ax.set_yscale("log")

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

_ = fig.suptitle("Compression ratio vs. Grammar HParams")
_ = axes[0].set_ylabel("Compression Ratio")

In [None]:
fig = plt.figure(figsize=(10, 3.3))
gs = gridspec.GridSpec(1, 3, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)

axes = [ax0, ax1, ax2]

y_task = "n_terminals"
x_tasks = ["n_nonterminals", "n_lexical_productions", "n_nonlexical_productions"]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_small, x=x_tasks[i], y=y_task, ax=ax, hue="is_cyclic"
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

_ = fig.suptitle(f"n_terminalals vs. other HParams (n={len(good_stats_small)})")

In [None]:
fig = plt.figure(figsize=(10, 3.3))
gs = gridspec.GridSpec(1, 3, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)

axes = [ax0, ax1, ax2]

y_task = "n_nonterminals"
x_tasks = ["n_terminals", "n_lexical_productions", "n_nonlexical_productions"]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_small,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
        hue="is_cyclic",
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

In [None]:
fig = plt.figure(figsize=(10, 3.3))
gs = gridspec.GridSpec(1, 3, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)

axes = [ax0, ax1, ax2]

y_task = "n_lexical_productions"
x_tasks = ["n_terminals", "n_nonterminals", "n_nonlexical_productions"]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_small,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
        hue="is_cyclic",
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

In [None]:
fig = plt.figure(figsize=(10, 3.3))
gs = gridspec.GridSpec(1, 3, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)

axes = [ax0, ax1, ax2]

y_task = "n_nonlexical_productions"
x_tasks = ["n_terminals", "n_nonterminals", "n_lexical_productions"]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_small,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
        hue="is_cyclic",
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

In [None]:
hyp_corr_small = good_stats_small[
    [
        "n_terminals",
        "n_nonterminals",
        "n_lexical_productions",
        "n_nonlexical_productions",
    ]
].corr()
hyp_mask_small = np.triu(np.ones_like(hyp_corr_small, dtype=bool))

fig, ax = plt.subplots(figsize=(6, 5))

_ = sns.heatmap(
    hyp_corr_small,
    mask=hyp_mask_small,
    annot=True,
    cmap="vlag_r",
    center=0,
    vmin=-1,
    vmax=1,
    ax=ax,
)

_ = ax.set_title(f"Correlation of HParams ≤ 100 (n={len(good_stats_small)})")

### Subsample small dataset

In [None]:
def get_dense_regions(df: pd.DataFrame, cols: list[str], n_bins=5):
    discretized_cols = []
    for col in cols:
        bins = np.linspace(df[col].min(), df[col].max(), n_bins + 1)
        discretized_col = pd.cut(df[col], bins=bins, labels=False, include_lowest=True)
        discretized_cols.append(discretized_col)

    df_discrete = pd.DataFrame(dict(zip(cols, discretized_cols)))
    region_counts = df_discrete.value_counts().sort_values(ascending=False)
    threshold = region_counts.mean() * 2
    overrepresented_regions = region_counts[region_counts > threshold].index.tolist()
    return df_discrete, overrepresented_regions


def subsample(df, cols, target_max_corr=0.1, grid_bins=5, removal_fraction=0.1):
    """Iteratively subsamples dataframe to lower the pairwise correlation"""

    df_subsampled = df.copy()
    while True:
        corr_matrix = df_subsampled[cols].corr()
        max_abs_corr = np.abs(
            corr_matrix.values[np.triu_indices_from(corr_matrix, k=1)]
        ).max()
        if max_abs_corr <= target_max_corr:
            break

        df_discrete, dense_regions = get_dense_regions(
            df_subsampled, cols=cols, n_bins=grid_bins
        )
        if not dense_regions:
            break

        most_correlated_pair = None
        max_corr = -1
        for i in range(len(cols)):
            for j in range(i + 1, len(cols)):
                if np.abs(corr_matrix.loc[cols[i], cols[j]]) > max_corr:
                    max_corr = np.abs(corr_matrix.loc[cols[i], cols[j]])
                    most_correlated_pair = (cols[i], cols[j])

        if most_correlated_pair:
            col1, col2 = most_correlated_pair
            indices_to_remove = []
            for region in dense_regions:
                region_filter = True
                for i, val in enumerate(region):
                    region_filter &= df_discrete.iloc[:, i] == val
                region_indices = df_subsampled[region_filter].index.tolist()
                if region_indices:
                    sub_df = df_subsampled.loc[region_indices]
                    if not sub_df.empty:
                        sub_df["correlation_contribution"] = (
                            sub_df[col1] - sub_df[col1].mean()
                        ) * (sub_df[col2] - sub_df[col2].mean())
                        indices_to_remove.extend(
                            sub_df.sort_values(
                                by="correlation_contribution", ascending=False
                            )
                            .head(int(len(sub_df) * removal_fraction))
                            .index.tolist()
                        )

            if indices_to_remove:
                df_subsampled = df_subsampled.drop(
                    index=list(set(indices_to_remove))
                ).reset_index(drop=True)
            else:
                break

        else:
            break

        if df_subsampled.empty:
            break

    return df_subsampled

In [None]:
corr_cols = [
    "n_terminals",
    "n_nonterminals",
    "n_lexical_productions",
    "n_nonlexical_productions",
]

good_stats_small_subsampled = subsample(
    good_stats_small,
    cols=corr_cols,
    target_max_corr=0.0,
    grid_bins=4,
    removal_fraction=0.99,
)

len(good_stats_small_subsampled)

In [None]:
fig, ax = plt.subplots(figsize=(5, 3))

_ = sns.histplot(
    data=good_stats_small_subsampled,
    x="coverage",
    hue="is_cyclic",
    binwidth=0.01,
    ax=ax,
)

_ = ax.set_xlabel("% Coverage")

In [None]:
fig = plt.figure(figsize=(10, 2.5))
gs = gridspec.GridSpec(1, 4, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)
ax3 = fig.add_subplot(gs[0, 3], sharey=ax0)

axes = [ax0, ax1, ax2, ax3]

y_task = "compression_ratio"
x_tasks = [
    "n_terminals",
    "n_nonterminals",
    "n_lexical_productions",
    "n_nonlexical_productions",
]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_small_subsampled,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
        hue="is_cyclic",
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)
    ax.set_yscale("log")

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

_ = fig.suptitle("Compression ratio vs. Grammar HParams")
_ = axes[0].set_ylabel("Compression Ratio")

In [None]:
fig = plt.figure(figsize=(10, 2.5))
gs = gridspec.GridSpec(1, 4, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)
ax3 = fig.add_subplot(gs[0, 3], sharey=ax0)

axes = [ax0, ax1, ax2, ax3]

y_task = "compression_ratio"
x_tasks = [
    "n_terminals",
    "n_nonterminals",
    "n_lexical_productions",
    "n_nonlexical_productions",
]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_small_subsampled,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
        hue="coverage",
        palette=sns.color_palette("ch:s=.25,rot=-.25", as_cmap=True),
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)
    ax.set_yscale("log")

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

_ = fig.suptitle("Compression ratio vs. Grammar HParams")
_ = axes[0].set_ylabel("Compression Ratio")

In [None]:
hyp_corr_small_sub = good_stats_small_subsampled[
    [
        "n_terminals",
        "n_nonterminals",
        "n_lexical_productions",
        "n_nonlexical_productions",
    ]
].corr()
hyp_mask_small_sub = np.triu(np.ones_like(hyp_corr_small_sub, dtype=bool))

fig, ax = plt.subplots(figsize=(6, 5))

_ = sns.heatmap(
    hyp_corr_small_sub,
    mask=hyp_mask_small_sub,
    annot=True,
    cmap="vlag_r",
    center=0,
    vmin=-1,
    vmax=1,
    ax=ax,
)

_ = ax.set_title(
    f"Correlation of HParams ≤ 100, Subsampled  (n={len(good_stats_small_subsampled)})"
)

In [None]:
fig = plt.figure(figsize=(10, 3.3))
gs = gridspec.GridSpec(1, 3, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0, sharex=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0, sharex=ax0)

axes = [ax0, ax1, ax2]

y_task = "n_terminals"
x_tasks = ["n_nonterminals", "n_lexical_productions", "n_nonlexical_productions"]

a_max = 0
for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_small_subsampled, x=x_tasks[i], y=y_task, ax=ax, hue="is_cyclic"
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)
    ax.set_aspect("equal", adjustable="box")

    if ax.get_xlim()[1] > a_max:
        a_max = ax.get_xlim()[1]
    if ax.get_ylim()[1] > a_max:
        a_max = ax.get_ylim()[1]

axes[0].set_xlim(0, a_max)
axes[0].set_ylim(0, a_max)

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

_ = fig.suptitle(
    f"n_terminalals vs. other HParams (Subsampled, n={len(good_stats_small_subsampled)})"
)

In [None]:
fig = plt.figure(figsize=(10, 3.3))
gs = gridspec.GridSpec(1, 3, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)

axes = [ax0, ax1, ax2]

y_task = "n_nonterminals"
x_tasks = ["n_terminals", "n_lexical_productions", "n_nonlexical_productions"]

a_max = 0
for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_small_subsampled, x=x_tasks[i], y=y_task, ax=ax, hue="is_cyclic"
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)

    if ax.get_xlim()[1] > a_max:
        a_max = ax.get_xlim()[1]
    if ax.get_ylim()[1] > a_max:
        a_max = ax.get_ylim()[1]

axes[0].set_xlim(0, a_max)
axes[0].set_ylim(0, a_max)

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

_ = fig.suptitle(
    f"n_nonterminals vs. other HParams (Subsampled, n={len(good_stats_small_subsampled)})"
)

In [None]:
def new_subsample(
    df: pd.DataFrame, end_points: int, cols: list[str], seed: int = 42
) -> pd.DataFrame:
    np.random.seed(seed)

    data = df[cols].to_numpy()
    data_scaled = (data - data.min(axis=0)) / (
        data.max(axis=0) - data.min(axis=0) + 1e-8
    )

    # start with a random subset
    current_idx = np.random.choice(len(df), size=end_points, replace=False)

    def mean_pairwise_dist(subset: np.ndarray) -> float:
        diffs = subset[:, np.newaxis, :] - subset[np.newaxis, :, :]  # (n, n, d)
        dists = np.sqrt(np.sum(diffs**2, axis=-1))  # (n, n)
        return np.sum(dists) / (len(subset) * (len(subset) - 1) + 1e-8)

    def objective(indices: np.ndarray):
        subset = data_scaled[indices]
        corr = np.corrcoef(subset, rowvar=False)
        corr[np.isnan(corr)] = 0
        off_diag_corr = np.sum(np.abs(corr - np.eye(len(cols))))
        # coverage term: maximize mean pairwise distance
        mean_dist = mean_pairwise_dist(subset)
        return off_diag_corr - mean_dist  # lower is better

    # greedy local search
    for _ in range(1000):
        i = np.random.randint(end_points)
        j = np.random.randint(len(df))
        if j in current_idx:
            continue
        new_idx = current_idx.copy()
        new_idx[i] = j
        if objective(new_idx) < objective(current_idx):
            current_idx = new_idx

    return df.iloc[np.unique(current_idx)].copy()

In [None]:
corr_cols = [
    "n_terminals",
    "n_nonterminals",
    "n_lexical_productions",
    "n_nonlexical_productions",
]

df_new_subsampled = new_subsample(
    good_stats_small,
    end_points=99,
    cols=corr_cols,
)

len(df_new_subsampled)

In [None]:
new_hpy_corr_small = df_new_subsampled[corr_cols].corr()
new_hpy_corr_small = new_hpy_corr_small.rename(
    {
        "n_terminals": r"$n_\text{term}$",
        "n_nonterminals": r"$n_\text{nonterm}$",
        "n_lexical_productions": r"$n_\text{lex}$",
        "n_nonlexical_productions": r"$n_\text{nonlex}$",
    },
    axis=1,
).rename(
    {
        "n_terminals": r"$n_\text{term}$",
        "n_nonterminals": r"$n_\text{nonterm}$",
        "n_lexical_productions": r"$n_\text{lex}$",
        "n_nonlexical_productions": r"$n_\text{nonlex}$",
    },
    axis=0,
)
new_hpy_corr_mask = np.triu(np.ones_like(new_hpy_corr_small, dtype=bool))

with sns.plotting_context("notebook"):
    fig, ax = plt.subplots(figsize=(6, 5))

    _ = sns.heatmap(
        new_hpy_corr_small,
        mask=new_hpy_corr_mask,
        annot=True,
        cmap=CMAP_HEATMAP,
        center=0,
        vmin=-1,
        vmax=1,
        fmt=".2f",
        ax=ax,
    )

    _ = ax.set_title("Correlation of G100 HParams")

plt.savefig(
    FIGURES_DIR / "g100_corr.pdf",
    bbox_inches="tight",
)

In [None]:
fig = plt.figure(figsize=(10, 3.3))
gs = gridspec.GridSpec(1, 3, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0, sharex=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0, sharex=ax0)

axes = [ax0, ax1, ax2]

y_task = "n_terminals"
x_tasks = ["n_nonterminals", "n_lexical_productions", "n_nonlexical_productions"]

a_max = 0
for i, ax in enumerate(axes):
    sns.scatterplot(
        data=df_new_subsampled, x=x_tasks[i], y=y_task, ax=ax, hue="is_cyclic"
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)
    ax.set_aspect("equal", adjustable="box")

    if ax.get_xlim()[1] > a_max:
        a_max = ax.get_xlim()[1]
    if ax.get_ylim()[1] > a_max:
        a_max = ax.get_ylim()[1]

axes[0].set_xlim(0, a_max)
axes[0].set_ylim(0, a_max)

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

_ = fig.suptitle(
    f"n_terminalals vs. other HParams (Subsampled, n={len(df_new_subsampled)})"
)

In [None]:
fig = plt.figure(figsize=(10, 2.5))
gs = gridspec.GridSpec(1, 4, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)
ax3 = fig.add_subplot(gs[0, 3], sharey=ax0)

axes = [ax0, ax1, ax2, ax3]

y_task = "compression_ratio"
x_tasks = [
    "n_terminals",
    "n_nonterminals",
    "n_lexical_productions",
    "n_nonlexical_productions",
]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=df_new_subsampled,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
        hue="is_cyclic",
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)
    ax.legend_.remove()

_ = fig.suptitle(f"Compression ratio vs. Grammar HParams (n={len(df_new_subsampled)})")
_ = axes[0].set_ylabel("Compression Ratio")

In [None]:
fig, ax = plt.subplots(figsize=(5, 3))

_ = sns.histplot(
    data=df_new_subsampled,
    x="coverage",
    hue="is_cyclic",
    binwidth=0.01,
    ax=ax,
)

_ = ax.set_xlabel("% Coverage")
_ = ax.set_title(f"Coverage Histogram (n={len(df_new_subsampled)})")

## Larger Subset

In [None]:
good_stats_large = good_stats_df = (
    stats_df[
        (stats_df.coverage > 0.5)
        & (
            (stats_df.n_terminals > 100)
            | (stats_df.n_nonterminals > 100)
            | (stats_df.n_lexical_productions > 100)
            | (stats_df.n_nonlexical_productions > 100)
        )
        & (stats_df.n_nonterminals < 10000)
        & (stats_df.n_terminals < 400)
        & (stats_df.n_lexical_productions < 10000)
        & (stats_df.n_nonlexical_productions < 10000)
    ]
    .sort_values(by="grammar_name", ascending=True)
    .reset_index(drop=True)
)

good_stats_large

In [None]:
fig, ax = plt.subplots(figsize=(5, 3))

_ = sns.histplot(
    data=good_stats_large,
    x="coverage",
    hue="is_cyclic",
    binwidth=0.01,
    ax=ax,
)

_ = ax.set_xlabel("% Coverage")

In [None]:
fig = plt.figure(figsize=(10, 2.5))
gs = gridspec.GridSpec(1, 4, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)
ax3 = fig.add_subplot(gs[0, 3], sharey=ax0)

axes = [ax0, ax1, ax2, ax3]

y_task = "compression_ratio"
x_tasks = [
    "n_terminals",
    "n_nonterminals",
    "n_lexical_productions",
    "n_nonlexical_productions",
]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_large,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
        hue="is_cyclic",
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)
    ax.set_yscale("log")

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

_ = fig.suptitle("Compression ratio vs. Grammar HParams")
_ = axes[0].set_ylabel("Compression Ratio")

In [None]:
corr_cols = [
    "n_terminals",
    "n_nonterminals",
    "n_lexical_productions",
    "n_nonlexical_productions",
]

df_large_subsampled = new_subsample(
    good_stats_large,
    end_points=99,
    cols=corr_cols,
)

len(df_large_subsampled)

In [None]:
new_hpy_corr_large = df_large_subsampled[corr_cols].corr()

new_hpy_corr_large = new_hpy_corr_large.rename(
    {
        "n_terminals": r"$n_\text{term}$",
        "n_nonterminals": r"$n_\text{nonterm}$",
        "n_lexical_productions": r"$n_\text{lex}$",
        "n_nonlexical_productions": r"$n_\text{nonlex}$",
    },
    axis=1,
).rename(
    {
        "n_terminals": r"$n_\text{term}$",
        "n_nonterminals": r"$n_\text{nonterm}$",
        "n_lexical_productions": r"$n_\text{lex}$",
        "n_nonlexical_productions": r"$n_\text{nonlex}$",
    },
    axis=0,
)

new_hpy_corr_mask = np.triu(np.ones_like(new_hpy_corr_large, dtype=bool))

with sns.plotting_context("notebook"):
    fig, ax = plt.subplots(figsize=(4, 3))

    _ = sns.heatmap(
        new_hpy_corr_large,
        mask=new_hpy_corr_mask,
        annot=True,
        cmap=CMAP_HEATMAP,
        center=0,
        vmin=-1,
        vmax=1,
        fmt=".2f",
        ax=ax,
    )

    _ = ax.set_title("Correlation of G400 HParams")

plt.savefig(
    FIGURES_DIR / "g400_corr.pdf",
    bbox_inches="tight",
)

In [None]:
with sns.plotting_context("notebook", font_scale=0.8):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(PAPER_WIDTH_IN, 2.8))
    cbar_ax = fig.add_axes([0.92, 0.01, 0.04, 0.8])

    heatmap_args = {
        "mask": new_hpy_corr_mask,
        "annot": True,
        "cmap": CMAP_HEATMAP,
        "center": 0,
        "vmin": -1,
        "vmax": 1,
        "fmt": ".2f",
    }

    sns.heatmap(
        data=new_hpy_corr_small,
        ax=ax1,
        **heatmap_args,
        cbar=False,
    )
    sns.heatmap(
        data=new_hpy_corr_large,
        ax=ax2,
        cbar_ax=cbar_ax,
        cbar=True,
        **heatmap_args,
    )

    ax1.set_title("Correlation of G100 HParams")
    ax2.set_title("Correlation of G400 HParams")

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)

plt.savefig(
    FIGURES_DIR / "g_corr.pdf",
    bbox_inches="tight",
)

In [None]:
new_hpy_corr

In [None]:
fig = plt.figure(figsize=(10, 3.3))
gs = gridspec.GridSpec(1, 3, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0, sharex=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0, sharex=ax0)

axes = [ax0, ax1, ax2]

y_task = "n_terminals"
x_tasks = ["n_nonterminals", "n_lexical_productions", "n_nonlexical_productions"]

a_max = 0
for i, ax in enumerate(axes):
    sns.scatterplot(
        data=df_large_subsampled, x=x_tasks[i], y=y_task, ax=ax, hue="is_cyclic"
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)
    ax.set_aspect("equal", adjustable="box")

    if ax.get_xlim()[1] > a_max:
        a_max = ax.get_xlim()[1]
    if ax.get_ylim()[1] > a_max:
        a_max = ax.get_ylim()[1]

axes[0].set_xlim(0, a_max)
axes[0].set_ylim(0, a_max)

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

_ = fig.suptitle(
    f"n_terminalals vs. other HParams (Subsampled, n={len(df_large_subsampled)})"
)

In [None]:
fig = plt.figure(figsize=(10, 2.5))
gs = gridspec.GridSpec(1, 4, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)
ax3 = fig.add_subplot(gs[0, 3], sharey=ax0)

axes = [ax0, ax1, ax2, ax3]

y_task = "compression_ratio"
x_tasks = [
    "n_terminals",
    "n_nonterminals",
    "n_lexical_productions",
    "n_nonlexical_productions",
]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=df_large_subsampled,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
        hue="is_cyclic",
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)
    ax.legend_.remove()

_ = fig.suptitle(
    f"Compression ratio vs. Grammar HParams (n={len(df_large_subsampled)})"
)
_ = axes[0].set_ylabel("Compression Ratio")

In [None]:
fig, ax = plt.subplots(figsize=(5, 3))

_ = sns.histplot(
    data=df_large_subsampled,
    x="coverage",
    hue="is_cyclic",
    binwidth=0.01,
    ax=ax,
)

_ = ax.set_xlabel("% Coverage")

In [None]:
df_large_subsampled["grammar_name"].to_list()