In [1]:
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black

In [2]:
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import evallm
from automata.fa.dfa import DFA
import tqdm.auto as tqdm

In [3]:
import itertools
from permacache import stable_hash

In [4]:
from evallm.llm.llm import model_specs
from evallm.prompting.transducer_prompt import (
    BasicSequencePrompt,
    BasicSequencePromptSlightlyMoreExplanation,
    BasicSequencePromptNoChat,
    SequencePromptWithExplanation,
    SequencePromptWithExplanationChainOfThought,
    RedGreenRoomPrompt1,
    BasicInstructionTransducerPrompter,
)
from evallm.experiments.transducer_experiment import (
    current_transducer_experiments,
    compute_relative_to_null,
    compute_relative_to_ngram,
    print_example,
    bottom_quartile_outcome,
    current_dfa_sample_spec,
    run_transducer_experiment_just_stats,
    run_transducer_experiment,
    run_brute_force_transducer,
)
from evallm.experiments.transducer_plotting import (
    plot_all_absolute_results_single_graph,
    plot_absolute_results_barchart,
    produce_table,
)
from evallm.experiments.models_display import model_by_display_key
from evallm.utils.bootstrap import boostrap_mean

In [5]:
num_states = 3
num_symbols = 3
num_sequence_symbols = 30
num_repeats_per_dfa = 30
sample_dfa_spec = current_dfa_sample_spec(num_states=num_states)
setting_kwargs = dict(
    num_sequence_symbols=num_sequence_symbols,
    sample_dfa_spec=sample_dfa_spec,
    num_states=num_states,
)

In [6]:
prompt_by_key = {
    "Basic": {
        "non-chat": BasicSequencePromptNoChat.for_setting(setting_kwargs),
        "chat": BasicSequencePrompt.for_setting(setting_kwargs),
    },
    "More-Expl": {
        "chat": BasicSequencePromptSlightlyMoreExplanation.for_setting(setting_kwargs)
    },
    "COT": {
        "chat": SequencePromptWithExplanationChainOfThought.for_setting(setting_kwargs)
    },
    "Red-Green": {"chat": RedGreenRoomPrompt1.for_setting(setting_kwargs)},
}

In [7]:
def for_model_and_prompt(model, num_dfas, *prompts):
    model_key = model_by_display_key[model]
    if model_specs[model_key].is_chat:
        prompt_kind = "chat"
    else:
        prompt_kind = "non-chat"
    return {
        (model, prompt): run_transducer_experiment(
            model_key,
            sample_dfa_spec,
            prompt_by_key[prompt][prompt_kind],
            num_repeats_per_dfa=num_repeats_per_dfa,
            num_dfas=num_dfas,
        )
        for prompt in prompts
    }

In [8]:
deterministic_baseline_outcomes = run_transducer_experiment_just_stats(
    "none",
    sample_dfa_spec,
    BasicInstructionTransducerPrompter(num_sequence_symbols, strip=True),
    num_repeats_per_dfa=num_repeats_per_dfa,
    num_dfas=1000,
)
model_outcomes = {
    **for_model_and_prompt("llama3-8B", 1000, "Basic"),
    **for_model_and_prompt("llama3-70B", 1000, "Basic"),
    **for_model_and_prompt("llama3.1-8B-Instruct", 1000, "Basic"),
    **for_model_and_prompt("starcoder2-15b", 100, "Basic"),
    **for_model_and_prompt("codestral-22B", 1000, "Basic"),
    **for_model_and_prompt("deepseek-coder-33b-instruct", 1000, "Basic"),
    **for_model_and_prompt("qwen-2.5-coder-7B", 1000, "Basic"),
    **for_model_and_prompt("qwen-2.5-coder-instruct-7B", 1000, "Basic"),
    **for_model_and_prompt("mistral-nemo-minitron-8B", 1000, "Basic"),
    **for_model_and_prompt("mistral-nemo-base-12B", 1000, "Basic"),
    **for_model_and_prompt("mistral-nemo-instruct-12B", 1000, "Basic"),
    **for_model_and_prompt("gemma-7b", 1000, "Basic"),
    **for_model_and_prompt("falcon-7b", 1000, "Basic"),
    **for_model_and_prompt("gpt-3.5-instruct", 100, "Basic"),
    **for_model_and_prompt("gpt-3.5-chat", 100, "Basic"),
    **for_model_and_prompt(
        "gpt-4o-mini",
        100,
        "Basic",
        "More-Expl",
        "COT",
        "Red-Green",
    ),
    **for_model_and_prompt("gpt-4o", 30, "Basic"),
    **for_model_and_prompt(
        "claude-3.5",
        30,
        "Basic",
        "More-Expl",
        "COT",
        "Red-Green",
    ),
}

