In [None]:
import sys

sys.path.append("../")

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import data.drawings.make_tasks as drawing_tasks
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from src.analysis_utilities import IterativeExperimentAnalyzer
from src.config_builder import ExperimentType

In [None]:
%config InlineBackend.figure_format = 'retina'

In [None]:
sns.set_theme(style="whitegrid", font_scale=1.25)

# Single-domain analysis

In [None]:
# EXPERIMENT_NAME = "gg_drawings"
# DOMAIN = "drawings_nuts_bolts"
# DOMAIN = "drawings_furniture"
# DOMAIN = "drawings_dials"
# DOMAIN = "drawings_wheels"

EXPERIMENT_NAME = "gg_laps_domains_dsl_descriptions"
DOMAIN = "clevr"
# DOMAIN = "re2"

COMPUTE_LIKELIHOODS = False

FIGURES_DIR = os.path.join("figures", EXPERIMENT_NAME)
FIGURES_DOMAIN_DIR = os.path.join("figures", EXPERIMENT_NAME, DOMAIN)
os.makedirs(FIGURES_DOMAIN_DIR, exist_ok=True)

analyzer = IterativeExperimentAnalyzer(
    experiment_name=EXPERIMENT_NAME,
    allow_incomplete_results=True,
    compute_likelihoods=COMPUTE_LIKELIHOODS,
)

In [None]:
analyzer.get_available_experiment_types(domain=DOMAIN)

In [None]:
# experiment_types = [
#     ExperimentType.ORACLE.value,
#     ExperimentType.ORACLE_TRAIN_TEST.value,
#     ExperimentType.STITCH.value,
#     ExperimentType.STITCH_CODEX.value,
#     ExperimentType.STITCH_CODEX_LANGUAGE.value,
#     ExperimentType.STITCH_CODEX_LANGUAGE_ORIGIN_RANDOM_TEST.value,
# ]

experiment_types = None

df = analyzer.get_results_for_domain(domain=DOMAIN, experiment_types=experiment_types)

In [None]:
analyzer.plot_description_length(domain=DOMAIN, df=df)
plt.savefig(os.path.join(FIGURES_DOMAIN_DIR, f"description_length.png"), dpi=300)

In [None]:
analyzer.plot_description_length(
    domain=DOMAIN, df=df, plot_type="lineplot", logscale=True
)

In [None]:
analyzer.plot_n_frontiers(domain=DOMAIN, df=df)

## What programs does Codex generate?

In [None]:
df_codex = analyzer.get_codex_programs_for_domain(DOMAIN, use_results_by_query=True)

In [None]:
df_codex[(df_codex.experiment_type == "stitch_codex") & (df_codex.batch_size == 5) & (df_codex.seed == 111) & (df_codex.origin != "train") & (df_codex.valid)]["program"].nunique()

In [None]:
df_unique_counts = df_codex.query("origin == 'codex'").drop_duplicates(subset=["program"]).groupby(["experiment_type", "batch_size", "seed"]).sum().reset_index()
df_unique_counts

In [None]:
# Darken each color in the palette by 25%
from PIL import ImageColor
DARKEN_RATIO = 0.75
PALLETE_DARKENED = {k: tuple(r * DARKEN_RATIO * 1/256 for r in ImageColor.getcolor(hex_str, "RGB")) for k, hex_str in analyzer.EXPERIMENT_TYPES_PALETTE.items()}

g = sns.barplot(data=analyzer.format_dataframe_camera(df_unique_counts), x=analyzer.COL_NAMES_CAMERA["batch_size"], y="valid", hue=analyzer.COL_NAMES_CAMERA["experiment_type"], palette=analyzer.EXPERIMENT_TYPES_PALETTE)
g = sns.barplot(data=analyzer.format_dataframe_camera(df_unique_counts), x=analyzer.COL_NAMES_CAMERA["batch_size"], y="match_train", hue=analyzer.COL_NAMES_CAMERA["experiment_type"], palette=PALLETE_DARKENED)

g.legend_.remove()

In [None]:
# plt.title("Percentage of valid programs")
sns.barplot(
    data=analyzer.format_dataframe_camera(df_codex).query("origin != 'train'"),
    x=analyzer.COL_NAMES_CAMERA["batch_size"],
    y="valid",
    hue=analyzer.COL_NAMES_CAMERA["experiment_type"],
    palette=analyzer.EXPERIMENT_TYPES_PALETTE,
);

In [None]:
# plt.title("Percentage of valid programs")
sns.pointplot(
    data=analyzer.format_dataframe_camera(df_codex).query("origin != 'train'"),
    x=analyzer.COL_NAMES_CAMERA["batch_size"],
    y="valid",
    hue=analyzer.COL_NAMES_CAMERA["experiment_type"],
    palette=analyzer.EXPERIMENT_TYPES_PALETTE,
);

