In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import asyncio
import os
import pandas as pd

from hfppl.llms import CachedCausalLM
from hfppl.inference import smc_standard
from hfppl.distributions import LMContext

from battleship.v1.board import Board, TRIAL_IDS
from battleship.prompting import QuestionGenerationPrompt, TranslationPrompt
from battleship.scoring import compute_score
from battleship.models import QuestionGenerationModel, SingleStepQuestionGenerationModel

In [None]:
# Initialize the HuggingFace model
lm = CachedCausalLM.from_pretrained("codellama/CodeLlama-7b-hf")
# lm = CachedCausalLM.from_pretrained("codellama/CodeLlama-13b-hf")

# Run SMC

In [None]:
async def run_smc_baseline(
    n_particles=5,
    trial_ids=TRIAL_IDS,
    board_format="textual",
    include_board=False,
    include_instructions=False,
    include_system_prompt=False,
    # question prompt
    q_n_example_trials=3,
    q_n_examples_per_trial=3,
    q_cache_prompt=False,
    # translation prompt
    t_n_example_trials=10,
    t_n_examples_per_trial=1,
    t_cache_prompt=False,
    results_file="hfppl_results.csv",
    random_seed=123,
    verbose=False,
):
    results_all = []
    for trial_id in trial_ids:
        print("-" * 80)
        print(f"Trial {trial_id}")
        print("-" * 80)

        # TODO: Sample the question prompt each trial?
        question_prompt = QuestionGenerationPrompt(
            target_trial_id=trial_id,
            board_format=board_format,
            n_example_trials=q_n_example_trials,
            n_examples_per_trial=q_n_examples_per_trial,
            include_board=include_board,
            include_instructions=include_instructions,
            include_system_prompt=include_system_prompt,
            random_seed=random_seed,
        )

        # TODO: Sample the translation prompt each trial?
        translation_prompt = TranslationPrompt(
            target_trial_id=trial_id,
            n_example_trials=t_n_example_trials,
            n_examples_per_trial=t_n_examples_per_trial,
            include_board=include_board,
            include_instructions=include_instructions,
            include_system_prompt=include_system_prompt,
            random_seed=random_seed,
        )

        print("-" * 80)
        print("QUESTION PROMPT")
        print("-" * 80)
        print(str(question_prompt))
        print("-" * 80)
        print("TRANSLATION PROMPT")
        print("-" * 80)
        print(str(translation_prompt))
        print("-" * 80)

        lm.clear_cache()
        lm.clear_kv_cache()

        # Caching speeds up performance, but may result in CUDA out of memory error.
        if q_cache_prompt:
            lm.cache_kv(lm.tokenizer.encode(str(question_prompt)))
        if t_cache_prompt:
            # Additionally, this currently degrades the quality of the translations for an unknown reason.
            lm.cache_kv(lm.tokenizer.encode(str(translation_prompt)))

        model = QuestionGenerationModel(
            lm=lm,
            board=Board.from_trial_id(trial_id),
            question_prompt=str(question_prompt),
            translation_prompt=str(translation_prompt),
            verbose=verbose,
        )

        particles = await smc_standard(model, n_particles=n_particles)

        results_trial = []
        for i, p in enumerate(particles):
            df_p = pd.DataFrame(p.get_final_results())
            df_p["particle"] = i
            results_trial.append(df_p)
        df_trial = pd.concat(results_trial).reset_index(drop=True)
        df_trial["trial_id"] = trial_id
        results_all.append(df_trial)
        df_results = pd.concat(results_all).reset_index(drop=True)
        df_results.to_csv(results_file, index=False)
    return df_results

In [None]:
# RUN SMC BASELINE
df_results = await run_smc_baseline(
    n_particles=1,
    trial_ids=[13],
    board_format="textual",
    include_board=False,
    include_instructions=False,
    include_system_prompt=False,
    q_n_example_trials=3,
    q_n_examples_per_trial=10,
    t_n_example_trials=12,
    t_n_examples_per_trial=1,
    q_cache_prompt=True,
    t_cache_prompt=False,
    results_file="results/eval_smc.csv",
    verbose=True,
)

In [None]:
df_results.sort_values("score", ascending=False).head(30)

In [None]:
df_results.query("type == 'final'")

# Debugging

In [None]:
lm.clear_cache()
lm.clear_kv_cache()