In [9]:
no_prompt = "Basic"

accuracies = defaultdict(dict)
accuracies[r"\textsc{Null}$_T$"][no_prompt] = [
    r.null_success_rate for r in deterministic_baseline_outcomes
]
for ngram in range(2, 2 + 5):
    accuracies[rf"{ngram}-\textsc{{Gram}}$_T$"][no_prompt] = [
        r.kgram_success_rates_each[ngram - 2] for r in deterministic_baseline_outcomes
    ]
accuracies[r"\textsc{BruteForce}$_T$"][no_prompt] = run_brute_force_transducer(
    sample_dfa_spec,
    num_states,
    num_symbols,
    num_sequence_symbols,
    num_repeats_per_dfa,
    num_dfas=1000,
)
for model, prompt in model_outcomes:
    accuracies[model][prompt] = [
        r.success_rate_binary_ignore_na for r in model_outcomes[model, prompt]
    ]

In [10]:
prompts = list(prompt_by_key)

In [11]:
def display_prompt(p):
    return rf"\textsc{{{p}}}$_T$"

In [12]:
produce_table(
    {
        m: {display_prompt(p): accuracies[m][p] for p in accuracies[m]}
        for m in accuracies
    },
    [display_prompt(p) for p in prompt_by_key],
)

\begin{tabular}{|r|c|c|c|c|}
\hline
Model & \textsc{Basic}$_T$ & \textsc{More-Expl}$_T$ & \textsc{COT}$_T$ & \textsc{Red-Green}$_T$\\
\hline
\cellcolor{lightgray}\textsc{BruteForce}$_T$ &\cellcolor{lightgray}96.4 (96.2--96.7)&--&--&--\\
\hline
\bf 6-\textsc{Gram}$_T$ &\bf 93.5 (93.1--93.9)&--&--&--\\
\hline
5-\textsc{Gram}$_T$ &93.4 (93.0--93.7)&--&--&--\\
\hline
4-\textsc{Gram}$_T$ &91.1 (90.6--91.6)&--&--&--\\
\hline
mistral-nemo-minitron-8B &88.6 (88.0--89.1)&--&--&--\\
\hline
qwen-2.5-coder-instruct-7B &88.3 (87.8--88.8)&--&--&--\\
\hline
qwen-2.5-coder-7B &88.2 (87.6--88.7)&--&--&--\\
\hline
mistral-nemo-instruct-12B &88.0 (87.5--88.5)&--&--&--\\
\hline
mistral-nemo-base-12B &87.9 (87.4--88.4)&--&--&--\\
\hline
gpt-3.5-instruct &87.8 (85.9--89.6)&--&--&--\\
\hline
llama3-70B &87.7 (87.2--88.3)&--&--&--\\
\hline
starcoder2-15b &87.7 (85.8--89.5)&--&--&--\\
\hline
llama3-8B &87.5 (86.9--88.0)&--&--&--\\
\hline
claude-3.5 &86.9 (83.3--90.0)&87.1 (83.9--90.2)&76.4 (72.9--79.9)&82.9 (7