In [None]:
plt.title("Count of unique programs")

df_tmp1 = (
    df_codex.query("origin == 'train'")
    .groupby(["experiment_type", "batch_size", "seed"])
    .nunique()
    .reset_index()
)
df_tmp2 = (
    df_codex.query(f"origin == 'codex'")
    .groupby(["experiment_type", "batch_size", "seed"])
    .nunique()
    .reset_index()
)
df_tmp3 = (
    df_codex.query(f"origin == 'codex' & match_train")
    .groupby(["experiment_type", "batch_size", "seed"])
    .nunique()
    .reset_index()
)
df_tmp4 = (
    df_codex.query(f"origin == 'codex' & ~match_train")
    .groupby(["experiment_type", "batch_size", "seed"])
    .nunique()
    .reset_index()
)
df_tmp1["origin"] = "train"
df_tmp2["origin"] = "codex (overall)"
df_tmp3["origin"] = "codex (copied from train)"
df_tmp4["origin"] = "codex (original)"

df_tmp = pd.concat([df_tmp2, df_tmp3, df_tmp4], axis=0).reset_index()

sns.lineplot(data=df_tmp, x="batch_size", y="program", style="origin", hue="experiment_type");
# sns.barplot(data=df_tmp, x="batch_size", y="program", hue="experiment_type");


# fig = sns.catplot(data=df_tmp, kind="point", x="batch_size", y="program", hue="origin", col="experiment_type");
# fig.set(xscale="log")

In [None]:
df_codex

In [None]:
df_codex.query("origin == 'codex'").groupby(["experiment_type", "batch_size", "seed"]).sum()

In [None]:
df_codex = analyzer.get_codex_programs_for_domain(DOMAIN, use_results_by_query=True)

In [None]:
df_codex.loc[df_codex["origin"] == "train", "experiment_type"] = "train"

In [None]:
sns.violinplot(
    data=analyzer.format_dataframe_camera(df_codex),
    x=analyzer.COL_NAMES_CAMERA["batch_size"],
    y="program_str_len",
    hue=analyzer.COL_NAMES_CAMERA["experiment_type"],
);

In [None]:
sns.catplot(
    kind="violin",
    col=analyzer.COL_NAMES_CAMERA["experiment_type"],
    data=analyzer.format_dataframe_camera(df_codex),
    x=analyzer.COL_NAMES_CAMERA["batch_size"],
    y="program_str_len",
    hue="origin",
);

In [None]:
plt.title(f"{DOMAIN}: Count of programs per prompt")
sns.pointplot(
    data=analyzer.format_dataframe_camera(df_codex)
    .query("origin == 'train'")
    .groupby(
        [
            analyzer.COL_NAMES_CAMERA["batch_size"],
            analyzer.COL_NAMES_CAMERA["experiment_type"],
            "seed",
            "query_id",
        ]
    )
    .count()
    .reset_index(),
    x=analyzer.COL_NAMES_CAMERA["batch_size"],
    y="program",
    hue=analyzer.COL_NAMES_CAMERA["experiment_type"],
    palette=analyzer.EXPERIMENT_TYPES_PALETTE,
)

# Invention analysis

In [None]:
df_libraries = analyzer.get_library_inventions(DOMAIN)
df_libraries

In [None]:
def compute_overlap_metrics(
    df_libraries, experiment_type_a, experiment_type_b, random_seed_a, random_seed_b, batch_size_a = None, batch_size_b = None
):
    batch_sizes = sorted(
        df_libraries[
            (df_libraries["experiment_type"] == experiment_type_a)
            & (df_libraries["random_seed"] == random_seed_a)
        ]["batch_size"].unique()
    )

    data = []

    for batch_size in batch_sizes:
        set_a = set(
            df_libraries[
                (df_libraries["experiment_type"] == experiment_type_a)
                & (df_libraries["random_seed"] == random_seed_a)
                & (df_libraries["batch_size"] == (batch_size_a if batch_size_a else batch_size))
            ]["dreamcoder"]
        )

        set_b = set(
            df_libraries[
                (df_libraries["experiment_type"] == experiment_type_b)
                & (df_libraries["random_seed"] == random_seed_b)
                & (df_libraries["batch_size"] == (batch_size_b if batch_size_b else batch_size))
            ]["dreamcoder"]
        )

        data.append(
            {
                "batch_size": batch_size,
                f"{experiment_type_a}_only": len(set_a - set_b),
                f"{experiment_type_b}_only": len(set_b - set_a),
                f"overlap": len(set_a.intersection(set_b)),
            }
        )

    return data

