# Hyperparameter Experiment Results

In [None]:
from capo.analysis.utils import (
    generate_comparison_table,
)
from capo.analysis.visualizations import (
    plot_population_scores_comparison,
    plot_length_score,
)

import os

os.chdir("../../")

In [None]:
DATASETS = ["agnews", "gsm8k"]
# markes to choose: ["8","s","D","o","^","p","X"]

## Length Penalty

In [None]:
hp_runs = [
    "CAPO_no_lp",
    "CAPO_gamma_0.01",
    "CAPO_gamma_0.02",
    "CAPO_gamma_0.05",
    "dummy",
    "dummy",
    "CAPO_gamma_0.1",
]
markers = ["8", "s", "d", "o", None, None, "p"]
labels = [
    r"$\gamma=0$",
    r"$\gamma=0.01$",
    r"$\gamma=0.02$",
    r"$\gamma=0.05$ (CAPO)",
    "Dummy",
    "Dummy",
    r"$\gamma=0.1$",
]

In [None]:
for dataset in DATASETS:
    plot_population_scores_comparison(
        dataset,
        "llama",
        hp_runs,
        "mean",
        plot_seeds=False,
        plot_stddev=False,
        x_col="step",
        score_col="test_score",
        continuous_colors=True,
        markers=markers,
        labels=labels,
    );

- test score is just slightly better without length penalty

In [None]:
plot_population_scores_comparison(
    "gsm8k",
    "llama",
    hp_runs,
    "mean",
    plot_seeds=False,
    plot_stddev=False,
    x_col="step",
    score_col="prompt_len",
    continuous_colors=True,
    markers=markers,
    labels=labels,
);

- prompt length is clearly shorter than without the length penalty

- we can see this also for gsm8k but the effect is smaller here

In [None]:
for dataset in DATASETS:
    plot_length_score(
        dataset,
        "llama",
        ["CAPO_gamma_0.1", "nan", "nan", "CAPO_no_lp"],
        x_col="prompt_len",
        score_col="test_score",
        log_scale=False,
    );

- for agnews we can see that all prompts tend to be clearly longer than for a high length penalty and the best scores also are longer without the penalty
- for gsm8k this can only be seen in the extremes

In [None]:
generate_comparison_table(
    DATASETS,
    ["CAPO_no_lp", "CAPO_gamma_0.01", "CAPO_gamma_0.02", "CAPO", "CAPO_gamma_0.1"],
    "llama",
)

- works very good for agnews, but not for gsm8k

### Conclusion
- we can perform more steps if we choose a higher length penalty (because the prompts are then also shorter)

## Population Size

In [None]:
hp_runs = ["CAPO_pop_6", "CAPO_pop_8", "CAPO_pop_10", "Dummy", "CAPO_pop_12"]
markers = ["8", "s", "o", None, "p"]
labels = [r"$\mu=6$", r"$\mu=8$", r"$\mu=10$ (CAPO)", "Dummy", r"$\mu=12$"]

In [None]:
for dataset in DATASETS:
    plot_population_scores_comparison(
        dataset,
        "llama",
        hp_runs,
        "mean",
        plot_seeds=False,
        plot_stddev=True,
        x_col="step",
        score_col="test_score",
        continuous_colors=True,
        markers=markers,
        labels=labels,
        ncols=3,
        figsize=(5.4, 3),
    );

In [None]:
generate_comparison_table(
    DATASETS,
    ["CAPO_pop_6", "CAPO_pop_8", "CAPO", "CAPO_pop_12"],
    "llama",
)

- population size does not have a measurable effect on the performance
- higher standard deviation for smaller population sizes
- we are quite robust for this hyperparameter (i would not say so? i would say this is a tuning parameter with some influence that is not trivial to choose / depends on the task)

## Number of Crossovers

In [None]:
hp_runs = ["Dummy", "Dummy", "CAPO_ncrossovers_4", "CAPO_ncrossovers_7", "CAPO_ncrossovers_10"]
markers = [None, None, "o", "p", "d"]
labels = ["Dummy", "Dummy", r"$c=4$ (CAPO)", r"$c=7$", r"$c=10$"]

In [None]:
for dataset in DATASETS:
    plot_population_scores_comparison(
        dataset,
        "llama",
        hp_runs,
        "mean",
        plot_seeds=False,
        plot_stddev=True,
        x_col="step",
        score_col="test_score",
        continuous_colors=True,
        markers=markers,
        labels=labels,
    );

In [None]:
generate_comparison_table(
    DATASETS,
    ["CAPO", "CAPO_ncrossovers_7", "CAPO_ncrossovers_10"],
    "llama",
)

- with less crossovers we can do much more steps (do not have to evaluate to many new prompts)
- agnews has a higher variance with 10 crossovers
- slightly better performacne for n_crossovers = 7 (might have been a better choice - we are more sensitive to this hyperparameter than to the one we looked at before)

In [None]:
all_hp_runs = [
    "CAPO",
    "CAPO_no_lp",
    "CAPO_gamma_0.01",
    "CAPO_gamma_0.02",
    "CAPO_gamma_0.1",
    "CAPO_pop_6",
    "CAPO_pop_8",
    "CAPO_pop_12",
    "CAPO_ncrossovers_7",
    "CAPO_ncrossovers_10",
]

generate_comparison_table(DATASETS, all_hp_runs, "llama", score_col="test_score")