In [None]:
import json
import os

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyrootutils
import seaborn as sns

In [None]:
PROJECT_ROOT = pyrootutils.find_root(
    search_from=os.path.abspath(""), indicator=".project-root"
)

grammars_dir = PROJECT_ROOT / "data" / "grammars"
grammar_stats_filename = "grammar_stats.json"
samples_stats_filename = "filtered_samples_stats.json"

grammars = [
    f
    for f in grammars_dir.iterdir()
    if (f.is_dir())
    and (f / grammar_stats_filename).exists()
    and (f / samples_stats_filename).exists()
]

stats = []
for g in grammars:
    g_stats = json.load(open(g / grammar_stats_filename))
    s_stats = json.load(open(g / samples_stats_filename))
    merged = {**g_stats, **s_stats}
    stats.append(merged)
stats_df = pd.DataFrame(stats)

# Filter grammars to only keep those with at least 90% coverage of positive & negative
# samples to ensure we aren't testing models on languages which can't generate strings
# of the relevant lengths.
good_stats_df = (
    stats_df[stats_df.coverage > 0.9]
    .sort_values(by="grammar_name", ascending=True)
    .reset_index(drop=True)
)
del stats_df

good_stats_df

In [None]:
good_stats_df.iloc[51:60]

In [None]:
fig = plt.figure(figsize=(13, 3))
gs = gridspec.GridSpec(1, 4)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)
ax3 = fig.add_subplot(gs[0, 3], sharey=ax0)

axes = [ax0, ax1, ax2, ax3]
hparams = [
    "n_terminals",
    "n_nonterminals",
    "n_lexical_productions",
    "n_nonlexical_productions",
]
for ax, hparam in zip(axes, hparams):
    sns.histplot(
        data=good_stats_df,
        x=hparam,
        binwidth=100,
        ax=ax,
    )
    ax.set_title(hparam)
    ax.set_ylabel(None)

In [None]:
fig = plt.figure(figsize=(10, 2.5))
gs = gridspec.GridSpec(1, 4, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)
ax3 = fig.add_subplot(gs[0, 3], sharey=ax0)

axes = [ax0, ax1, ax2, ax3]

y_task = "compression_ratio"
x_tasks = [
    "n_terminals",
    "n_nonterminals",
    "n_lexical_productions",
    "n_nonlexical_productions",
]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_df,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

fig.suptitle("Compression ratio vs. Grammar HParams")

In [None]:
fig = plt.figure(figsize=(10, 3.3))
gs = gridspec.GridSpec(1, 3, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)

axes = [ax0, ax1, ax2]

y_task = "n_terminals"
x_tasks = ["n_nonterminals", "n_lexical_productions", "n_nonlexical_productions"]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_df,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

In [None]:
fig = plt.figure(figsize=(10, 3.3))
gs = gridspec.GridSpec(1, 3, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)

axes = [ax0, ax1, ax2]

y_task = "n_nonterminals"
x_tasks = ["n_terminals", "n_lexical_productions", "n_nonlexical_productions"]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_df,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

In [None]:
fig = plt.figure(figsize=(10, 3.3))
gs = gridspec.GridSpec(1, 3, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)

axes = [ax0, ax1, ax2]

y_task = "n_lexical_productions"
x_tasks = ["n_terminals", "n_nonterminals", "n_nonlexical_productions"]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_df,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

In [None]:
fig = plt.figure(figsize=(10, 3.3))
gs = gridspec.GridSpec(1, 3, wspace=0.1)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
ax2 = fig.add_subplot(gs[0, 2], sharey=ax0)

axes = [ax0, ax1, ax2]

y_task = "n_nonlexical_productions"
x_tasks = ["n_terminals", "n_nonterminals", "n_lexical_productions"]

for i, ax in enumerate(axes):
    sns.scatterplot(
        data=good_stats_df,
        x=x_tasks[i],
        y=y_task,
        ax=ax,
    )
    ax.set_xlabel(x_tasks[i])
    ax.set_ylabel(y_task)

for ax in axes[1:]:
    plt.setp(ax.get_yticklabels(), visible=False)
    ax.set_ylabel(None)

### Hyperparameter Correlations

In [None]:
hyp_corr = good_stats_df[
    [
        "n_terminals",
        "n_nonterminals",
        "n_lexical_productions",
        "n_nonlexical_productions",
    ]
].corr()
hyp_mask = np.triu(np.ones_like(hyp_corr, dtype=bool))

_ = sns.heatmap(
    hyp_corr,
    # mask=hyp_mask,
    annot=True,
    cmap="coolwarm",
    center=0,
    vmin=-1,
    vmax=1,
)