In [None]:
# overlap_metrics = compute_overlap_metrics(
#     df_libraries, ExperimentType.ORACLE, ExperimentType.STITCH, 111, 111
# )

overlap_metrics = compute_overlap_metrics(
    df_libraries, ExperimentType.STITCH, ExperimentType.STITCH, 111, 111, batch_size_a = 200,
)

In [None]:
overlap_metrics

In [None]:
def get_df_overlap(df_libraries, experiment_type_base = ExperimentType.ORACLE_TRAIN_TEST, random_seed_base = 111, batch_size_base = None):
    data = []
    for experiment_type in [
        ExperimentType.ORACLE,
        ExperimentType.ORACLE_TRAIN_TEST,
        ExperimentType.STITCH,
        ExperimentType.STITCH_CODEX,
        ExperimentType.STITCH_CODEX_LANGUAGE,
        ExperimentType.STITCH_CODEX_LANGUAGE_ORIGIN_RANDOM_TEST,
    ]:

        for seed in df_libraries[df_libraries["experiment_type"] == experiment_type][
            "random_seed"
        ].unique():
            overlap_metrics = compute_overlap_metrics(
                df_libraries,
                experiment_type_a=experiment_type_base,
                experiment_type_b=experiment_type,
                random_seed_a=random_seed_base,
                random_seed_b=seed,
                batch_size_a=batch_size_base,
            )
            for result in overlap_metrics:
                data.append(
                    {
                        "experiment_type": experiment_type.value,
                        "random_seed": seed,
                        "batch_size": result["batch_size"],
                        "overlap": result["overlap"],
                        "baseline": experiment_type != ExperimentType.STITCH,
                    }
                )
    df_overlap = pd.DataFrame(data)
    return df_overlap

In [None]:
df_overlap = get_df_overlap(df_libraries, experiment_type_base = ExperimentType.ORACLE)

fig = sns.barplot(
    data=analyzer.format_dataframe_camera(df_overlap[df_overlap.experiment_type == ExperimentType.ORACLE]),
    x=analyzer.COL_NAMES_CAMERA["batch_size"],
    y="overlap",
    hue=analyzer.COL_NAMES_CAMERA["experiment_type"],
    palette=analyzer.EXPERIMENT_TYPES_PALETTE,
)

fig = sns.barplot(
    data=analyzer.format_dataframe_camera(df_overlap[df_overlap.experiment_type != ExperimentType.ORACLE]),
    x=analyzer.COL_NAMES_CAMERA["batch_size"],
    y="overlap",
    hue=analyzer.COL_NAMES_CAMERA["experiment_type"],
    palette=analyzer.EXPERIMENT_TYPES_PALETTE,
)
sns.despine()

fig.set_ylabel("Oracle inventions discovered")

lgd = plt.legend(bbox_to_anchor=(1.0, 1.0), loc="upper left")
# plt.savefig(os.path.join(FIGURES_DOMAIN_DIR, "oracle_discovery_barplot.pdf"), dpi=300, bbox_inches="tight")

In [None]:
df_overlap = get_df_overlap(df_libraries, experiment_type_base = ExperimentType.STITCH, batch_size_base = 200)

fig = sns.barplot(
    data=analyzer.format_dataframe_camera(df_overlap),
    x=analyzer.COL_NAMES_CAMERA["batch_size"],
    y="overlap",
    hue=analyzer.COL_NAMES_CAMERA["experiment_type"],
    palette=analyzer.EXPERIMENT_TYPES_PALETTE,
)

# fig = sns.barplot(
#     data=analyzer.format_dataframe_camera(df_overlap[df_overlap.experiment_type != ExperimentType.STITCH]),
#     x=analyzer.COL_NAMES_CAMERA["batch_size"],
#     y="overlap",
#     hue=analyzer.COL_NAMES_CAMERA["experiment_type"],
#     palette=analyzer.EXPERIMENT_TYPES_PALETTE,
# )
sns.despine()

fig.set_ylabel("Stitch@200 inventions discovered")

lgd = plt.legend(bbox_to_anchor=(1.0, 1.0), loc="upper left")
plt.savefig(os.path.join(FIGURES_DOMAIN_DIR, "oracle_discovery_barplot.pdf"), dpi=300, bbox_inches="tight")

