In [3]:
from dataclasses import dataclass
import tiktoken

In [4]:

@dataclass
class BudgetResult:
    in_per_convo: int
    out_per_convo: int
    total_input_tokens: int
    total_output_tokens: int
    total_cost: float

def estimate_budget(
    n_games: int,
    n_baselines: int,
    num_conversations: int,
    max_turns: int,
    max_thinking_tokens: int,
    max_response_tokens: int,
    len_system_prompt_1: int,
    len_system_prompt_2: int,
    input_token_cost: float,
    output_token_cost: float,
    len_initial_message: int = 0,
) -> BudgetResult:
    """
    Worst-case token budget for 2-LLM simulated dialogues where, per turn:
      M1: think (≤ H) -> respond (≤ R)
      M2: think (≤ H) -> respond (≤ R)
    Only responses are appended to the running transcript. Each response pass
    also sees that turn's think trace for that model.

    Parameters
    ----------
    num_conversations : C
    max_turns : T
    max_thinking_tokens : H
    max_response_tokens : R
    len_system_prompt_1 : S1
    len_system_prompt_2 : S2
    input_token_cost : c_in   (cost per input token)
    output_token_cost : c_out (cost per output token)
    len_initial_message : L0  (seed user message at start of each convo; default 0)

    Returns
    -------
    BudgetResult with:
      - in_per_convo
      - out_per_convo
      - total_input_tokens
      - total_output_tokens
      - total_cost
    """
    # Aliases
    n_games = int(n_games)
    n_baselines = int(n_baselines)
    C = int(num_conversations)
    T = int(max_turns)
    H = int(max_thinking_tokens)
    R = int(max_response_tokens)
    S1 = int(len_system_prompt_1)
    S2 = int(len_system_prompt_2)
    L0 = int(len_initial_message)
    c_in = float(input_token_cost)
    c_out = float(output_token_cost)

    # Guard against negatives
    for name, v in {
        "num_conversations": C, "max_turns": T, "max_thinking_tokens": H,
        "max_response_tokens": R, "len_system_prompt_1": S1,
        "len_system_prompt_2": S2, "len_initial_message": L0
    }.items():
        if v < 0:
            raise ValueError(f"{name} must be >= 0")

    # Sum_{t=1..T} (t-1) = T(T-1)/2
    triangular = T * (T - 1) // 2

    # Input tokens per conversation (upper bound)
    # For each turn t:
    #   M1 think:    S1 + L0 + 2R(t-1)
    #   M1 respond:  S1 + L0 + H + 2R(t-1)
    #   M2 think:    S2 + L0 + R + 2R(t-1)      (M1's new response included)
    #   M2 respond:  S2 + L0 + H + R + 2R(t-1)
    in_per_convo = (
        T * (2*S1 + 2*S2 + 4*L0 + 2*H + 2*R)
        + 4*R * triangular // 1  # integer already; //1 keeps int type explicit
    )
    in_per_convo = int(in_per_convo)

    # Output tokens per conversation (upper bound)
    # Per turn: H (M1 think) + R (M1 resp) + H (M2 think) + R (M2 resp) = 2H + 2R
    out_per_convo = T * (2*H + 2*R)

    # Totals across conversations 
    total_input_tokens = C * in_per_convo * n_games * n_baselines
    total_output_tokens = C * out_per_convo * n_games * n_baselines

    total_cost = total_input_tokens * c_in + total_output_tokens * c_out

    return BudgetResult(
        in_per_convo=in_per_convo,
        out_per_convo=out_per_convo,
        total_input_tokens=total_input_tokens,
        total_output_tokens=total_output_tokens,
        total_cost=total_cost,
    )


In [5]:
from hangman.prompts.

SyntaxError: invalid syntax (2258881609.py, line 1)

In [10]:
n_baselines=8
n_games=2
num_conversations=50
max_turns=20
max_thinking_tokens=1024
max_response_tokens=512
len_system_prompt_1=300
len_system_prompt_2=300
input_token_cost=0.072 / 1e6   
output_token_cost=0.28 / 1e6  
len_initial_message=100
# --- Example ---
res = estimate_budget(
    n_games=n_games,
    n_baselines=n_baselines,
    num_conversations=num_conversations,
    max_turns=max_turns,
    max_thinking_tokens=max_thinking_tokens,
    max_response_tokens=max_response_tokens,
    len_system_prompt_1=len_system_prompt_1,
    len_system_prompt_2=len_system_prompt_2,
    input_token_cost=input_token_cost,
    output_token_cost=output_token_cost,
    len_initial_message=len_initial_message,
)
print(res)

BudgetResult(in_per_convo=482560, out_per_convo=61440, total_input_tokens=386048000, total_output_tokens=49152000, total_cost=41.558015999999995)
