In [43]:
import nest_asyncio
from dotenv import load_dotenv

nest_asyncio.apply()
load_dotenv()

True

In [51]:
from datasets import load_dataset, Dataset
from pydantic_ai import Agent, RunContext
from transformers import AutoTokenizer
import asyncio
from itertools import chain
from collections import Counter
import json


class MathDataset:
    def __init__(
        self,
        data_path: str = "AI-MO/aimo-validation-aime",
        model_name: str = "openai:gpt-5-mini",
        tokenizer_name: str = "Qwen/Qwen3-0.6B",
    ) -> None:
        ds = load_dataset(data_path, split="train")
        self.questions = [str(question) for question in ds["problem"]]
        self.answers = [str(answer) for answer in ds["answer"]]
        self.agent = Agent(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    @property
    def num_unique_tokens(self) -> int:
        input_ids = self.tokenizer(self.questions)["input_ids"]
        return len(set(chain(*input_ids)))

    @property
    def token_counts(self) -> dict[str, int]:
        input_ids = self.tokenizer(self.questions)["input_ids"]
        return {self.tokenizer.decode(i): c for i, c in Counter(chain(*input_ids)).items()}

    def replace(self, old: str, new: str) -> None:
        self.questions = [question.replace(old, new) for question in self.questions]

    async def llm_transform(self, transform_instruction: str) -> None:
        prompts = [
            (
                "Transform the question according to the instruction.\n\n"
                "**The transformed question must preserve the exact mathematical formulation "
                "and all constraints of the original question. Do not alter numbers, variables, "
                "equations, or assumptions.**\n\n"
                f"# Question\n\n{question}\n\n"
                f"# Instruction\n\n{transform_instruction}\n\n"
                "Respond with the transformed question only."
            )
            for question in self.questions
        ]

        futures = [self.agent.run(prompt) for prompt in prompts]
        results = await asyncio.gather(*futures)
        self.questions = [result.output for result in results]


dataevolve_agent = Agent(
    "openai:gpt-5-mini",
    deps_type=MathDataset,
    instructions=(
        "You must optimize math questions to minimize the number if unique tokens while strictly preserving their "
        "mathematical content. Your goal is to reduce the vocab size to under 512 unique tokens. You must ensure "
        "the mathematical of all questions is unchanged, and that they still have the same answer."
    )
)


@dataevolve_agent.tool
def get_questions(ctx: RunContext[MathDataset]) -> str:
    """
    Return the current list of questions as a stringified Python list.

    What it does:
      - Serializes the in-memory dataset questions to a string (e.g., for inspection/logging).

    Returns:
      str: A JSON-like string representation of the questions list.

    Example:
      >>> questions = get_questions(ctx)
      >>> print(questions[:200])  # peek
    """
    return str(ctx.deps.questions)


@dataevolve_agent.tool
def get_num_unique_tokens(ctx: RunContext[MathDataset]) -> int:
    """
    Compute the number of unique tokenizer tokens across all questions.

    What it does:
      - Tokenizes every question with the dataset tokenizer.
      - Flattens token ids and counts unique ids to measure vocabulary breadth.

    Returns:
      int: Count of unique tokens in the current questions.

    Use when:
      - You want a single-number metric to track progress toward brevity.

    Example:
      >>> before = get_num_unique_tokens(ctx)
      >>> replace(ctx, "which is equal to", "=")
      >>> after = get_num_unique_tokens(ctx)
      >>> print(before, after)
    """
    return ctx.deps.num_unique_tokens


@dataevolve_agent.tool
def get_token_counts(ctx: RunContext[MathDataset]) -> str:
    """
    Return a frequency table of decoded tokens across all questions (as pretty-printed JSON).

    What it does:
      - Tokenizes all questions.
      - Aggregates counts per token id.
      - Decodes each token id to text for human-readable analysis.

    Returns:
      str: JSON string mapping token (text) → frequency, sorted by token text.

    Caveats:
      - Decoding single-token pieces may show special tokens or partial subwords.

    Example:
      >>> counts_json = get_token_counts(ctx)
      >>> print(counts_json[:500])
    """
    return json.dumps(ctx.deps.token_counts, indent=2, sort_keys=True)


@dataevolve_agent.tool
def replace(ctx: RunContext[MathDataset], old: str, new: str) -> None:
    """
    Perform a literal string replacement across all questions.

    What it does:
      - Replaces every occurrence of `old` with `new` in each question.

    Args:
      old (str): Substring to be replaced (literal match).
      new (str): Replacement text.

    Side effects:
      - Mutates the in-memory questions list.

    Use when:
      - Applying safe, pattern-free micro-edits (e.g., 'which equals' → '=').

    Example:
      >>> replace(ctx, "is equal to", "=")
    """
    ctx.deps.replace(old, new)


@dataevolve_agent.tool
async def llm_transform(ctx: RunContext[MathDataset], transform_instruction: str) -> None:
    """
    Apply an LLM-guided transformation to every question, under strict fidelity constraints.

    What it does:
      - For each question, prompts the LLM to transform text while preserving the exact
        mathematical formulation (numbers, variables, relations, and constraints).
      - Updates the questions list with the transformed outputs.

    Args:
      transform_instruction (str): A concise directive, e.g.,
        - "Shorten wording; replace phrases with symbols where safe."
        - "Remove filler; keep all constraints; keep LaTeX intact."

    Side effects:
      - Mutates the in-memory questions list with LLM outputs.

    Example:
      >>> await llm_transform(ctx, "Shorten phrasing; prefer symbols; preserve all math.")
    """
    await ctx.deps.llm_transform(transform_instruction)

In [None]:
math_ds = MathDataset()
print(dataevolve_agent.run_sync(deps=math_ds).output)

In [None]:
math_ds.num_unique_tokens

1167