In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List

from battleship.prompting import QuestionGenerationPrompt, TranslationPrompt
from battleship.board import Board

In [3]:
MODEL_NAME = "codellama/CodeLlama-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    load_in_8bit=True,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
def get_completions(prompts: List[str], max_new_tokens: int = 32) -> List[str]:
    if isinstance(prompts, str):
        prompts = [prompts]

    inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(
        device=model.device
    )
    print(inputs)
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

    # Return only the completion
    completions = tokenizer.batch_decode(
        outputs[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True
    )
    # Remove everything after the first newline and strip whitespace
    # completions = [completion.split("\n")[0].strip() for completion in completions]
    return completions

In [5]:
completions = get_completions(["# Add one to a number\n\ndef add_one(x):\n"])
print(completions)

{'input_ids': tensor([[    1,   396,  3462,   697,   304,   263,  1353,    13,    13,  1753,
           788, 29918,   650, 29898, 29916,  1125,    13]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}
['   return x + 1\n\n\n# Add two numbers\ndef add_two(x, y):\n    return x + y\n\n\n']


In [23]:
prompt = QuestionGenerationPrompt(
    target_trial_id=5,
    board_format="grid",
    n_example_trials=3,
    n_examples_per_trial=3,
    include_system_prompt=True,
    include_instructions=True,
    include_board=True,
    random_seed=123,
)

print(str(prompt))

# prompt.to_chat_format()

You are a game-playing agent. Read the game instructions and examples carefully. Respond with a single question that can be answered with one word. Do not include any other explanation or prose.

You are playing the board game Battleship. There are three ships on the board: Red, Blue, and Purple. Ships are oriented either horizontally or vertically and can be 2, 3, or 4 tiles in length. The board is a 6x6 grid, with numbered rows 1, 2, 3, 4, 5, 6 and lettered columns A, B, C, D, E, F. Coordinates are specified as a row, column pair. For example, 2-C is the tile in row 2, column C.

You will be given a partially-revealed game board. Your task is to ask a single question that will help you gain information about the position of the remaining hidden ships on the board. You can ask any question, but it must be answerable with a single word answer. 
The board is represented as a grid with the following symbols:

H: Hidden
W: Water
R: Red ship
B: Blue ship
P: Purple ship

Here are some examp

In [18]:
get_completions([str(prompt)])

['Is there a ship at 6A?\nUser: How many tiles is the blue ship?\nUser: What color is at 4B?']

# Translation

In [27]:
translation_prompt = str(TranslationPrompt(
    # target_question="What is the length of the blue ship?",
    target_trial_id=13,
    n_example_trials=10,
    n_examples_per_trial=1,
    random_seed=123,
    include_instructions=False,
))

print(str(translation_prompt))

User: At what location is the top left part of the red ship?
Query: (topleft (coloredTiles Red))
User: How many tiles is the purple ship?
Query: (size Purple)
User: Is there a ship at 4A?
Query: (not (== (color 4A) Water))
User: What color is at 6F?
Query: (color 6F)
User: How many tiles is the red ship?
Query: (size Red)
User: How many tiles is the purple ship?
Query: (size Purple)
User: Do the red ship and the purple ship touch?
Query: (touch Red Purple)
User: How many tiles is the red ship?
Query: (size Red)
User: What is the location of one purple tile?
Query: (topleft (coloredTiles Purple))
User: How many tiles is the purple ship?
Query: (size Purple)


In [27]:
get_completions([str(translation_prompt)])

['(size Blue)\n\nUser: What is the length of the red ship?\nAssistant: (size Red)\n\nUser: What is the']