In [None]:
import sys

sys.path.append("../")

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import data.drawings.make_tasks as drawing_tasks
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from src.analysis_utilities import IterativeExperimentAnalyzer
from src.config_builder import ExperimentType

In [None]:
%config InlineBackend.figure_format = 'retina'

# Single-domain analysis

In [None]:
# EXPERIMENT_NAME = "gg_drawings"
# DOMAIN = "drawings_nuts_bolts"
# DOMAIN = "drawings_furniture"
# DOMAIN = "drawings_dials"
# DOMAIN = "drawings_wheels"

EXPERIMENT_NAME = "gg_laps_domains"
DOMAIN = "clevr"
# DOMAIN = "re2"

FIGURES_DIR = os.path.join("figures", EXPERIMENT_NAME)
FIGURES_DOMAIN_DIR = os.path.join("figures", EXPERIMENT_NAME, DOMAIN)
os.makedirs(FIGURES_DOMAIN_DIR, exist_ok=True)

analyzer = IterativeExperimentAnalyzer(
    experiment_name=EXPERIMENT_NAME,
    allow_incomplete_results=False,
)

In [None]:
analyzer.get_available_experiment_types(domain=DOMAIN)

In [None]:
# experiment_types = [
#     ExperimentType.ORACLE.value,
#     ExperimentType.ORACLE_TRAIN_TEST.value,
#     ExperimentType.STITCH.value,
#     ExperimentType.STITCH_CODEX.value,
#     ExperimentType.STITCH_CODEX_LANGUAGE.value,
#     ExperimentType.STITCH_CODEX_LANGUAGE_ORIGIN_RANDOM_TEST.value,
# ]

experiment_types = None

df = analyzer.get_results_for_domain(domain=DOMAIN, experiment_types=experiment_types)

In [None]:
analyzer.plot_description_length(domain=DOMAIN, df=df)
plt.savefig(os.path.join(FIGURES_DOMAIN_DIR, f"description_length.png"), dpi=300)

In [None]:
analyzer.plot_description_length(
    domain=DOMAIN, df=df, plot_type="lineplot", logscale=True
)

In [None]:
analyzer.plot_n_frontiers(domain=DOMAIN, df=df)

## What programs does Codex generate?

In [None]:
df_codex = analyzer.get_codex_programs_for_experiment_type(
    DOMAIN, experiment_type=ExperimentType.STITCH_CODEX
)

In [None]:
plt.title("Percentage of valid programs")
sns.barplot(data=df_codex, x="batch_size", y="valid", hue="origin");

In [None]:
plt.title("Program string length")
sns.violinplot(data=df_codex, x="batch_size", y="program_str_len", hue="origin");

In [None]:
plt.title("Count of unique programs")

df_tmp1 = (
    df_codex.query("origin == 'train'")
    .groupby(["batch_size", "seed"])
    .nunique()
    .reset_index()
)
df_tmp2 = (
    df_codex.query("origin == 'codex'")
    .groupby(["batch_size", "seed"])
    .nunique()
    .reset_index()
)
df_tmp3 = (
    df_codex.query("origin == 'codex' & copied_from_train")
    .groupby(["batch_size", "seed"])
    .nunique()
    .reset_index()
)
df_tmp4 = (
    df_codex.query("origin == 'codex' & ~copied_from_train")
    .groupby(["batch_size", "seed"])
    .nunique()
    .reset_index()
)
df_tmp1["origin"] = "train"
df_tmp2["origin"] = "codex (overall)"
df_tmp3["origin"] = "codex (copied from train)"
df_tmp4["origin"] = "codex (original)"

df_tmp = pd.concat([df_tmp1, df_tmp2, df_tmp3, df_tmp4], axis=0).reset_index()

sns.pointplot(data=df_tmp, x="batch_size", y="program", hue="origin");

In [None]:
plt.title("Count of programs copied from train")
sns.pointplot(
    data=df_codex.groupby(["batch_size", "seed"]).sum().reset_index(),
    x="batch_size",
    y="copied_from_train",
);

