In [None]:
import sys

sys.path.append("../")

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import data.drawings.make_tasks as drawing_tasks
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from src.analysis_utilities import IterativeExperimentAnalyzer
from src.config_builder import ExperimentType

In [None]:
%config InlineBackend.figure_format = 'retina'

In [None]:
# EXPERIMENT_NAME = "gg_drawings"

# DOMAIN_NAMES_CAMERA = {
#     "drawings_nuts_bolts": "nuts & bolts",
#     "drawings_wheels": "vehicles",
#     "drawings_dials": "gadgets",
#     "drawings_furniture": "furniture",
# }

In [None]:
EXPERIMENT_NAME = "gg_laps_domains"

DOMAIN_NAMES_CAMERA = {
    "re2": "REGEX",
    "clevr": "CLEVR",
}

In [None]:
FIGURES_DIR = os.path.join("figures", EXPERIMENT_NAME)
os.makedirs(FIGURES_DIR, exist_ok=True)

# Multi-domain analysis

In [None]:
analyzer_multi = IterativeExperimentAnalyzer(
    experiment_name=EXPERIMENT_NAME,
    allow_incomplete_results=False,
)

In [None]:
df_list = []
for domain in DOMAIN_NAMES_CAMERA:
    df = analyzer_multi.get_results_for_domain(domain=domain)
    df["domain"] = DOMAIN_NAMES_CAMERA[domain]
    df_list.append(df)
df_domains = pd.concat(df_list, axis=0).reset_index(drop=True)

In [None]:
df_domains = analyzer_multi.format_dataframe_camera(df_domains)

g = sns.catplot(
    data=df_domains,
    x=analyzer_multi.COL_NAMES_CAMERA["batch_size"],
    y=analyzer_multi.COL_NAMES_CAMERA["description_length"],
    hue=analyzer_multi.COL_NAMES_CAMERA["experiment_type"],
    col="domain",
    col_wrap=2,
    kind="point",
    sharex=False,
    sharey=False,
    legend=False,
    aspect=1.5,
    palette=analyzer_multi.EXPERIMENT_TYPES_PALETTE,
)

g.set_axis_labels(
    analyzer_multi.COL_NAMES_CAMERA["batch_size"],
    analyzer_multi.COL_NAMES_CAMERA["description_length"],
    fontsize=14,
)
g.set_xticklabels(size=12)
g.set_yticklabels(size=12)
g.set_titles(col_template="{col_name}", size=18)

lgd = plt.legend(bbox_to_anchor=(1.0, df_domains["domain"].nunique() // 2), loc="upper left", fontsize=18)

plt.savefig(
    os.path.join(FIGURES_DIR, f"{EXPERIMENT_NAME}_results_camera.pdf"),
    dpi=300,
    bbox_extra_artists=(lgd,),
    bbox_inches="tight",
)