In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import asyncio
import os
import pandas as pd

from hfppl.llms import CachedCausalLM
from hfppl.inference import smc_standard

from battleship.board import Board, TRIAL_IDS
from battleship.prompting import QuestionGenerationPrompt, TranslationPrompt
from battleship.scoring import compute_score
from battleship.models import QuestionGenerationModel, SingleStepQuestionGenerationModel

In [3]:
# Initialize the HuggingFace model
lm = CachedCausalLM.from_pretrained("codellama/CodeLlama-7b-hf")
# lm = CachedCausalLM.from_pretrained("codellama/CodeLlama-13b-hf")



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



# Run SMC

In [9]:
async def run_smc_baseline(
    n_particles=5,
    trial_ids=TRIAL_IDS,
    board_format="textual",
    n_example_trials=3,
    results_file="hfppl_results.csv",
    random_seed=123,
    verbose=False,
):
    results_all = []
    for trial_id in trial_ids:
        print("-" * 80)
        print(f"Trial {trial_id}")
        print("-" * 80)

        # TODO: Sample the question prompt each trial?
        question_prompt = QuestionGenerationPrompt(
            target_trial_id=trial_id,
            board_format=board_format,
            n_example_trials=n_example_trials,
            n_examples_per_trial=3,
            include_board=False,
            include_system_prompt=False,
            include_instructions=False,
            random_seed=random_seed,
        )

        # TODO: Sample the translation prompt each trial?
        translation_prompt = TranslationPrompt(
            target_trial_id=trial_id,
            n_example_trials=10,
            n_examples_per_trial=1,
            include_system_prompt=False,
            random_seed=random_seed,
        )
        print("-" * 80)
        print("QUESTION PROMPT")
        print("-" * 80)
        print(str(question_prompt))
        print("-" * 80)
        print("TRANSLATION PROMPT")
        print("-" * 80)
        print(str(translation_prompt))
        print("-" * 80)

        lm.clear_cache()
        lm.clear_kv_cache()

        # Significantly speeds up performance, but may result in CUDA out of memory error
        # lm.cache_kv(lm.tokenizer.encode(str(question_prompt)))
        # lm.cache_kv(lm.tokenizer.encode(str(translation_prompt)))

        model = QuestionGenerationModel(
            lm=lm,
            board=Board.from_trial_id(trial_id),
            question_prompt=str(question_prompt),
            translation_prompt=str(translation_prompt),
            verbose=verbose,
        )

        particles = await smc_standard(model, n_particles=n_particles)

        results_trial = []
        for i, p in enumerate(particles):
            df_p = pd.DataFrame(p.get_final_results())
            df_p["particle"] = i
            results_trial.append(df_p)
        df_trial = pd.concat(results_trial).reset_index(drop=True)
        df_trial["trial_id"] = trial_id
        results_all.append(df_trial)
        df_results = pd.concat(results_all).reset_index(drop=True)
        df_results.to_csv(results_file, index=False)
    return df_results

In [10]:
# TRIAL_IDS = range(1, 19)
TRIAL_IDS = [13]
N_PARTICLES = 3
BOARD_FORMAT = "grid"
RESULTS_FILE = "results/eval_smc.csv"

await run_smc_baseline(
    n_particles=N_PARTICLES,
    trial_ids=TRIAL_IDS,
    board_format=BOARD_FORMAT,
    results_file=RESULTS_FILE,
    verbose=True,
)

--------------------------------------------------------------------------------
Trial 13
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
QUESTION PROMPT
--------------------------------------------------------------------------------
Question: How many tiles is the red ship?
Question: At what location is the top left part of the red ship?
Question: Is the purple ship 4 tiles long?
Question: Is there a ship at 6A?
Question: How many tiles is the blue ship?
Question: What color is at 4B?
Question: At what location is the top left part of the red ship?
Question: How many tiles is the red ship?
Question: Are all ships 3 tiles long?
Question: 
--------------------------------------------------------------------------------
TRANSLATION PROMPT
--------------------------------------------------------------------------------
User: At what location is the top left part of the red sh

Unnamed: 0,prefix,completion,translation,score,type,particle,trial_id
0,3A represents what color?,3A represents what color?,3A,0.0,final,0,13
1,3,325A is a ship?,325A,0.0,rollout,0,13
2,3,3B is a ship.,3B,0.0,rollout,0,13
3,3,3B is a ship?,3B,0.0,rollout,0,13
4,3A,3A - What color is that tile?,3A,0.0,rollout,0,13
...,...,...,...,...,...,...,...
100,What is the color of the ship next to the shi...,What is the color of the ship next to the ship...,4 (color (right (coloredTiles Red))),0.0,rollout,2,13
101,What is the color of the ship next to the shi...,What is the color of the ship next to the ship...,4 (color (right (color 4))),0.0,rollout,2,13
102,What is the color of the ship next to the shi...,What is the color of the ship next to the ship...,4 (color (right (coloredTiles Red))),0.0,rollout,2,13
103,What is the color of the ship next to the shi...,What is the color of the ship next to the ship...,4 (color (neighbor 4)),0.0,rollout,2,13


# Debugging

In [8]:
lm.clear_cache()
lm.clear_kv_cache()

question_prompt = QuestionGenerationPrompt(
    target_trial_id=13,
    board_format="grid",
    n_example_trials=3,
    n_examples_per_trial=3,
    include_board=False,
    include_instructions=False,
    include_system_prompt=False,
    random_seed=123,
)

# TODO: Sample the translation prompt each trial?
translation_prompt = TranslationPrompt(
    target_trial_id=13,
    n_example_trials=10,
    n_examples_per_trial=1,
    include_system_prompt=False,
    random_seed=123,
)

print("-" * 80)
print(str(question_prompt))
print("-" * 80)
print(str(translation_prompt))
print("-" * 80)

model = QuestionGenerationModel(
    lm=lm,
    board=Board.from_trial_id(13),
    question_prompt=str(question_prompt),
    translation_prompt=str(translation_prompt),
    verbose=True,
)

partial_questions = [
    "What",
    "Is",
    "How many",
]

print("-" * 80)
print("QUESTIONS")
for q in partial_questions:
    completion = await model._complete_question(q)
    print(completion)


questions = [
    "What is the length of the red ship?",
    "Is there a ship at 3C?",
    "What is the top left corner of the blue ship?",
]

print("-" * 80)
print("TRANSLATIONS")
for q in questions:
    print("-" * 80)
    print(q)
    completion = await model._translate_question(q)
    print(completion)

--------------------------------------------------------------------------------
Question: How many tiles is the red ship?
Question: At what location is the top left part of the red ship?
Question: Is the purple ship 4 tiles long?
Question: Is there a ship at 6A?
Question: How many tiles is the blue ship?
Question: What color is at 4B?
Question: At what location is the top left part of the red ship?
Question: How many tiles is the red ship?
Question: Are all ships 3 tiles long?
Question: 
--------------------------------------------------------------------------------
User: At what location is the top left part of the red ship?
Query: (topleft (coloredTiles Red))
User: How many tiles is the purple ship?
Query: (size Purple)
User: Is there a ship at 4A?
Query: (not (== (color 4A) Water))
User: What color is at 6F?
Query: (color 6F)
User: How many tiles is the red ship?
Query: (size Red)
User: How many tiles is the purple ship?
Query: (size Purple)
User: Do the red ship and the purple sh