In [None]:
import random

import nltk

MAX_LENGTH = 512

# Define a PCFG with probabilities favoring recursion.
grammar = nltk.PCFG.fromstring("""
S -> S S            [0.40]
S -> '(' S ')'      [0.20]
S -> '[' S ']'      [0.20]
S -> 'a'            [0.10]
S -> 'b'            [0.10]
""")


def min_length(symbol):
    """
    Returns the minimal number of tokens that can be generated from a symbol.
    For terminals, it is 1.
    For nonterminals, since our grammar includes S -> 'a' and S -> 'b', the minimal length is 1.
    """
    if not isinstance(symbol, nltk.Nonterminal):
        return 1
    return 1


def generate(symbol, current_length=0, max_length=MAX_LENGTH):
    """
    Recursively generate a sequence of tokens from the given symbol,
    ensuring that the total length does not exceed max_length.

    current_length: tokens generated so far.
    """
    # If symbol is terminal, return it.
    if not isinstance(symbol, nltk.Nonterminal):
        return [symbol], current_length + 1

    # Gather productions for the nonterminal.
    productions = grammar.productions(lhs=symbol)

    # Filter productions that are feasible given the remaining budget.
    feasible = []
    for prod in productions:
        # Minimal length required for this production
        req = sum(min_length(sym) for sym in prod.rhs())
        if current_length + req <= max_length:
            feasible.append(prod)

    # If no production is feasible, force a terminal production.
    if not feasible:
        # Choose one of the terminal productions ('a' or 'b') if available.
        term_prods = [
            prod
            for prod in productions
            if all(not isinstance(sym, nltk.Nonterminal) for sym in prod.rhs())
        ]
        if term_prods:
            prod = random.choice(term_prods)
        else:
            # Fallback: pick the first production (may exceed limit slightly)
            prod = productions[0]
    else:
        # Normalize probabilities among feasible productions.
        weights = [prod.prob() for prod in feasible]
        prod = random.choices(feasible, weights=weights, k=1)[0]

    result = []
    # To decide how many tokens we can spend on the expansion of each symbol,
    # we compute the minimal required tokens for the remaining symbols.
    remaining_symbols = prod.rhs()
    for i, sym in enumerate(remaining_symbols):
        min_for_rest = sum(min_length(s) for s in remaining_symbols[i + 1 :])
        available = max_length - current_length - min_for_rest
        # Recursively generate for symbol 'sym' with the available budget.
        generated, current_length = generate(sym, current_length, current_length + available)
        result.extend(generated)
    return result, current_length


# Generate a challenging sequence with maximum length of 512 tokens.
challenging_sequence, length = generate(nltk.Nonterminal("S"), 0, MAX_LENGTH)
print(f"Generated sequence with {length} tokens (max {MAX_LENGTH}):")
print(" ".join(challenging_sequence))