In [None]:
plt.title(f"{DOMAIN}: Count of programs per prompt")
sns.pointplot(
    data=df_codex.query("origin == 'train'")
    .groupby(["batch_size", "seed", "query_id"])
    .count()
    .reset_index(),
    x="batch_size",
    y="program",
)

# Codex program analysis across multiple experiment types

In [None]:
df_codex = analyzer.get_codex_programs(DOMAIN)

In [None]:
# plt.title("Percentage of valid programs")
sns.barplot(
    data=analyzer.format_dataframe_camera(df_codex).query("origin != 'train'"),
    x=analyzer.COL_NAMES_CAMERA["batch_size"],
    y="valid",
    hue=analyzer.COL_NAMES_CAMERA["experiment_type"],
    palette=analyzer.EXPERIMENT_TYPES_PALETTE,
);

In [None]:
sns.catplot(
    kind="violin",
    col=analyzer.COL_NAMES_CAMERA["experiment_type"],
    data=analyzer.format_dataframe_camera(df_codex),
    x=analyzer.COL_NAMES_CAMERA["batch_size"],
    y="program_str_len",
    hue="origin",
);

In [None]:
plt.title(f"{DOMAIN}: Count of programs per prompt")
sns.pointplot(
    data=analyzer.format_dataframe_camera(df_codex)
    .query("origin == 'train'")
    .groupby(
        [
            analyzer.COL_NAMES_CAMERA["batch_size"],
            analyzer.COL_NAMES_CAMERA["experiment_type"],
            "seed",
            "query_id",
        ]
    )
    .count()
    .reset_index(),
    x=analyzer.COL_NAMES_CAMERA["batch_size"],
    y="program",
    hue=analyzer.COL_NAMES_CAMERA["experiment_type"],
    palette=analyzer.EXPERIMENT_TYPES_PALETTE,
)

In [None]:
EXPERIMENT_NAME = "gg_drawings"

DRAWING_DOMAINS = {
    "drawings_nuts_bolts": "nuts & bolts",
    "drawings_wheels": "vehicles",
    "drawings_dials": "gadgets",
    "drawings_furniture": "furniture",
}

analyzer_multi = IterativeExperimentAnalyzer(
    experiment_name=EXPERIMENT_NAME,
    allow_incomplete_results=False,
)

In [None]:
df_list = []
for domain in DRAWING_DOMAINS:
    df = analyzer_multi.get_results_for_domain(domain=domain)
    df["domain"] = DRAWING_DOMAINS[domain]
    df_list.append(df)
df_domains = pd.concat(df_list, axis=0).reset_index(drop=True)

In [None]:
df_domains = analyzer.format_dataframe_camera(df_domains)

g = sns.catplot(
    data=df_domains,
    x=analyzer_multi.COL_NAMES_CAMERA["batch_size"],
    y=analyzer_multi.COL_NAMES_CAMERA["description_length"],
    hue=analyzer_multi.COL_NAMES_CAMERA["experiment_type"],
    col="domain",
    col_wrap=2,
    kind="point",
    sharex=False,
    sharey=False,
    legend=False,
    aspect=1.5,
    palette=analyzer_multi.EXPERIMENT_TYPES_PALETTE,
)

g.set_axis_labels(
    analyzer_multi.COL_NAMES_CAMERA["batch_size"],
    analyzer_multi.COL_NAMES_CAMERA["description_length"],
    fontsize=14,
)
g.set_xticklabels(size=12)
g.set_yticklabels(size=12)
g.set_titles(col_template="{col_name}", size=18)

lgd = plt.legend(bbox_to_anchor=(1.0, 2.2), loc="upper left", fontsize=18)

plt.savefig(
    "drawings_results_camera.pdf",
    dpi=300,
    bbox_extra_artists=(lgd,),
    bbox_inches="tight",
)