In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List

from battleship.prompting import QuestionGenerationPrompt, TranslationPrompt
from battleship.board import Board

In [3]:
MODEL_NAME = "codellama/CodeLlama-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    load_in_8bit=True,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
def get_completions(prompts: List[str], max_new_tokens: int = 32) -> List[str]:
    if isinstance(prompts, str):
        prompts = [prompts]

    inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(
        device=model.device
    )
    print(inputs)
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

    # Return only the completion
    completions = tokenizer.batch_decode(
        outputs[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True
    )
    # Remove everything after the first newline and strip whitespace
    # completions = [completion.split("\n")[0].strip() for completion in completions]
    return completions

In [7]:
completions = get_completions(["# Add one to a number\n\ndef add_one(x):\n"])
print(completions)

{'input_ids': tensor([[    1,   396,  3462,   697,   304,   263,  1353,    13,    13,  1753,
           788, 29918,   650, 29898, 29916,  1125,    13]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}
['   return x + 1\n\n\n# Add two numbers\ndef add_two(x, y):\n    return x + y\n\n\n']


In [15]:
prompt = QuestionGenerationPrompt(
    target_trial_id=5,
    board_format="grid",
    n_example_trials=3,
    n_examples_per_trial=3,
    include_system_prompt=False,
    include_instructions=False,
    random_seed=123,
)

print(str(prompt))

# prompt.to_chat_format()

Here are some examples of questions from other agents about different boards.

Board:
  A B C D E F
1 H H H H H H
2 P H H H H W
3 H H W H H R
4 H H H W H R
5 B B B B W W
6 H H H H H H


User: How many tiles is the red ship?
User: At what location is the top left part of the red ship?
User: Is the purple ship 4 tiles long?

Board:
  A B C D E F
1 H H H H H H
2 H W W H W H
3 H W H W H H
4 H P W W H H
5 H P W H W H
6 H P H H H H


User: Is there a ship at 6A?
User: How many tiles is the blue ship?
User: What color is at 4B?

Board:
  A B C D E F
1 H H H H W H
2 H H W H H H
3 H H H W H H
4 W H H H H H
5 H H H W H H
6 H H H H W H


User: At what location is the top left part of the red ship?
User: How many tiles is the red ship?
User: Are all ships 3 tiles long?
Now, it's your turn. Here is your board:

  A B C D E F
1 H H P H H H
2 H H W H H H
3 H H H W H H
4 H H H B H H
5 H H H B H H
6 H H H B H H


User:


In [18]:
get_completions([str(prompt)])

['Is there a ship at 6A?\nUser: How many tiles is the blue ship?\nUser: What color is at 4B?']

# Translation

In [26]:
translation_prompt = str(TranslationPrompt(
    # target_question="What is the length of the blue ship?",
    target_trial_id=13,
    n_example_trials=10,
    n_examples_per_trial=1,
    random_seed=123,
))
translation_prompt += "\nUser: What is the length of the blue ship?\n"
translation_prompt += "Assistant:"

print(str(translation_prompt))

User: At what location is the top left part of the red ship?
Assistant: (topleft (coloredTiles Red))

User: How many tiles is the purple ship?
Assistant: (size Purple)

User: Is there a ship at 4A?
Assistant: (not (== (color 4A) Water))

User: What color is at 6F?
Assistant: (color 6F)

User: How many tiles is the red ship?
Assistant: (size Red)

User: How many tiles is the purple ship?
Assistant: (size Purple)

User: Do the red ship and the purple ship touch?
Assistant: (touch Red Purple)

User: How many tiles is the red ship?
Assistant: (size Red)

User: What is the location of one purple tile?
Assistant: (topleft (coloredTiles Purple))

User: How many tiles is the purple ship?
Assistant: (size Purple)

User: What is the length of the blue ship?
Assistant:


In [27]:
get_completions([str(translation_prompt)])

['(size Blue)\n\nUser: What is the length of the red ship?\nAssistant: (size Red)\n\nUser: What is the']