In [None]:
fig = sns.lineplot(
    data=analyzer.format_dataframe_camera(df_overlap),
    x=analyzer.COL_NAMES_CAMERA["batch_size"],
    y="overlap",
    size="baseline",
    style="baseline",
    hue=analyzer.COL_NAMES_CAMERA["experiment_type"],
    palette=analyzer.EXPERIMENT_TYPES_PALETTE,
    legend=True,
)
# fig.set(xscale="log")
plt.ylabel("Number of Oracle inventions discovered")
lgd = plt.legend(bbox_to_anchor=(1.0, 1.0), loc="upper left", fontsize=14)
plt.savefig(os.path.join(FIGURES_DOMAIN_DIR, "oracle_discovery_rate.pdf"), dpi=300, bbox_inches="tight")

In [None]:
df_overlap = get_df_overlap(df_libraries, experiment_type_base = ExperimentType.ORACLE_TRAIN_TEST)

fig = sns.barplot(
    data=analyzer.format_dataframe_camera(df_overlap[df_overlap.experiment_type == ExperimentType.ORACLE_TRAIN_TEST]),
    x=analyzer.COL_NAMES_CAMERA["batch_size"],
    y="overlap",
    hue=analyzer.COL_NAMES_CAMERA["experiment_type"],
    palette=analyzer.EXPERIMENT_TYPES_PALETTE,
)

fig = sns.barplot(
    data=analyzer.format_dataframe_camera(df_overlap[~df_overlap.experiment_type.isin([ExperimentType.ORACLE_TRAIN_TEST, ExperimentType.ORACLE])]),
    x=analyzer.COL_NAMES_CAMERA["batch_size"],
    y="overlap",
    hue=analyzer.COL_NAMES_CAMERA["experiment_type"],
    palette=analyzer.EXPERIMENT_TYPES_PALETTE,
)
sns.despine()

fig.set_ylabel("Oracle inventions discovered")

lgd = plt.legend(bbox_to_anchor=(1.0, 1.0), loc="upper left")
plt.savefig(os.path.join(FIGURES_DOMAIN_DIR, "oracle_discovery_barplot.pdf"), dpi=300, bbox_inches="tight")

In [None]:
from matplotlib_venn import venn3

In [None]:
def make_venn3(
    df_libraries,
    codex_type=ExperimentType.STITCH_CODEX,
    codex_seed=111,
    stitch_type=ExperimentType.STITCH,
    oracle_type=ExperimentType.ORACLE_TRAIN_TEST,
):
    fig = plt.figure(figsize=(20, 8))
    i = 1
    for batch_size, group in df_libraries.groupby("batch_size"):
        fns_stitch = set(group[group.experiment_type == stitch_type].dreamcoder)
        fns_codex = set(
            group[
                (group.experiment_type == codex_type)
                & (group.random_seed == codex_seed)
            ].dreamcoder
        )
        fns_oracle = set(group[group.experiment_type == oracle_type].dreamcoder)
        ax = fig.add_subplot(2, 4, i)
        c = venn3(
            [fns_stitch, fns_codex, fns_oracle],
            # set_labels=(stitch_type.value, codex_type.value, oracle_type.value),
            set_labels=("", "", ""),
        )
        plt.title("Batch size: " + str(batch_size), fontweight="bold")
        # plt.savefig(f"library_venn_diagrams/venn3/batch_{int(batch_size):03d}.png", dpi=144)

        try:
            c.get_patch_by_id("100").set_color(analyzer.EXPERIMENT_TYPES_PALETTE[analyzer.EXPERIMENT_TYPES_CAMERA[stitch_type]])
            c.get_patch_by_id("010").set_color(analyzer.EXPERIMENT_TYPES_PALETTE[analyzer.EXPERIMENT_TYPES_CAMERA[codex_type]])
            c.get_patch_by_id("001").set_color(analyzer.EXPERIMENT_TYPES_PALETTE[analyzer.EXPERIMENT_TYPES_CAMERA[oracle_type]])
        except:
            pass

        if i == 1:
            plt.legend(
                handles=[
                    c.get_patch_by_id("100"),
                    c.get_patch_by_id("010"),
                    c.get_patch_by_id("001"),
                ],
                labels=[analyzer.EXPERIMENT_TYPES_CAMERA[t] for t in [stitch_type.value, codex_type.value, oracle_type.value]],
                fontsize=16,
                bbox_to_anchor=(0.0, 1.5),
                loc="lower left",
                ncol=3,
            )

        i += 1

In [None]:
# make_venn3(df_libraries, codex_type=ExperimentType.STITCH_CODEX, codex_seed=111)
make_venn3(df_libraries, codex_type=ExperimentType.STITCH_CODEX, codex_seed=111)
plt.savefig(os.path.join(FIGURES_DOMAIN_DIR, "invention_venn3.pdf"), dpi=300, bbox_inches="tight")