question_prompt = QuestionGenerationPrompt(
    target_trial_id=13,
    board_format="grid",
    n_example_trials=3,
    n_examples_per_trial=3,
    include_board=False,
    include_instructions=False,
    include_system_prompt=False,
    random_seed=123,
)

translation_prompt = TranslationPrompt(
    target_trial_id=13,
    n_example_trials=10,
    n_examples_per_trial=1,
    include_system_prompt=False,
    include_instructions=False,
    random_seed=123,
)

print("-" * 80)
print(str(question_prompt))
print("-" * 80)
print(str(translation_prompt))
print("-" * 80)

# lm.cache_kv(lm.tokenizer.encode(str(translation_prompt)))

# Alternative way to force the LM to cache the translation prompt
# from hfppl.distributions import LMContext
# from hfppl.modeling import Model
# model2 = Model()
# ctx = LMContext(
#     lm,
#     str(translation_prompt),
#     temp=0.1,
# )
# token = await model2.sample(ctx.next_token())

model = QuestionGenerationModel(
    lm=lm,
    board=Board.from_trial_id(13),
    question_prompt=str(question_prompt),
    translation_prompt=str(translation_prompt),
    verbose=True,
)

# partial_questions = [
#     "What",
#     "Is",
#     "How many",
# ]

# print("-" * 80)
# print("QUESTIONS")
# for q in partial_questions:
#     completion = await model._complete_question(q)
#     print(completion)

questions = [
    "What is the length of the red ship?",
    "What is the top left corner of the blue ship?",
    "Is there a ship at 3C?",
]

print("-" * 80)
print("TRANSLATIONS")
for q in questions:
    print("-" * 80)
    print(q)
    completion = await model._translate_question(q)
    print(completion)

In [None]:
# Token 29871 is empty space

# WITH CACHING
# tokens = [13, 2659, 29901, 1317, 727, 263, 7751, 472, 29871, 29906, 29923, 29973, 13, 3010, 29901, 1317, 727, 263, 7751, 472, 29871, 29906, 29923, 29973, 13]

# WITHOUT CACHING
tokens = [1, 4911, 29901, 2180, 825, 4423, 338, 278, 2246, 2175, 760, 310, 278, 2654, 7751, 29973, 13, 3010, 29901, 313, 3332, 1563, 313, 2780, 287, 29911, 5475, 4367, 876, 13, 2659, 29901, 1128, 1784, 260, 5475, 338, 278, 3708, 552, 7751, 29973, 13, 3010, 29901, 313, 2311, 15247, 552, 29897, 13, 2659, 29901, 1317, 727, 263, 7751, 472, 29871, 29946, 29909, 29973, 13, 3010, 29901, 313, 1333, 313, 1360, 313, 2780, 29871, 29946, 29909, 29897, 13062, 876, 13, 2659, 29901, 1724, 2927, 338, 472, 29871, 29953, 29943, 29973, 13, 3010, 29901, 313, 2780, 29871, 29953, 29943, 29897, 13, 2659, 29901, 1128, 1784, 260, 5475, 338, 278, 2654, 7751, 29973, 13, 3010, 29901, 313, 2311, 4367, 29897, 13, 2659, 29901, 1128, 1784, 260, 5475, 338, 278, 3708, 552, 7751, 29973, 13, 3010, 29901, 313, 2311, 15247, 552, 29897, 13, 2659, 29901, 1938, 278, 2654, 7751, 322, 278, 3708, 552, 7751, 6023, 29973, 13, 3010, 29901, 313, 16747, 4367, 15247, 552, 29897, 13, 2659, 29901, 1128, 1784, 260, 5475, 338, 278, 2654, 7751, 29973, 13, 3010, 29901, 313, 2311, 4367, 29897, 13, 2659, 29901, 1724, 338, 278, 4423, 310, 697, 3708, 552, 25900, 29973, 13, 3010, 29901, 313, 3332, 1563, 313, 2780, 287, 29911, 5475, 15247, 552, 876, 13, 2659, 29901, 1128, 1784, 260, 5475, 338, 278, 3708, 552, 7751, 29973, 13, 3010, 29901, 313, 2311, 15247, 552, 29897, 13, 2659, 29901, 1317, 727, 263, 7751, 472, 29871, 29945, 29909, 29973, 13, 3010, 29901, 313, 1333, 313, 1360, 313, 2780, 29871, 29945, 29909, 29897, 13062, 876, 13]

for t in tokens:
    print(t, repr(lm.tokenizer.decode([t])))