In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np

from battleship.board import Board
from battleship.agents import CodeQuestion, Question
from battleship.game import BattleshipGame
from battleship.captains import create_captain
from battleship.spotters import create_spotter
from battleship.fast_sampler import FastSampler

In [None]:
import logging

# logging.basicConfig(level=logging.DEBUG)


In [None]:
# BOARD_ID = "B03"
# true_board = Board.from_trial_id(BOARD_ID)


SEED = 123

true_board = Board(np.zeros((8, 8), dtype=int))
true_board.board[7, 3:5] = 1
true_board.board[0:4, 2] = 2
true_board.board[6, 0:3] = 3
true_board.board[2:7, 6] = 4

true_board

In [None]:
partial_board = Board(np.full((8, 8), -1, dtype=int), transparent=True)
# partial_board.board[7, 3:5] = 1
# partial_board.board[0:4, 2] = 2
partial_board.board[6, 1:2] = 3
partial_board.board[2:7, 6] = 4
partial_board.board[3, 0] = 0
partial_board.board[6, 3] = 0
partial_board.board[1, 3] = 0
partial_board.board[6, 5] = 0
partial_board.board[7, 6] = 0

partial_board

In [None]:
def heatmap(board: Board, n_samples: int = 10000, constraints: list = [], ship_tracker=None):
    sampler = FastSampler(
        board, ship_tracker=ship_tracker, seed=SEED
    )
    # Compute posterior (return raw counts)
    posterior = sampler.compute_posterior(n_samples=n_samples, normalize=False, constraints=constraints)
    return Board._to_figure(posterior, mode="heatmap", transparent=True)


def sample_boards(board: Board, n_samples: int = 3, constraints: list = [], ship_tracker=None):
    sampler = FastSampler(
        board, ship_tracker=ship_tracker, seed=SEED
    )
    samples = sampler.get_weighted_samples(n_samples=n_samples, constraints=constraints)
    return samples

In [None]:
fig = heatmap(partial_board, ship_tracker=true_board.ship_tracker(partial_board))
fig.savefig("heatmap.png", dpi=300, bbox_inches="tight")

In [None]:
samples = sample_boards(partial_board, ship_tracker=true_board.ship_tracker(partial_board), n_samples=5)
for i, (board, weight) in enumerate(samples):
    print(f"Weight: {weight}")
    display(board)
    fig = board.to_figure(transparent=True)
    fig.savefig(f"sample_{i}.png", dpi=300, bbox_inches="tight")

In [None]:
BOARD_ID = "B03"

spotter = create_spotter(
    spotter_type="CodeSpotterModel",
    board_id=BOARD_ID,
    board_experiment="collaborative",
    llm="gpt-5",
    use_cot=True,
    json_path=None,
)

captain = create_captain(
    captain_type="MAPCaptain",
    seed=SEED,
    llm="gpt-5",
    board_id=BOARD_ID,
    map_samples=1000,
    prob_q_prob=None,
    eig_samples=1000,
    eig_k=10,
    json_path=None,
)

game = BattleshipGame(
    board_target=true_board,
    captain=captain,
    spotter=spotter,
)

In [None]:
game

In [None]:
heatmap(game.state)

In [None]:
while not game.is_done():
    game.next_stage()
    print(f"Stage {game.stage_index}")
    display(game)
    display(heatmap(game.state))
    break

In [None]:
code_question = spotter.translate(
    question=Question("Is the red ship vertical?"),
    board=game.state,
    history=game.history,
)

In [None]:
print(code_question.fn_str)

In [None]:
answer = code_question(true_board=true_board.board, partial_board=game.state.board)
print(answer)

In [None]:
heatmap(game.state, constraints=[(code_question, False)])