In [None]:
!pip install  gymnax  optax jax
# !pip uninstall -y jax jaxlib jax-cuda12-plugin
# # The new, simplified command for installing JAX with CUDA support
# !pip install -U "jax[cuda]"

Collecting gymnax
  Downloading gymnax-0.0.9-py3-none-any.whl.metadata (19 kB)
Downloading gymnax-0.0.9-py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.6/86.6 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: gymnax
Successfully installed gymnax-0.0.9


In [None]:
!pip install --upgrade rlax

Collecting rlax
  Downloading rlax-0.1.8-py3-none-any.whl.metadata (8.3 kB)
Collecting absl-py>=2.3.1 (from rlax)
  Downloading absl_py-2.3.1-py3-none-any.whl.metadata (3.3 kB)
Collecting distrax>=0.1.7 (from rlax)
  Downloading distrax-0.1.7-py3-none-any.whl.metadata (14 kB)
Collecting dm_env>=1.6 (from rlax)
  Downloading dm_env-1.6-py3-none-any.whl.metadata (966 bytes)
Collecting jax>=0.7.0 (from rlax)
  Downloading jax-0.8.0-py3-none-any.whl.metadata (13 kB)
Collecting jaxlib>=0.7.0 (from rlax)
  Downloading jaxlib-0.8.0-cp312-cp312-manylinux_2_27_x86_64.whl.metadata (1.3 kB)
Collecting tfp-nightly (from distrax>=0.1.7->rlax)
  Downloading tfp_nightly-0.26.0.dev20251017-py2.py3-none-any.whl.metadata (13 kB)
Downloading rlax-0.1.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.2/116.2 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading absl_py-2.3.1-py3-none-any.whl (135 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

#Utils

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

propositions = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"]

LTL_BASE_VOCAB = {
    "and": 0, "or": 1, "not": 2, "next": 3, "until": 4,
    "always": 5, "eventually": 6, "True": 7, "False": 8,
}

PROP_OFFSET = len(LTL_BASE_VOCAB)
for i, el in enumerate(propositions):
    LTL_BASE_VOCAB[el] = PROP_OFFSET + i
# Define constants AND OR ... for easier access
globals().update({k.upper(): v for k, v in LTL_BASE_VOCAB.items() if v < PROP_OFFSET})
NUM_PROPS = len(propositions)
TRUE_VAL = LTL_BASE_VOCAB["True"]
FALSE_VAL = LTL_BASE_VOCAB["False"]
VOCAB_SIZE = len(LTL_BASE_VOCAB)
VOCAB_INV = {v: k for k, v in LTL_BASE_VOCAB.items()}

# Use jax.numpy for constants to be used in JIT'd code
_is_unary_op_np = np.zeros(VOCAB_SIZE, dtype=bool)
_is_unary_op_np[NOT] = True
_is_unary_op_np[NEXT] = True
_is_unary_op_np[EVENTUALLY] = True
_is_unary_op_np[ALWAYS] = True
IS_UNARY_OP = jnp.array(_is_unary_op_np)

_is_binary_op_np = np.zeros(VOCAB_SIZE, dtype=bool)
_is_binary_op_np[AND] = True
_is_binary_op_np[OR] = True
_is_binary_op_np[UNTIL] = True
IS_BINARY_OP = jnp.array(_is_binary_op_np)

MAX_NODES = 200 # Maximum size of the formula array

def encode_letters(letter_str: str) -> tuple:
    """Helper function to encode a string of letters into a tuple of IDs."""
    return tuple(LTL_BASE_VOCAB[l] for l in letter_str)

def encode_formula(formula):
        """Recursively encodes a formula from string/tuple to integer representation."""
        if isinstance(formula, str):
            return LTL_BASE_VOCAB[formula]
        if isinstance(formula, tuple): return tuple(encode_formula(f) for f in formula)
        raise ValueError(f"Unsupported element type: {formula}")


def encode_formula_to_array(formula, vocab, array, index=0):
    if isinstance(formula, int):
        array[index] = [formula, 0, 0]
        return index + 1

    op, children = formula[0], formula[1:]

    if len(children) == 1:
        array[index] = [op, index + 1, 0]
        return encode_formula_to_array(children[0], vocab, array, index + 1)
    elif len(children) == 2:
        left_index = index + 1
        right_start_index = encode_formula_to_array(children[0], vocab, array, left_index)
        array[index] = [op, left_index, right_start_index]
        return encode_formula_to_array(children[1], vocab, array, right_start_index)
    raise ValueError("Formulas must have 1 or 2 children")


def decode_array_to_formula(array, node_index, num_valid_nodes, visited_nodes=None):
    if visited_nodes is None:
        visited_nodes = set()
    node_index=int(node_index)
    if not (0 <= node_index < num_valid_nodes):
        return f"invalid_ref_{node_index}"

    # A cycle is detected only if we revisit a non-terminal node.
    # Shared terminal nodes like 'True' or 'False' are valid.
    op_val, left_idx, right_idx = array[node_index]

    op_val, left_idx, right_idx = int(op_val), int(left_idx), int(right_idx)
    is_terminal = not (IS_UNARY_OP[op_val] or IS_BINARY_OP[op_val])

    if not is_terminal and node_index in visited_nodes:
        return f"ref_{node_index}"

    visited_nodes.add(node_index)

    op_str = VOCAB_INV.get(op_val, f"p{op_val}")

    if is_terminal:
        # Once we decode a terminal, we can remove it from the visited set
        # to allow it to be decoded again if shared by another branch.
        # This is not strictly necessary with the above check, but is good practice.
        visited_nodes.remove(node_index)
        return op_str

    if IS_UNARY_OP[op_val]:
        child = decode_array_to_formula(array, int(left_idx), num_valid_nodes, visited_nodes)
        visited_nodes.remove(node_index)
        return (op_str, child)

    if IS_BINARY_OP[op_val]:
        left_child = decode_array_to_formula(array, int(left_idx), num_valid_nodes, visited_nodes)
        right_child = decode_array_to_formula(array, int(right_idx), num_valid_nodes, visited_nodes)
        visited_nodes.remove(node_index)
        return (op_str, left_child, right_child)

    # Should be unreachable if logic is correct
    visited_nodes.remove(node_index)
    return op_str




# sampler

In [None]:
"""
This class is responsible for sampling LTL formulas typically from
given template(s).

@ propositions: The set of propositions to be used in the sampled
                formula at random.
"""
import random


class LTLSampler():
    def __init__(self, propositions):
        self.propositions = propositions

    def sample(self):
        raise NotImplementedError


# Samples from one of the other samplers at random. The other samplers are sampled by their default args.
class SuperSampler(LTLSampler):
    def __init__(self, propositions):
        super().__init__(propositions)
        self.reg_samplers = getRegisteredSamplers(self.propositions)

    def sample(self):
        return random.choice(self.reg_samplers).sample()

# This class samples formulas of form (or, op_1, op_2), where op_1 and 2 can be either specified as samplers_ids
# or by default they will be sampled at random via SuperSampler.
class OrSampler(LTLSampler):
    def __init__(self, propositions, sampler_ids = ["SuperSampler"]*2):
        super().__init__(propositions)
        self.sampler_ids = sampler_ids

    def sample(self):
        return ('or', getLTLSampler(self.sampler_ids[0], self.propositions).sample(),
                        getLTLSampler(self.sampler_ids[1], self.propositions).sample())

# This class generates random LTL formulas using the following template:
#   ('until',('not','a'),('and', 'b', ('until',('not','c'),'d')))
# where p1, p2, p3, and p4 are randomly sampled propositions
class DefaultSampler(LTLSampler):
    def sample(self):
        p = random.sample(self.propositions,4)
        return ('until',('not',p[0]),('and', p[1], ('until',('not',p[2]),p[3])))

# This class generates random conjunctions of Until-Tasks.
# Each until tasks has *n* levels, where each level consists
# of avoiding a proposition until reaching another proposition.
#   E.g.,
#      Level 1: ('until',('not','a'),'b')
#      Level 2: ('until',('not','a'),('and', 'b', ('until',('not','c'),'d')))
#      etc...
# The number of until-tasks, their levels, and their propositions are randomly sampled.
# This code is a generalization of the DefaultSampler---which is equivalent to UntilTaskSampler(propositions, 2, 2, 1, 1)
class UntilTaskSampler(LTLSampler):
    def __init__(self, propositions, min_levels=2, max_levels=2, min_conjunctions=2 , max_conjunctions=2):
        super().__init__(propositions)
        self.levels       = (int(min_levels), int(max_levels))
        self.conjunctions = (int(min_conjunctions), int(max_conjunctions))
        assert 2*int(max_levels)*int(max_conjunctions) <= len(propositions), "The domain does not have enough propositions!"

    def sample(self):
        # Sampling a conjuntion of *n_conjs* (not p[0]) Until (p[1]) formulas of *n_levels* levels
        n_conjs = random.randint(*self.conjunctions)
        p = random.sample(self.propositions,2*self.levels[1]*n_conjs)
        ltl = None
        b = 0
        for i in range(n_conjs):
            n_levels = random.randint(*self.levels)
            # Sampling an until task of *n_levels* levels
            until_task = ('until',('not',p[b]),p[b+1])
            b +=2
            for j in range(1,n_levels):
                until_task = ('until',('not',p[b]),('and', p[b+1], until_task))
                b +=2
            # Adding the until task to the conjunction of formulas that the agent have to solve
            if ltl is None: ltl = until_task
            else:           ltl = ('and',until_task,ltl)
        return ltl


# This class generates random LTL formulas that form a sequence of actions.
# @ min_len, max_len: min/max length of the random sequence to generate.
class SequenceSampler(LTLSampler):
    def __init__(self, propositions, min_len=2, max_len=4):
        super().__init__(propositions)
        self.min_len = int(min_len)
        self.max_len = int(max_len)

    def sample(self):
        length = random.randint(self.min_len, self.max_len)
        seq = ""

        while len(seq) < length:
            c = random.choice(self.propositions)
            if len(seq) == 0 or seq[-1] != c:
                seq += c

        ret = self._get_sequence(seq)

        return ret

    def _get_sequence(self, seq):
        if len(seq) == 1:
            return ('eventually',seq)
        return ('eventually',('and', seq[0], self._get_sequence(seq[1:])))

# This generates several sequence tasks which can be accomplished in parallel.
# e.g. in (eventually (a and eventually c)) and (eventually b)
# the two sequence tasks are "a->c" and "b".
class EventuallySampler(LTLSampler):
    def __init__(self, propositions, min_levels = 1, max_levels=4, min_conjunctions=1, max_conjunctions=3):
        super().__init__(propositions)
        assert(len(propositions) >= 3)
        self.conjunctions = (int(min_conjunctions), int(max_conjunctions))
        self.levels = (int(min_levels), int(max_levels))

    def sample(self):
        conjs = random.randint(*self.conjunctions)
        ltl = None

        for i in range(conjs):
            task = self.sample_sequence()
            if ltl is None:
                ltl = task
            else:
                ltl = ('and',task,ltl)
        return ltl


    def sample_sequence(self):
        length = random.randint(*self.levels)
        seq = []

        last = []
        while len(seq) < length:
            # Randomly replace some propositions with a disjunction to make more complex formulas
            population = [p for p in self.propositions if p not in last]

            if random.random() < 0.25:
                c = random.sample(population, 2)
            else:
                c = random.sample(population, 1)

            seq.append(c)
            last = c

        ret = self._get_sequence(seq)

        return ret

    def _get_sequence(self, seq):
        term = seq[0][0] if len(seq[0]) == 1 else ('or', seq[0][0], seq[0][1])
        if len(seq) == 1:
            return ('eventually',term)
        return ('eventually',('and', term, self._get_sequence(seq[1:])))


class AdversarialEnvSampler(LTLSampler):
    def sample(self):
        p = random.randint(0,1)
        if p == 0:
            return ('eventually', ('and', 'a', ('eventually', 'b')))
        else:
            return ('eventually', ('and', 'a', ('eventually', 'c')))

def getRegisteredSamplers(propositions):
    return [SequenceSampler(propositions),
            UntilTaskSampler(propositions),
            DefaultSampler(propositions),
            EventuallySampler(propositions)]

# The LTLSampler factory method that instantiates the proper sampler
# based on the @sampler_id.
def getLTLSampler(sampler_id, propositions):
    if sampler_id is None:
        return DefaultSampler(propositions)
    tokens = ["Default"]
    if (sampler_id != None):
        tokens = sampler_id.split("_")

    # Don't change the order of ifs here otherwise the OR sampler will fail
    if (tokens[0] == "OrSampler"):
        return OrSampler(propositions)
    elif ("_OR_" in sampler_id): # e.g., Sequence_2_4_OR_UntilTask_3_3_1_1
        sampler_ids = sampler_id.split("_OR_")
        return OrSampler(propositions, sampler_ids)
    elif (tokens[0] == "Sequence"):
        return SequenceSampler(propositions, tokens[1], tokens[2])
    elif (tokens[0] == "Until"):
        return UntilTaskSampler(propositions, tokens[1], tokens[2], tokens[3], tokens[4])
    elif (tokens[0] == "SuperSampler"):
        return SuperSampler(propositions)
    elif (tokens[0] == "Adversarial"):
        return AdversarialEnvSampler(propositions)
    elif (tokens[0] == "Eventually"):
        return EventuallySampler(propositions, tokens[1], tokens[2], tokens[3], tokens[4])
    else: # "Default"
        return DefaultSampler(propositions)



In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from sfl.envs.ltl_env.utils import *
from jax import random
class JaxUntilTaskSampler():
    def __init__(self, propositions, min_levels=1, max_levels=3, min_conjunctions=1, max_conjunctions=2):
        self.prop_tokens = jnp.array([LTL_BASE_VOCAB[p] for p in propositions],dtype=jnp.int32)
        self.num_props = len(propositions)
        self.min_levels = min_levels
        self.max_levels = max_levels
        self.min_conjunctions = min_conjunctions
        self.max_conjunctions = max_conjunctions
        # Probability of choosing a disjunction ('or') of two propositions
        self.disjunction_prob = 0.25
        self.max_props_needed = 2 * max_levels * max_conjunctions
        assert self.max_props_needed <= len(self.prop_tokens), "Not enough propositions for the given max settings!"


    @partial(jax.jit, static_argnames=("self"))
    def sample(self, key):
        """
        JAX-compatible function to sample a complex 'Until' task.

        This function builds a formula of the form:
        (Task_1) AND (Task_2) AND ... AND (Task_n_conjs)
        where each Task_i is a nested 'Until' formula:
        U(!p1, p2 AND U(!p3, p4 AND ...))

        Args:
            key: A jax.random.PRNGKey for random operations.
            min_levels, max_levels: Min/max nesting depth for each 'Until' sub-formula.
            min_conjunctions, max_conjunctions: Min/max number of 'Until' sub-formulas to be joined by 'AND'.

        Returns:
            A tuple containing:
            - formula_array (jnp.ndarray): The encoded formula in a (MAX_NODES, 3) array.
            - num_nodes (int): The number of valid nodes used in the array.
            - root_idx (int): The index of the root node of the final formula.
            - num_conjuncts (int): The number of conjunctions in the formula.
            - num_levels (int): The total number of levels across all conjunctions.
        """
        # --- 1. Initial Setup and Random Sampling ---
        key, n_conjs_key, p_key = jax.random.split(key, 3)

        # Sample the number of conjunctions
        n_conjs = jax.random.randint(n_conjs_key, (), self.min_conjunctions, self.max_conjunctions + 1)

        # Sample all propositions needed upfront without replacement.
        # We must sample the maximum possible number to ensure a static shape for JIT.

        p = jax.random.choice(p_key, self.prop_tokens, shape=(self.max_props_needed,), replace=False)

        # --- 2. Define Loop Bodies for jax.lax.fori_loop ---

        def build_nested_until_task(key, formula_array, start_node_idx, start_prop_idx, n_levels):
            """Builds one nested 'Until' sub-formula."""

            # Base case: U(not p[b], p[b+1])
            # This requires 4 nodes: p[b], p[b+1], NOT, and UNTIL.
            p1_idx = start_node_idx
            p0_idx = start_node_idx + 1
            not_p0_idx = start_node_idx + 2
            until_root_idx = start_node_idx + 3

            formula_array = formula_array.at[p1_idx].set(jnp.array([p[start_prop_idx + 1], 0, 0]))
            formula_array = formula_array.at[p0_idx].set(jnp.array([p[start_prop_idx], 0, 0]))
            formula_array = formula_array.at[not_p0_idx].set(jnp.array([NOT, p0_idx, 0]))
            formula_array = formula_array.at[until_root_idx].set(jnp.array([UNTIL, not_p0_idx, p1_idx]))

            prop_idx = start_prop_idx + 2
            node_idx = start_node_idx + 4

            # Inner loop state: (key, formula_array, node_idx, prop_idx, current_until_root)
            initial_inner_carry = (key, formula_array, node_idx, prop_idx, until_root_idx)

            def inner_loop_body(j, carry):
                """Adds one level of nesting: U(not p_new, (and p_next, old_until_task))"""
                l_key, l_formula_array, l_node_idx, l_prop_idx, l_until_root_idx = carry

                # This requires 5 new nodes: p_next, p_new, NOT, AND, UNTIL
                p_next_idx      = l_node_idx
                p_new_idx       = l_node_idx + 1
                not_p_new_idx   = l_node_idx + 2
                and_idx         = l_node_idx + 3
                new_until_root  = l_node_idx + 4

                # p_next
                l_formula_array = l_formula_array.at[p_next_idx].set(jnp.array([p[l_prop_idx + 1], 0, 0]))
                # p_new
                l_formula_array = l_formula_array.at[p_new_idx].set(jnp.array([p[l_prop_idx], 0, 0]))
                # not p_new
                l_formula_array = l_formula_array.at[not_p_new_idx].set(jnp.array([NOT, p_new_idx, 0]))
                # p_next AND old_until_task
                l_formula_array = l_formula_array.at[and_idx].set(jnp.array([AND, p_next_idx, l_until_root_idx]))
                # U(not p_new, (...))
                l_formula_array = l_formula_array.at[new_until_root].set(jnp.array([UNTIL, not_p_new_idx, and_idx]))

                return (l_key, l_formula_array, l_node_idx + 5, l_prop_idx + 2, new_until_root)

            # Loop n_levels - 1 times to add the nested layers
            key, formula_array, node_idx, prop_idx, until_root_idx = jax.lax.fori_loop(
                0, n_levels - 1, inner_loop_body, initial_inner_carry
            )
            return key, formula_array, node_idx, prop_idx, until_root_idx

        def outer_loop_body(i, carry):
            """Builds one 'Until' task and ANDs it with the main formula."""
            key, formula_array, node_idx, prop_idx, ltl_root_idx, total_levels = carry
            key, n_levels_key, build_key = jax.random.split(key, 3)

            # Sample levels for this specific sub-formula
            n_levels = jax.random.randint(n_levels_key, (), self.min_levels, self.max_levels + 1)

            new_total_levels = total_levels + n_levels

            # Build the sub-formula
            build_key, formula_array, new_node_idx, new_prop_idx, until_task_root = build_nested_until_task(
                build_key, formula_array, node_idx, prop_idx, n_levels
            )

            # If this is the first task, it becomes the root.
            # Otherwise, create an AND node to join it with the existing formula.
            def first_task_fn(_):
                return until_task_root, formula_array, new_node_idx

            def subsequent_task_fn(_):
                and_node_idx = new_node_idx
                new_array = formula_array.at[and_node_idx].set(jnp.array([AND, until_task_root, ltl_root_idx]))
                return and_node_idx, new_array, new_node_idx + 1

            new_ltl_root, formula_array, node_idx = jax.lax.cond(
                ltl_root_idx == -1,  # Use -1 as a sentinel for the first task
                first_task_fn,
                subsequent_task_fn,
                operand=None
            )

            return key, formula_array, node_idx, new_prop_idx, new_ltl_root, new_total_levels

        # --- 3. Execute Main Loop ---

        # Initial state for the main loop
        # carry = (key, formula_array, node_idx, prop_idx, ltl_root_idx)
        initial_carry = (
            key,
            jnp.full((MAX_NODES, 3), -1, dtype=jnp.int32), # Formula array
            0,                                             # Next available node index
            0,                                             # Next available proposition index
            -1,
             0,                                                                                         # Root of the combined formula
        )

        # Run the loop for n_conjs iterations
        _, final_array, num_nodes, _, root_idx, total_levels = jax.lax.fori_loop(
            0, n_conjs, outer_loop_body, initial_carry
        )
        avg_levels = total_levels.astype(jnp.float32) / n_conjs.astype(jnp.float32)
        return final_array, num_nodes, root_idx, n_conjs, avg_levels



class JaxEventuallySampler:
    """
    A class to generate complex LTL formulas using a JIT-compiled JAX sampler.
    The sampler's static configuration is provided during initialization, and
    the JIT compilation happens once.
    """
    def __init__(self, propositions, min_levels=1, max_levels=5, min_conjunctions=1, max_conjunctions=4):
        self.propositions = jnp.array([LTL_BASE_VOCAB[p] for p in propositions],dtype=jnp.int32)
        self.min_levels = min_levels
        self.max_levels = max_levels
        self.min_conjunctions = min_conjunctions
        self.max_conjunctions = max_conjunctions
        assert(len(propositions) >= 3)

        self._jitted_sampler = partial(
            jax.jit(self._static_sampler, static_argnames=(
                "min_levels", "max_levels", "min_conjunctions", "max_conjunctions"
            )),
            min_levels=self.min_levels,
            max_levels=self.max_levels,
            min_conjunctions=self.min_conjunctions,
            max_conjunctions=self.max_conjunctions,
            propositions=self.propositions
        )

    def sample(self, key):
        """
        Generates a new LTL formula sample.
        Args:
            key (jax.random.PRNGKey): The random key for this specific sample generation.
        Returns:
            A tuple of (formula_array, num_nodes, root_id).
        """
        return self._jitted_sampler(key=key)

    @staticmethod
    def _static_sampler(key, propositions, min_levels, max_levels, min_conjunctions, max_conjunctions):
        """
        The core JAX-jittable static method to generate LTL formulas.
        """
        formula_array = jnp.zeros((MAX_NODES, 3), dtype=jnp.int32)

        key, subkey = random.split(key)
        num_conjs = random.randint(subkey, shape=(), minval=min_conjunctions, maxval=max_conjunctions + 1)

        def _sample_sequence_task(carry, _):
            key, formula_array, next_node_idx = carry
            key, subkey = random.split(key)
            seq_length = random.randint(subkey, shape=(), minval=min_levels, maxval=max_levels + 1)

            def _generate_seq_body(i, state):
                key, formula_array, next_node_idx, last_prop_ids, seq_node_ids = state
                mask = jnp.all(propositions[:, None] != last_prop_ids[None, :], axis=1)
                safe_mask = jnp.where(mask.sum() == 0, jnp.ones_like(mask), mask)
                probs = safe_mask.astype(jnp.float32) / safe_mask.sum()

                key, subkey_cond, subkey_disj = random.split(key, 3)

                def _create_disjunction(k):
                    p1, p2 = random.choice(k, propositions, shape=(2,), replace=False, p=probs)
                    node_idx, p1_idx, p2_idx = next_node_idx, next_node_idx + 1, next_node_idx + 2
                    arr = formula_array.at[node_idx].set(jnp.array([OR, p1_idx, p2_idx]))
                    arr = arr.at[p1_idx].set(jnp.array([p1, 0, 0]))
                    arr = arr.at[p2_idx].set(jnp.array([p2, 0, 0]))
                    return arr, node_idx, next_node_idx + 3, jnp.array([p1, p2])

                def _create_single_prop(k):
                    p1 = random.choice(k, propositions, shape=(1,), p=probs)[0]
                    node_idx = next_node_idx
                    arr = formula_array.at[node_idx].set(jnp.array([p1, 0, 0]))
                    return arr, node_idx, next_node_idx + 1, jnp.array([p1, -1])

                arr, node_id, next_idx, new_last_props = jax.lax.cond(
                    random.uniform(subkey_cond) < 0.25, _create_disjunction, _create_single_prop, subkey_disj)
                seq_node_ids = seq_node_ids.at[i].set(node_id)
                return key, arr, next_idx, new_last_props, seq_node_ids

            init_seq_state = (key, formula_array, next_node_idx, jnp.array([-1, -1]), jnp.full((max_levels,), -1, dtype=jnp.int32))
            key, formula_array, next_node_idx, _, seq_node_ids = jax.lax.fori_loop(0, seq_length, _generate_seq_body, init_seq_state)

            def _build_nested_formula(i, state):
                rev_i = seq_length - 2 - i
                _, formula_array, next_node_idx, current_root_id = state
                prop_node_id = seq_node_ids[rev_i]
                and_node_id = next_node_idx
                formula_array = formula_array.at[and_node_id].set(jnp.array([AND, prop_node_id, current_root_id]))
                eventually_node_id = next_node_idx + 1
                formula_array = formula_array.at[eventually_node_id].set(jnp.array([EVENTUALLY, and_node_id, 0]))
                return key, formula_array, next_node_idx + 2, eventually_node_id

            last_prop_node_id = seq_node_ids[seq_length - 1]
            initial_root_id = next_node_idx
            formula_array = formula_array.at[initial_root_id].set(jnp.array([EVENTUALLY, last_prop_node_id, 0]))

            init_build_state = (key, formula_array, next_node_idx + 1, initial_root_id)

            _, formula_array, next_node_idx, final_root_id = jax.lax.cond(
                seq_length > 1,
                lambda: jax.lax.fori_loop(0, seq_length - 1, _build_nested_formula, init_build_state),
                lambda: init_build_state)

            return (key, formula_array, next_node_idx), final_root_id, seq_length

        def _main_conj_loop_body(i, state):
            key, formula_array, next_node_idx, overall_root_id, total_levels = state
            (key, formula_array, next_node_idx), new_task_root_id, seq_length = _sample_sequence_task((key, formula_array, next_node_idx), 0)
            new_total_levels = total_levels + seq_length
            def _combine_with_and(op):
                prev_root_id, new_root_id, arr, idx = op
                and_node_id = idx
                arr = arr.at[and_node_id].set(jnp.array([AND, new_root_id, prev_root_id]))
                return arr, idx + 1, and_node_id

            def _first_task(op):
                _, new_root_id, arr, idx = op
                return arr, idx, new_root_id

            formula_array, next_node_idx, overall_root_id = jax.lax.cond(
                i > 0, _combine_with_and, _first_task, (overall_root_id, new_task_root_id, formula_array, next_node_idx))
            return key, formula_array, next_node_idx, overall_root_id, new_total_levels

        init_main_state = (key, formula_array, 0, -1, 0)
        key, formula_array, num_nodes, root_id, total_levels= jax.lax.fori_loop(0, num_conjs, _main_conj_loop_body, init_main_state)
        avg_levels = total_levels.astype(jnp.float32) / num_conjs.astype(jnp.float32)
        return formula_array, num_nodes, root_id, num_conjs, avg_levels



# progression

In [None]:
import jax
import jax.numpy as jnp
import jax.lax as lax
from dataclasses import dataclass
from typing import Dict, Tuple, List, Union
from functools import partial
import numpy as np
import spot
from sfl.envs.ltl_env.utils import *

def simplify_spot(array, node_index, num_valid_nodes):
    tuple_string_format=decode_array_to_formula(array, node_index, num_valid_nodes)
    array=np.array(array)
    ltl_spot = _get_spot_format(tuple_string_format)
    f = spot.formula(ltl_spot)
    f = spot.simplify(f)
    ltl_spot = f.__format__("l")
    ltl_std,r = _get_std_format(ltl_spot.split(' '))
    new_array = np.zeros_like(array)
    num_nodes=encode_formula_to_array(encode_formula(ltl_std), LTL_BASE_VOCAB, array)
    array=jnp.array(array)
    return array, 0, num_nodes

def spotify(ltl_formula):
    ltl_spot = _get_spot_format(ltl_formula)
    f = spot.formula(ltl_spot)
    f = spot.simplify(f)
    ltl_spot = f.__format__("l")
    # return ltl_spot
    return f#.to_str('latex')


def _get_spot_format(ltl_std):
    ltl_spot = str(ltl_std).replace("(","").replace(")","").replace(",","")
    ltl_spot = ltl_spot.replace("'until'","U").replace("'not'","!").replace("'or'","|").replace("'and'","&")
    ltl_spot = ltl_spot.replace("'next'","X").replace("'eventually'","F").replace("'always'","G").replace("'True'","t").replace("'False'","f").replace("\'","\"")
    return ltl_spot

def _get_std_format(ltl_spot):

    s = ltl_spot[0]
    r = ltl_spot[1:]

    if s in ["X","U","&","|"]:
        v1,r1 = _get_std_format(r)
        v2,r2 = _get_std_format(r1)
        if s == "X": op = 'next'
        if s == "U": op = 'until'
        if s == "&": op = 'and'
        if s == "|": op = 'or'
        return (op,v1,v2),r2

    if s in ["F","G","!"]:
        v1,r1 = _get_std_format(r)
        if s == "F": op = 'eventually'
        if s == "G": op = 'always'
        if s == "!": op = 'not'
        return (op,v1),r1

    if s == "f":
        return 'False', r

    if s == "t":
        return 'True', r

    if s[0] == '"':
        return s.replace('"',''), r

    assert False, "Format error in spot2std"



@jax.jit
def progress_and_clean_jax(formula_array, truth_assignment, root_index, num_nodes):
    """
    The main JIT-compiled function that orchestrates the workflow.
    It calls the core JAX logic and then the NumPy callback for simplification.
    """
    # 1. Run the initial JAX-compatible part of your logic
    dirty_root_idx, dirty_array, dirty_num_nodes = jax_static_iterative_progress_no_copy(
        formula_array, truth_assignment, root_index, num_nodes
    )


    # 2. Define the shapes and dtypes for the callback's output.
    #    This is the "contract" that JAX needs to compile the rest of the graph.
    result_shape_and_dtype = (
        jax.ShapeDtypeStruct(dirty_array.shape, dirty_array.dtype),
        jax.ShapeDtypeStruct(dirty_root_idx.shape, dirty_root_idx.dtype),
        jax.ShapeDtypeStruct(dirty_num_nodes.shape, dirty_num_nodes.dtype),
    )

    # 3. Call the Python function via jax.pure_callback
    simplified_array, simplified_root_idx, simplified_num_nodes = jax.pure_callback(
        simplify_spot, # The Python function to call
        result_shape_and_dtype,      # The "contract" for the output
        dirty_array,                 # Arguments to the callback
        dirty_root_idx,
        dirty_num_nodes,
    )

    return simplified_array, simplified_root_idx, simplified_num_nodes


@jax.jit
def jax_static_iterative_progress_no_copy(formula_array, truth_assignment, root_index, num_nodes):
    true_node_idx = num_nodes
    formula_array = formula_array.at[true_node_idx].set(jnp.array([TRUE_VAL, 0, 0]))
    num_nodes += 1

    false_node_idx = num_nodes
    formula_array = formula_array.at[false_node_idx].set(jnp.array([FALSE_VAL, 0, 0]))
    num_nodes += 1

    results = jnp.full(MAX_NODES, -1, dtype=jnp.int32)
    stack = jnp.zeros((MAX_NODES, 2), dtype=jnp.int32)
    stack_ptr = 0
    stack = stack.at[stack_ptr].set(jnp.array([root_index, 0]))
    stack_ptr += 1

    init_state = (formula_array, results, stack, stack_ptr, num_nodes, true_node_idx, false_node_idx)

    def main_loop_body(i, state):
        fa, res, st, sp, nn, true_idx, false_idx = state

        def process_stack_top(carry):
            fa, res, st, sp, nn, t_idx, f_idx = carry
            sp -= 1
            node_index, processed = st[sp]
            op, left_idx, right_idx = fa[node_index]

            is_atomic = (left_idx == 0) & (right_idx == 0) & ~(IS_UNARY_OP[op] | IS_BINARY_OP[op])


            def compute_node(compute_carry):
                fa, res, st, sp, nn, t_idx, f_idx = compute_carry

                def process_parent(parent_carry):
                    fa, res, st, sp, nn, t_idx, f_idx = parent_carry

                    def handle_unary(unary_carry):
                        fa_u, res_u, nn_u, sp_u = unary_carry
                        child_res_idx = res_u[left_idx]
                        child_res_val = fa_u[child_res_idx, 0]

                        def case_not(c):
                            res = c[0]
                            res_idx = jnp.where(child_res_val == TRUE_VAL, f_idx, t_idx)
                            return res.at[node_index].set(res_idx)

                        def case_next(c):
                            res = c[0]
                            return res.at[node_index].set(child_res_idx)

                        def case_temporal(op_val, c):
                            fa, res, nn = c
                            new_fa = fa.at[nn].set(jnp.array([op_val, node_index, child_res_idx]))
                            return res.at[node_index].set(nn), new_fa, nn + 1

                        res = lax.cond(
                            op == LTL_BASE_VOCAB["not"], case_not,
                            lambda c: lax.cond(op == LTL_BASE_VOCAB["next"], case_next, lambda x: x[0], c),
                            (res_u,)
                        )
                        res, fa, nn = lax.cond(
                            op == LTL_BASE_VOCAB["always"], lambda c: case_temporal(LTL_BASE_VOCAB["and"], c),
                            lambda c: lax.cond(
                                op == LTL_BASE_VOCAB["eventually"], lambda c2: case_temporal(LTL_BASE_VOCAB["or"], c2),
                                lambda c3: (c3[1], c3[0], c3[2]), c), # (res, fa, nn)
                            (fa_u, res, nn_u)
                        )
                        return fa, res, st, sp, nn, t_idx, f_idx

                    def handle_binary(binary_carry):
                        fa_b, res_b, nn_b, sp_b = binary_carry
                        left_res_idx, right_res_idx = res_b[left_idx], res_b[right_idx]
                        left_res_val, right_res_val = fa_b[left_res_idx, 0], fa_b[right_res_idx, 0]

                        def handle_and_or(carry):
                            fa_ao, res_ao, nn_ao, sp_ao = carry
                            is_and = op == LTL_BASE_VOCAB["and"]
                            term_val = jnp.where(is_and, FALSE_VAL, TRUE_VAL)
                            ident_val = jnp.where(is_and, TRUE_VAL, FALSE_VAL)

                            def make_new_op_node(c):
                                fa, nn = c
                                new_fa = fa.at[nn].set(jnp.array([op, left_res_idx, right_res_idx]))
                                return new_fa, nn + 1, nn

                            is_term = (left_res_val == term_val) | (right_res_val == term_val)
                            res_idx = jnp.where(is_and, f_idx, t_idx)

                            fa, nn, res_idx = lax.cond(
                               is_term, lambda c: (c[0],c[1], res_idx),
                               lambda c: lax.cond( (left_res_val==ident_val) & (right_res_val==ident_val), lambda c2: (c2[0],c2[1], jnp.where(is_and, t_idx, f_idx)),
                               lambda c2: lax.cond( left_res_val==ident_val, lambda c3: (c3[0],c3[1],right_res_idx),
                               lambda c3: lax.cond( right_res_val==ident_val, lambda c4: (c4[0],c4[1],left_res_idx), make_new_op_node,c3),c2),c),
                               (fa_ao, nn_ao))

                            res = res_ao.at[node_index].set(res_idx)
                            return fa, res, st, sp, nn, t_idx, f_idx

                        def handle_until(carry):
                            fa_u, res_u, nn_u, sp_u = carry

                            def make_and_for_f1(c):
                                fa, nn = c
                                new_fa = fa.at[nn].set(jnp.array([LTL_BASE_VOCAB["and"], left_res_idx, node_index]))
                                return new_fa, nn + 1, nn

                            fa, nn, f1_idx = lax.cond(
                                left_res_val == FALSE_VAL, lambda c: (c[0], c[1], f_idx),
                                lambda c: lax.cond(left_res_val == TRUE_VAL, lambda c2: (c2[0], c2[1], node_index), make_and_for_f1, c),
                                (fa_u, nn_u))

                            def make_or_for_final(c):
                                fa_in, nn_in = c
                                new_fa = fa_in.at[nn_in].set(jnp.array([LTL_BASE_VOCAB["or"], right_res_idx, f1_idx]))
                                return new_fa, nn_in + 1, nn_in

                            fa, nn, final_res_idx = lax.cond(
                                right_res_val == TRUE_VAL, lambda c: (c[0], c[1], t_idx),
                                lambda c: lax.cond(right_res_val == FALSE_VAL, lambda c2: (c2[0], c2[1], f1_idx), make_or_for_final, c),
                                (fa, nn))

                            res = res_u.at[node_index].set(final_res_idx)
                            return fa, res, st, sp, nn, t_idx, f_idx

                        return lax.cond(
                            (op == LTL_BASE_VOCAB["and"]) | (op == LTL_BASE_VOCAB["or"]),
                            handle_and_or,
                            handle_until,
                            binary_carry
                        )

                    return lax.cond(
                        IS_UNARY_OP[op],
                        handle_unary,
                        handle_binary,
                        (fa, res, nn, sp)
                    )

                def push_children(pre_carry):
                    fa, res, st, sp, nn, t_idx, f_idx = pre_carry
                    st = st.at[sp].set(jnp.array([node_index, 1]))
                    sp += 1
                    st, sp = lax.cond(IS_BINARY_OP[op], lambda c: (c[0].at[c[1]].set(jnp.array([right_idx, 0])), c[1] + 1), lambda c: c, (st, sp))
                    st = st.at[sp].set(jnp.array([left_idx, 0]))
                    sp += 1
                    return fa, res, st, sp, nn, t_idx, f_idx

                return lax.cond(processed == 1, process_parent, push_children, compute_carry)

            def process_atomic(atomic_carry):
                fa, res, st, sp, nn, t_idx, f_idx = atomic_carry
                is_true = jnp.any(op == truth_assignment)
                prop_res = jnp.where(is_true, t_idx, f_idx)
                final_res = jnp.where(op == TRUE_VAL, t_idx, jnp.where(op == FALSE_VAL, f_idx, prop_res))
                return fa, res.at[node_index].set(final_res), st, sp, nn, t_idx, f_idx

            return lax.cond(
                res[node_index] != -1, lambda x: x,
                lambda y: lax.cond(is_atomic, process_atomic, compute_node, y),
                (fa, res, st, sp, nn, true_idx, false_idx)
            )

        return lax.cond(sp > 0, process_stack_top, lambda x: x, state)

    final_state = lax.fori_loop(0, MAX_NODES * 2, main_loop_body, init_state)
    final_fa, final_res, _, _, final_nn, _, _ = final_state

    return final_res[root_index], final_fa, final_nn




# AST

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from typing import Callable, List, NamedTuple, Optional
from sfl.envs.ltl_env.utils import *
from collections import namedtuple, OrderedDict

edge_types = {"self": 0, "arg": 1, "arg1": 2, "arg2": 3}
NUM_EDGE_TYPES = len(edge_types)




class JaxASTBuilder:
    """
    Builds a GNN-compatible graph representation from an LTL formula array.
    This version is designed to be fully JAX-compilable and compatible with
    jax.ops.segment_sum by using a dedicated padding node at index 0.
    """
    def __init__(self, vocab_size: int, max_formula_nodes: int):
        self.vocab_size = vocab_size
        # The total number of nodes in the graph is max_formula_nodes + 1 (for the padding node)
        self.max_graph_nodes = max_formula_nodes + 1
        self.max_formula_nodes = max_formula_nodes


    """Creates and JIT-compiles the static graph building function."""
    @staticmethod
    @partial(jax.jit, static_argnames=['max_formula_nodes', 'vocab_size', 'max_graph_nodes'])
    def _build_graph_static(encoded_array, num_nodes: int, max_formula_nodes: int, vocab_size: int, max_graph_nodes: int):
        """
        JIT-compiled static method to construct the graph.
        This function returns PADDED graph arrays compatible with segment_sum.
        - Node 0 is a dedicated padding node.
        - Real formula nodes are indexed from 1 to num_nodes.
        - Padded edges are self-loops on node 0.
        """
        # === 1. Node Feature Construction ===
        # We have max_graph_nodes = max_formula_nodes + 1 total nodes. Node 0 is for padding.
        node_tokens = encoded_array[:max_formula_nodes, 0]
        one_hot_features = jax.nn.one_hot(node_tokens, num_classes=vocab_size, dtype=jnp.float32)

        # Initialize all node features to zero, including the padding node at index 0.
        node_features = jnp.zeros((max_graph_nodes, vocab_size), dtype=jnp.float32)
        # Place the one-hot features for the real nodes starting from index 1.
        node_features = node_features.at[1:].set(one_hot_features)

        # Mask out features for padded formula nodes (beyond num_nodes).
        # The mask starts from index 1, as node 0 is always a padding node.
        valid_node_mask = jnp.arange(max_graph_nodes) < (num_nodes + 1)
        node_features = node_features * valid_node_mask[:, None]

        # The root is now at index 1.
        is_root_feature = jnp.zeros((max_graph_nodes, 1), dtype=jnp.float32).at[1].set(1.0)
        node_features = jnp.concatenate([node_features, is_root_feature], axis=-1)

        # === 2. Edge Construction Loop ===
        def edge_construction_body(i, carry):
            senders, receivers, edge_type_indices = carry
            op, left_idx, right_idx = encoded_array[i]

            # **MODIFIED**: Shift all node indices by +1 to account for the padding node.
            current_node_idx = i + 1
            left_child_idx = left_idx + 1
            right_child_idx = right_idx + 1

            # Self-loop for the current real node.
            senders = senders.at[3 * i].set(current_node_idx)
            receivers = receivers.at[3 * i].set(current_node_idx)
            edge_type_indices = edge_type_indices.at[3 * i].set(edge_types['self'])

            # Unary op
            def unary_true_fn(vals):
                s, r, e = vals
                s = s.at[3 * i + 1].set(left_child_idx)
                r = r.at[3 * i + 1].set(current_node_idx)
                e = e.at[3 * i + 1].set(edge_types['arg'])
                return s, r, e
            senders, receivers, edge_type_indices = jax.lax.cond(
                IS_UNARY_OP[op], unary_true_fn, lambda vals: vals, (senders, receivers, edge_type_indices)
            )

            # Binary op
            def binary_true_fn(vals):
                s, r, e = vals
                s = s.at[3 * i + 1].set(left_child_idx)
                r = r.at[3 * i + 1].set(current_node_idx)
                e = e.at[3 * i + 1].set(edge_types['arg1'])
                s = s.at[3 * i + 2].set(right_child_idx)
                r = r.at[3 * i + 2].set(current_node_idx)
                e = e.at[3 * i + 2].set(edge_types['arg2'])
                return s, r, e
            senders, receivers, edge_type_indices = jax.lax.cond(
                IS_BINARY_OP[op], binary_true_fn, lambda vals: vals, (senders, receivers, edge_type_indices)
            )
            return senders, receivers, edge_type_indices

        # Initialize with a sentinel value to detect validly created edges.
        num_potential_edges = max_formula_nodes * 3
        init_val = jnp.full(num_potential_edges, -1, dtype=jnp.int32)
        senders_padded, receivers_padded, edge_types_padded = jax.lax.fori_loop(
            0, num_nodes, edge_construction_body, (init_val, init_val, init_val)
        )

        # === 3. Gather Valid Edges and Pad for segment_sum ===
        valid_edge_mask = senders_padded != -1
        num_valid_edges = valid_edge_mask.sum()

        # Use argsort to efficiently move all valid edges to the front.
        sort_key = jnp.where(valid_edge_mask, jnp.arange(num_potential_edges), num_potential_edges)
        permutation = jnp.argsort(sort_key)
        senders_clean = senders_padded[permutation]
        receivers_clean = receivers_padded[permutation]
        edge_types_clean = edge_types_padded[permutation]

        # **MODIFIED**: Create a mask for padded edges (those after the valid ones).
        # All padded senders and receivers should point to the padding node (index 0).
        is_padded_edge_mask = jnp.arange(num_potential_edges) >= num_valid_edges
        final_senders = jnp.where(is_padded_edge_mask, 0, senders_clean)
        final_receivers = jnp.where(is_padded_edge_mask, 0, receivers_clean)
        # We can also set the edge type to 'self' for padded edges.
        final_edge_types = jnp.where(is_padded_edge_mask, edge_types['self'], edge_types_clean)

        # # === 4. Edge Feature Construction ===
        # edge_features = jax.nn.one_hot(final_edge_types, num_classes=NUM_EDGE_TYPES, dtype=jnp.float32)
        # # Mask out features for padded edges to ensure they are zero.
        # edge_features = edge_features * (~is_padded_edge_mask)[:, None]

        return OrderedDict([
            ('nodes', node_features),
            ('senders', final_senders),
            ('receivers', final_receivers),
            ('n_node', jnp.array([num_nodes])),
            ('edge_types', final_edge_types)
        ])


    @partial(jax.jit, static_argnums=0)
    def __call__(self, encoded_array: jnp.ndarray, num_nodes: int):
        """
        Processes an encoded formula array to produce a padded graph representation
        that is compatible with jax.ops.segment_sum.
        """
        # The __call__ can't be jitted with self, as it would recompile for every instance.
        # The performance comes from jitting the static _build_graph_static method.
        return JaxASTBuilder._build_graph_static(encoded_array, num_nodes,max_formula_nodes=self.max_formula_nodes,max_graph_nodes=self.max_graph_nodes,vocab_size=VOCAB_SIZE)


# Environment

In [None]:
import jax
import jax.numpy as jnp
import chex
from flax import struct
from gymnax.environments import spaces  # Assuming gymnax is available per the wrapper
from functools import partial
from typing import Tuple

# --- JAX-compatible State ---
# This dataclass holds all dynamic variables of the environment.
# An instance of this is passed to step() and returned by reset() and step().
@struct.dataclass
class SimpleLTLState:
    time: chex.Array
    proposition: chex.Array  # Stores the integer index of the last action
    num_episodes: chex.Array
    key: chex.PRNGKey

# --- JAX-compatible Parameters ---
# This holds static configuration.
@struct.dataclass
class SimpleLTLEnvParams:
    timeout: int
    num_letters: int

class JaxSimpleLTLEnv:
    """
    A JAX-compatible, functional version of SimpleLTLEnv.

    This environment's logic is simple:
    - The state consists of the current time and the last proposition (action).
    - An action IS a proposition.
    - The episode ends when 'timeout' is exceeded.
    - The reward is always 0.
    - The observation is always 0.
    """

    def __init__(self, letters: str, timeout: int):
        """
        letters:
            - (str) String of propositions, e.g., "abcdef"
        timeout:
            - (int) Maximum length of the episode
        """
        # --- Static (compile-time) attributes ---
        unique_letters = sorted(list(set(letters)))
        self.num_letter_types = len(unique_letters)

        # This string map is NOT JAX-compatible and cannot be used
        # inside jitted functions. It's only for non-jit helpers.
        self._letter_map = {i: letter for i, letter in enumerate(unique_letters)}

        # Store default parameters
        self.default_params = SimpleLTLEnvParams(
            timeout=timeout,
            num_letters=self.num_letter_types
        )

    @property
    def params(self) -> SimpleLTLEnvParams:
        """Default environment parameters."""
        return self.default_params

    @partial(jax.jit, static_argnames=("self",))
    def reset(
        self, key: chex.PRNGKey, params: SimpleLTLEnvParams
    ) -> Tuple[chex.Array, SimpleLTLState]:
        """Resets the environment state."""

        # Get observation (always 0)
        obs = self._get_observation(None) # State-independent

        state = SimpleLTLState(
            time=jnp.array(0),
            proposition=jnp.array(-1), # -1 indicates no proposition yet
            num_episodes=jnp.array(0), # This counter resets with the env
            key=key
        )
        return obs, state

    @partial(jax.jit, static_argnames=("self",))
    def step(
        self,
        key: chex.PRNGKey,
        state: SimpleLTLState,
        action: int,
        params: SimpleLTLEnvParams,
    ) -> Tuple[chex.Array, SimpleLTLState, chex.Array, chex.Array, dict]:
        """
        This function executes an action in the environment.
        """

        # Update time and check for timeout
        new_time = state.time + 1
        done = new_time > params.timeout

        # Reward is always 0
        reward = jnp.array(0.0)

        # Observation is always 0
        obs = self._get_observation(state)

        # The new proposition is the action taken
        new_proposition = jnp.array(action)

        # Create the new, immutable state
        new_state = SimpleLTLState(
            time=new_time,
            proposition=new_proposition,
            num_episodes=state.num_episodes, # Does not increment
            key=key
        )

        return obs, new_state, reward, done, {}

    # === Property & Helper Functions ===
    # These match the API expected by your LTLEnv wrapper

    def action_space(self, params: SimpleLTLEnvParams) -> spaces.Discrete:
        """Action space: one discrete action per letter."""
        return spaces.Discrete(params.num_letters)

    def observation_space(self, params: SimpleLTLEnvParams) -> spaces.Discrete:
        """Observation space: always 0."""
        return spaces.Discrete(1)

    @partial(jax.jit, static_argnames=("self",))
    def _get_observation(self, state: SimpleLTLState) -> chex.Array:
        """Returns the observation (which is always 0)."""
        return jnp.array(0)

    @partial(jax.jit, static_argnames=("self",))
    def get_events(self, state: SimpleLTLState, params: SimpleLTLEnvParams) -> chex.Array:
        """
        Gets the current "truth assignment" based on the state.
        In this env, the truth assignment is just a one-hot vector
        of the last action (proposition) taken.
        """
        # If proposition is -1 (at reset), return all-false vector
        return jax.lax.cond(
            state.proposition == -1,
            lambda: jnp.zeros(params.num_letters, dtype=jnp.bool_),
            lambda: jax.nn.one_hot(state.proposition, params.num_letters, dtype=jnp.bool_)
        )

    def get_propositions(self, params: SimpleLTLEnvParams) -> chex.Array:
        """
        Returns the set of all possible propositions *as integer indices*.
        """
        return jnp.arange(params.num_letters)

    # --- Non-JAX helper for debugging ---

    def get_events_str(self, state: SimpleLTLState) -> str:
        """
        Non-JITtable helper to get the string name of the current proposition.
        DO NOT use this inside a jitted function.
        """
        prop_idx = int(state.proposition)
        if prop_idx in self._letter_map:
            return self._letter_map[prop_idx]
        return "None"

# Example of how to use it (similar to the original SimpleLTLEnvDefault)
class JaxSimpleLTLEnvDefault(JaxSimpleLTLEnv):
    def __init__(self):
        super().__init__(letters="abcdefghijkl", timeout=75)


  TEST 1: STANDARD 5x5 (WORLD-CENTRIC)


Step: 0
--- True Map View ---
-----------
|A|.|e|a|.|
|.|.|.|.|b|
|d|.|.|b|c|
|.|.|.|.|e|
|.|a|c|.|d|
-----------
Current Event: None
Move (w/a/s/d) and press Enter, or q to quit: a

Step: 1
--- True Map View ---
-----------
|.|.|e|a|A|
|.|.|.|.|b|
|d|.|.|b|c|
|.|.|.|.|e|
|.|a|c|.|d|
-----------
Current Event: None
Move (w/a/s/d) and press Enter, or q to quit: s

Step: 2
--- True Map View ---
-----------
|.|.|e|a|.|
|.|.|.|.|A|
|d|.|.|b|c|
|.|.|.|.|e|
|.|a|c|.|d|
-----------
Current Event: b
Move (w/a/s/d) and press Enter, or q to quit: a

Step: 3
--- True Map View ---
-----------
|.|.|e|a|.|
|.|.|.|A|b|
|d|.|.|b|c|
|.|.|.|.|e|
|.|a|c|.|d|
-----------
Current Event: None
Move (w/a/s/d) and press Enter, or q to quit: w

Step: 4
--- True Map View ---
-----------
|.|.|e|A|.|
|.|.|.|.|b|
|d|.|.|b|c|
|.|.|.|.|e|
|.|a|c|.|d|
-----------
Current Event: a
Move (w/a/s/d) and press Enter, or q to quit: w

Step: 5
--- True Map View ---
-----------
|.|.|e|

# Wrapper

In [None]:
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from gymnax.environments.environment import Environment # Corrected import
from gym import spaces
from dataclasses import dataclass, replace
from flax.struct import PyTreeNode

class LTLEnvState(PyTreeNode):
    env_state: any  # Underlying env state
    ltl_goal: jnp.ndarray   # Current LTL formula
    ltl_original: jnp.ndarray  # Original LTL formula
    key: jnp.ndarray  # PRNG key
    num_nodes: jnp.ndarray
    root_idx: jnp.ndarray

class LTLEnv:
    """
    Functional wrapper adding LTL goals to a Gymnax environment.
    Adds LTL formula to observations, progresses it, and modifies rewards.
    """
    def __init__(self, env: Environment, params: any, progression_mode: str = "full",
                 ltl_sampler: str = None, intrinsic: float = 0.0):
        self.env = env
        self.params = params
        self.progression_mode = progression_mode
        self.propositions = env.get_propositions(params)
        self.sampler = JaxUntilTaskSampler(propositions,propositions, min_levels=1, max_levels=3, min_conjunctions=1, max_conjunctions=2)
        self.intrinsic = intrinsic
        self.observation_space = spaces.Dict({
            'features': env.observation_space(params),
            'text': spaces.Box(low=0, high=0, shape=(), dtype=object) if progression_mode in ["full", "none"]
                   else spaces.Box(low=-1, high=1, shape=(len(self.propositions),), dtype=jnp.float32)
        })
        self.ast_builder = JaxASTBuilder(VOCAB_SIZE, MAX_NODES)
    @partial(jax.jit, static_argnames=("self", "params"))
    def reset(self, key: jnp.ndarray, params: any) -> tuple[dict, LTLEnvState]:
        """Reset env, sample LTL goal, return dict obs and state."""
        key, subkey, sample_key= jax.random.split(key,3)
        obs, env_state = self.env.reset_env(subkey, params)

        final_array, num_nodes, root_idx = self.sample_ltl_goal(sample_key)
        ltl_state = LTLEnvState(
            env_state=env_state,
            ltl_goal=final_array,
            ltl_original=final_array,
            key=key,
            num_nodes=num_nodes,
            root_idx=root_idx
        )
        graph = self.ast_builder(final_array, num_nodes)
        ltl_obs=graph
        return ltl_obs, ltl_state

    # @partial(jax.jit, static_argnames=("self", "params"))
    # def step(self, key: jnp.ndarray, state: LTLEnvState, action: int, params: any) -> tuple[dict, float, bool, dict, LTLEnvState]:
    #     """Step env, progress LTL, and auto-reset if done."""
    #     # 1. Split the key for the step and a potential reset
    #     key_step, key_reset = jax.random.split(key)

    #     # --- Calculate the outcome of a single step (the "state_st" path) ---
    #     obs_st, reward_st, done_st, info, new_env_state = self.env.step_env(
    #         key_step, state.env_state, action, params
    #     )

    #     # Progress LTL formula
    #     truth_assignment = self.get_events(new_env_state, params)
    #     ltl_goal, root_index, num_nodes = progress_and_clean_jax(
    #         state.ltl_goal, truth_assignment, state.root_idx, state.num_nodes
    #     )

    #     # Create the next state if the episode is not done
    #     state_st = LTLEnvState(
    #         env_state=new_env_state,
    #         ltl_goal=ltl_goal,
    #         ltl_original=state.ltl_original,
    #         key=key_step, # This key is not used after this, but we keep the structure
    #         num_nodes=num_nodes,
    #         root_idx=root_index
    #     )

    #     # Compute LTL reward and combine done signals
    #     ltl_reward = jax.lax.cond(
    #         ltl_goal[0][0] == LTL_BASE_VOCAB['True'], lambda: 1.0,
    #         lambda: jax.lax.cond(
    #             ltl_goal[0][0] == LTL_BASE_VOCAB['False'], lambda: -1.0,
    #             lambda: self.intrinsic
    #         )
    #     )
    #     is_true = (ltl_goal[0][0] == LTL_BASE_VOCAB['True'])
    #     is_false = (ltl_goal[0][0] == LTL_BASE_VOCAB['False'])
    #     ltl_done = jnp.logical_or(is_true, is_false)

    #     # Final done condition and observation for the step path
    #     final_done = jnp.logical_or(done_st, ltl_done)
    #     obs_st = {'image': obs_st, 'text': ltl_goal}
    #     reward = reward_st + ltl_reward

    #     # --- Calculate the outcome of a reset (the "state_re" path) ---
    #     # 2. Call the wrapper's own reset method to get the reset state and obs
    #     obs_re, state_re = self.reset(key_reset, params)

    #     # --- Conditionally select the next state and observation ---
    #     # 3. If final_done is true, pick the reset state/obs; otherwise, pick the step state/obs.
    #     #    jax.tree.map is used because both state and obs are pytrees (custom class and dict).
    #     state = jax.tree.map(
    #         lambda x, y: jax.lax.select(final_done, x, y), state_re, state_st
    #     )
    #     obs = jax.tree.map(
    #         lambda x, y: jax.lax.select(final_done, x, y), obs_re, obs_st
    #     )

    #     return obs, reward, final_done, info, state
    @partial(jax.jit, static_argnames=("self", "params"))
    def step(self, key: jnp.ndarray, state: LTLEnvState, action: int, params: any) -> tuple[dict, float, bool, dict, LTLEnvState]:
        """Step env, progress LTL, and reset automatically if done."""
        # Split key for the step and a potential reset
        step_key, reset_key = jax.random.split(key)

        # --- 1. Perform the normal environment step ---
        obs, reward, done, info, new_env_state = self.env.step_env(step_key, state.env_state, action, params)

        # --- 2. Progress the LTL goal ---
        truth_assignment = self.get_events(new_env_state, params)
        ltl_goal, root_index, num_nodes = progress_and_clean_jax(state.ltl_goal, truth_assignment, state.root_idx, state.num_nodes)

        # --- 3. Compute LTL reward and done status ---
        is_true = (ltl_goal[0][0] == LTL_BASE_VOCAB['True'])
        is_false = (ltl_goal[0][0] == LTL_BASE_VOCAB['False'])
        ltl_done = jnp.logical_or(is_true, is_false)
        ltl_reward = jax.lax.cond(
            is_true,
            lambda: 1.0,
            lambda: jax.lax.cond(
                is_false,
                lambda: -1.0,
                lambda: self.intrinsic
            )
        )

        # --- 4. Determine final reward and done for the current transition ---
        final_reward = reward + ltl_reward
        final_done = jnp.logical_or(done, ltl_done)

        # --- 5. Conditionally determine the next state and observation ---
        # This is the core of the auto-reset logic.
        # We define two functions: one for the 'done' case (reset) and one for the 'not done' case.
        # `jax.lax.cond` will execute one of them based on `final_done`.
        # Both functions must return PyTrees (e.g., tuples, dicts) with the exact same structure.

        def reset_case(_):
            """Called when final_done is True. Resets the environment."""
            # We pass the separate reset_key to the reset function.
            return self.reset(reset_key, params)

        def step_case(_):
            """Called when final_done is False. Returns the result of the normal step."""
            new_state = LTLEnvState(
                env_state=new_env_state,
                ltl_goal=ltl_goal,
                ltl_original=state.ltl_original,
                key=step_key, # The new state gets the used step_key
                num_nodes=num_nodes,
                root_idx=root_index
            )
            graph = self.ast_builder(ltl_goal, num_nodes)
            new_obs = graph
            return new_obs, new_state

        # `cond` returns the output of either `reset_case` or `step_case`.
        # Both functions return a tuple of (observation, state).
        final_obs, final_state = jax.lax.cond(
            final_done,
            reset_case,
            step_case,
            operand=None # No operand needed as functions use variables from the outer scope
        )

        # --- 6. Return the final transition ---
        # The reward and done flags are from the current step, but the observation
        # and state are for the *next* step (which is a reset if done).
        return final_obs, final_reward, final_done, info, final_state


    # def sample_ltl_goal(self) -> any:
    #     """Sample LTL formula, adjust timeout for SequenceSampler."""
    #     formula = self.sampler.sample()
    #     if isinstance(self.sampler, SequenceSampler):
    #         def count_and(formula):
    #             return sum(count_and(item) for item in formula if isinstance(item, tuple)) + 1
    #         length = count_and(formula)
    #         self.params = self.params.replace(timeout=25)  # 10 * length
    #     return formula
    def sample_ltl_goal(self,key) -> any:
        """Sample LTL formula, adjust timeout for SequenceSampler."""
        final_array, num_nodes, root_idx = self.sampler.sample(key)
        return final_array, num_nodes, root_idx

    def get_events(self, env_state: EnvState, params: EnvParams) -> chex.Array:
        """Get current propositions from the underlying env using its state."""
        return self.env.get_events(env_state, params)

class NoLTLWrapper:
    """Remove LTL wrapper, return plain env observations."""
    def __init__(self, env: Environment, params: any):
        self.env = env
        self.params = params
        self.observation_space = env.observation_space(params)

    def reset(self, key: jnp.ndarray, params: any) -> tuple[any, any]:
        return self.env.reset_env(key, params)

    def step(self, key: jnp.ndarray, state: any, action: int, params: any) -> tuple[any, float, bool, dict, any]:
        return self.env.step_env(key, state, action, params)

    def get_propositions(self, params: any) -> list:
        return []



# --- import your LetterEnv implementation ---
# from letter_env_fixed import LetterEnv, EnvParams   # <-- adjust to your filename
# from ltl_env_wrapper import LTLEnv                  # <-- adjust if you saved wrapper separately

def main_wrap():
    # Create the base LetterEnv
    env = LetterEnv()
    params = EnvParams(grid_size=5, letters="aabbccddee", use_fixed_map=False, use_agent_centric_view=False, timeout=10)

    # Wrap with the LTLEnv wrapper
    ltl_env = LTLEnv(env, params, progression_mode="full", ltl_sampler=None, intrinsic=0.0)

    # PRNG key
    key = jax.random.PRNGKey(0)

    # Reset
    obs, state = ltl_env.reset(key, params)
    print("\n=== RESET ===")
    show(env, state.env_state, params)
    show_features(env, obs['features'], params)
    print("Initial LTL goal:", state.ltl_goal)

    str_to_action = {"w": 0, "s": 1, "a": 2, "d": 3}
    step_idx = 0

    while True:
        step_idx += 1
        print("\n--- Step", step_idx, "---")
        cmd = input("Action? (w/a/s/d)  r=random  q=quit  > ").strip().lower()
        if cmd == "q":
            print("Quitting.")
            break
        if cmd == "r":
            action = random.choice([0, 1, 2, 3])
        elif cmd in str_to_action:
            action = str_to_action[cmd]
        else:
            print("Unknown command; try w/a/s/d, r, or q.")
            continue

        # Step the LTL-wrapped env (we pass state.key so wrapper can split it internally)
        obs, total_reward, done, info, state = ltl_env.step(state.key, state, action, params)

        # Show grid and features (from underlying env_state stored in wrapper state)
        show(env, state.env_state, params)
        show_features(env, obs['features'], params)

        # Print debug info
        print("Action:", action, "(w/up=0, s/down=1, a/left=2, d/right=3)")
        print("Underlying env reward:", info.get('env_reward'))
        print("LTL reward:", info.get('ltl_reward'))
        print("Total reward returned:", total_reward)
        print("Truth assignment at new agent pos:", repr(info.get('truth_assignment')))
        print("Progressed LTL goal (in state):", state.ltl_goal)
        print("Is episode done?:", done)

        if done:
            print("\nEpisode finished (env or LTL termination).")
            break

# GCN Model

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from typing import Callable, List, NamedTuple, Optional

def segment_sum(data: jnp.ndarray, segment_ids: jnp.ndarray, num_segments: int) -> jnp.ndarray:
    """Computes the sum of elements within segments of an array."""
    # Note: jax.ops.segment_sum is deprecated. Using jax.lax.segment_sum instead.
    # We pad segment_ids to avoid jax.lax.segment_sum's check.
    # This assumes segment_ids are contiguous from 0 to num_segments - 1.
    return jax.ops.segment_sum(data, segment_ids, num_segments=num_segments)

class RelationalUpdate(nn.Module):
    """
    A Flax module to compute messages based on relation type.
    It applies a different linear transformation for each relation.
    """
    features: int
    num_relations: int

    @nn.compact
    def __call__(self, nodes: jnp.ndarray, senders: jnp.ndarray, edge_types: jnp.ndarray) -> jnp.ndarray:
        """
        Args:
            nodes: The node features array of shape `[num_nodes, in_features]`.
            senders: The sender node indices for each edge of shape `[num_edges]`.
            edge_types: The integer type for each edge of shape `[num_edges]`.

        Returns:
            An array of computed messages of shape `[num_edges, out_features]`.
        """
        in_features = nodes.shape[-1]

        # Create a stack of weight matrices, one for each relation type.
        kernels = self.param(
            'kernels',
            nn.initializers.lecun_normal(),
            (self.num_relations, in_features, self.features)
        )

        # Get the features of the sender nodes for each edge.
        sender_features = nodes[senders]  # Shape: [num_edges, in_features]

        # Select the appropriate kernel for each edge based on its type.
        edge_kernels = kernels[edge_types] # Shape: [num_edges, in_features, out_features]

        # Compute messages: messages[e] = W_type(e) * h_sender(e)
        # einsum is efficient for this batched matrix-vector product.
        messages = jnp.einsum('eif,ei->ef', edge_kernels, sender_features) # Shape: [num_edges, out_features]

        return messages

def CustomRelationalGraphConvolution(
    update_node_module: nn.Module,
    symmetric_normalization: bool = True
) -> Callable[[dict], dict]:
    """
    Returns a function that applies a Relational Graph Convolution layer.
    This function wraps the message computation and performs aggregation.
    """
    def _ApplyRGCN(graph: dict) -> dict:
        nodes, senders, receivers, edge_types = (
            graph["nodes"], graph["senders"], graph["receivers"], graph["edge_types"]
        )

        # Compute messages using the provided relation-specific update module.
        messages = update_node_module(nodes, senders, edge_types)

        total_num_nodes = nodes.shape[0]

        # Aggregate messages at receiver nodes.
        if symmetric_normalization:
            ones = jnp.ones_like(senders, dtype=jnp.float32)
            # Ensure degrees are calculated correctly even for isolated nodes
            sender_degree = segment_sum(ones, senders, total_num_nodes).clip(1.0)
            receiver_degree = segment_sum(ones, receivers, total_num_nodes).clip(1.0)

            norm_senders = jax.lax.rsqrt(sender_degree)
            norm_receivers = jax.lax.rsqrt(receiver_degree)

            messages = messages * norm_senders[senders, None]
            aggregated_nodes = segment_sum(messages, receivers, total_num_nodes)
            aggregated_nodes = aggregated_nodes * norm_receivers[:, None]
        else:
            aggregated_nodes = segment_sum(messages, receivers, total_num_nodes)

        return {**graph, "nodes": aggregated_nodes}

    return _ApplyRGCN


# --- GNN Module (Provided) ---

class RGCNRootShared_no_jraph(nn.Module):
    """An RGCN with shared weights and root-based readout."""
    hidden_dim: int
    num_layers: int
    output_dim: int
    num_edge_types: int

    @nn.compact
    def __call__(self, graph: dict) -> jnp.ndarray:
        # Separate the 'is_root' flag from the node features.
        h_features = graph["nodes"][:, :-1]
        is_root_nodes = graph["nodes"][:, -1:]

        # Initial linear projection.
        h_0 = nn.Dense(features=self.hidden_dim, name='input_dense')(h_features)
        h = h_0

        # Define the single, shared convolutional layer module.
        # Its input will have size 2 * hidden_dim due to the skip connection.
        shared_update_module = RelationalUpdate(
            features=self.hidden_dim,
            num_relations=self.num_edge_types,
            name='shared_rgcn_update'
        )
        rgcn_layer = CustomRelationalGraphConvolution(update_node_module=shared_update_module)

        # Prepare graph structure (excluding nodes) for convolution loops
        conv_graph = {key: val for key, val in graph.items() if key != 'nodes'}

        for _ in range(self.num_layers):
            h_cat = jnp.concatenate([h, h_0], axis=-1)
            current_layer_graph = {**conv_graph, "nodes": h_cat}

            graph_after_rgcn = rgcn_layer(current_layer_graph)
            # Use tanh activation as in the DGL example.
            h = nn.tanh(graph_after_rgcn["nodes"])

        # Graph Readout: Select and sum root node embeddings.
        num_graphs = graph["n_node"].shape[0]

        # This logic handles batching (num_graphs > 1) and single instances (num_graphs=1)
        if num_graphs > 0:
            # Create segment_ids for segment_sum
            # This assumes nodes are packed contiguously per graph
            num_total_nodes = h.shape[0]
            num_nodes_per_graph = num_total_nodes // num_graphs
            segment_ids = jnp.repeat(jnp.arange(num_graphs), repeats=num_nodes_per_graph)
            graph_embeddings = segment_sum(h * is_root_nodes, segment_ids, num_segments=num_graphs)
        else:
            graph_embeddings = jnp.zeros((0, h.shape[-1]))

        output = nn.Dense(features=self.output_dim, name='output_dense')(graph_embeddings)
        return jnp.squeeze(output, axis=0) # Squeeze just in case batch size was 1



# ACModel

In [None]:
class ActorCritic(nn.Module):
    """
    JAX-native Actor-Critic model that combines visual and textual (GNN) embeddings.
    Accepts an Observation dataclass and returns a distrax.Distribution.
    """
    text_embedding_size: int = 32
    output_dim: int = 12


    def setup(self):
        """Initializes the sub-modules of the actor-critic model."""
        self.env_model = EnvModel()
        VmappedGNN = nn.vmap(
            RGCNRootShared_no_jraph,
            in_axes=0,             # Map over the first axis of the input PyTree (the Graph dict)
            out_axes=0,            # Stack outputs along the first axis
            variable_axes={'params': None}, # Do not map/split the model parameters
            split_rngs={'params': False}    # Do not split RNGs for parameter initialization
        )

        self.gnn = VmappedGNN(
            output_dim=self.text_embedding_size,
            hidden_dim=32,
            num_layers=8,
            num_edge_types=4
        )

        # Define actor network layers directly to output logits.
        actor_layers = []


        actor_layers.append(nn.Dense(features=self.output_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)))

        self.actor_net = nn.Sequential(actor_layers)

        # Critic network remains the same
        self.critic_net = nn.Sequential([
            nn.Dense(features=1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))
        ])

    def __call__(self, obs: Dict, carry=None, reset=None) -> Tuple[jnp.ndarray, distrax.Distribution, Any]:
        """
        Forward pass for the Actor-Critic model.

        Args:
            obs: An Observation dataclass instance.

        Returns:
            A tuple containing:
            - v (jnp.ndarray): The state-value estimate (critic).
            - distribution (distrax.Distribution): The policy distribution (actor).
            - carry (Any): The recurrent hidden state (None for this model).
        """



        # Process text features with GNN
        embedding_gnn = self.gnn(obs)
        # --- End Modification ---


        # --- Actor Pass ---
        # Get unnormalized logits from the actor network
        logits = self.actor_net(embedding)

        # --- MODIFICATION: Create distribution ---
        distribution = distrax.Categorical(logits=logits)
        # --- End Modification ---

        # --- Critic Pass ---
        # Get the state-value estimate from the critic network
        v = self.critic_net(embedding)


        return distribution, v

  return datetime.utcnow().replace(tzinfo=utc)


# Base Algo

In [None]:
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import jax
import jax.numpy as jnp
import flax.linen as nn
from typing import Callable, Dict, Any, List
from gymnax.environments.environment import Environment
import gymnax
import rlax
import chex
from flax.struct import PyTreeNode


# This dataclass is still useful for defining the expected structure of experience data.
@dataclass
class Experience(PyTreeNode):
    """Stores all data for a batch of experiences."""
    obs: chex.ArrayTree
    mask: chex.Array
    action: chex.Array
    value: chex.Array
    reward: chex.Array
    advantage: chex.Array
    returnn: chex.Array
    log_prob: chex.Array

@dataclass
class AlgoState:
    """High-level state of the algorithm, including model parameters and non-JAX logs."""
    params: Dict[str, Any]
    opt_state: Any
    log_return: List[float] = field(default_factory=list)
    log_reshaped_return: List[float] = field(default_factory=list)
    log_num_frames: List[int] = field(default_factory=list)

@dataclass
class RolloutState(PyTreeNode):
    """State carried through the JAX scan loop. Must be JAX-compatible."""
    rng: chex.PRNGKey
    env_state: Any
    obs: chex.ArrayTree
    mask: chex.Array
    ep_return: chex.Array
    ep_reshaped_return: chex.Array
    ep_num_frames: chex.Array


class BaseAlgo(ABC):
    """Base class for JAX-based RL algorithms with JIT-compiled experience collection."""
    def __init__(self, env: Environment, env_params: Any, acmodel: nn.Module, num_procs: int,
                 num_frames_per_proc: int, discount: float, gae_lambda: float,
                 preprocess_obss: Callable = None, reshape_reward: Callable = None):
        self.env = env
        self.env_params = env_params
        self.acmodel = acmodel
        self.num_procs = num_procs
        self.num_frames_per_proc = num_frames_per_proc
        self.num_frames = num_procs * num_frames_per_proc
        self.discount = discount
        self.gae_lambda = gae_lambda
        self.preprocess_obss = preprocess_obss or (lambda x, rng=None: x)
        self.reshape_reward = reshape_reward
        # Vmap the environment step function for parallel execution.
        step_environment = lambda key, state, action, params: self.env.step(key, state, action, params)
        self.vmapped_env_step = jax.vmap(
            step_environment, in_axes=(0, 0, 0, None)
        )
    @jax.jit
    def collect_experiences(self, algo_state: AlgoState, rollout_state: RolloutState) -> tuple[dict, dict, AlgoState, RolloutState]:
        """Collects rollouts and computes advantages. Can be JIT-compiled."""


        @jax.jit
        def step_fn(carry: RolloutState, _):
            """
            This function is scanned over, representing one time step for ALL parallel processes.
            """
            rng, env_state, obs, mask = carry.rng, carry.env_state, carry.obs, carry.mask
            ep_return, ep_reshaped_return, ep_num_frames = carry.ep_return, carry.ep_reshaped_return, carry.ep_num_frames

            rng, subkey_policy, subkey_step = jax.random.split(rng, 3)

            dist, value = self.acmodel.apply({'params': algo_state.params}, obs)

            # Sample actions for all processes from the batched distribution
            action = dist.sample(seed=subkey_policy)
            log_prob = dist.log_prob(action)

            step_keys = jax.random.split(subkey_step, self.num_procs)
            next_obs,reward,done, info, next_env_state= self.vmapped_env_step(
                step_keys,
                env_state,
                action,
                self.env_params
            )



            # Assuming reshape_reward is a simple function. If not, it needs to be JAX-compatible.
            reshaped_reward = reward * (self.discount ** ep_num_frames)

            new_ep_return = ep_return + reward
            new_ep_reshaped_return = ep_reshaped_return + reshaped_reward
            new_ep_num_frames = ep_num_frames + 1

            experience_step = {
                "obs": obs, "mask": mask, "action": action, "value": value,
                "reward": reshaped_reward, "log_prob": log_prob,
                "done": done, "ep_return_at_done": new_ep_return,
                "ep_reshaped_return_at_done": new_ep_reshaped_return,
                "ep_num_frames_at_done": new_ep_num_frames
            }
            next_mask = 1.0 - done
            next_carry = RolloutState(
                rng=rng, env_state=next_env_state, obs=next_obs,
                mask=next_mask, ep_return=new_ep_return * next_mask,
                ep_reshaped_return=new_ep_reshaped_return * next_mask,
                ep_num_frames=jnp.int32(new_ep_num_frames * next_mask)
            )
            return next_carry, experience_step

        # --- KEY CHANGE: Scan the batch-oriented step_fn directly ---
        # No vmap is needed here because step_fn already handles the batch of processes.
        final_rollout_state, experiences = jax.lax.scan(
            step_fn, rollout_state, None, length=self.num_frames_per_proc
        )

        # 'experiences' is now a dictionary of arrays, a standard JAX Pytree.
        # Each value has shape: (num_frames_per_proc, num_procs, ...)

        # --- KEY CHANGE: Use a single batched call for the next value ---
        _, next_value = self.acmodel.apply({'params': algo_state.params}, final_rollout_state.obs)

         # --- FIX: Construct the full value sequence for GAE calculation ---
        # GAE requires values from step 0 to k, where k is the last step.
        # `experiences['value']` contains values [v0, v1, ..., vk-1]
        # `next_value` is vk. We concatenate them for the rlax function.
        all_values = jnp.concatenate(
            [experiences['value'], next_value[None, :]], axis=0
        )

        # Ensure discounts are zero for terminal states to reset GAE calculation
        # Note: The reward and discount sequences should have length k.
        discounts = self.discount * (1.0 - experiences['done'])

        # --- FIX: Vmap the GAE calculation over the batch dimension ---
        vmapped_gae = jax.vmap(
            rlax.truncated_generalized_advantage_estimation,
            # This tuple maps to the POSITIONAL arguments of the function below
            in_axes=(1, 1, None, 1),
            out_axes=1
        )

        # --- FIX: Call the vmapped function with POSITIONAL arguments ---
        # The order must match the signature: (r_t, discount_t, lambda_, values)
        advantages = vmapped_gae(
            experiences['reward'],   # Corresponds to r_t
            discounts,               # Corresponds to discount_t
            self.gae_lambda,         # Corresponds to lambda_
            all_values               # Corresponds to values
        )

        experiences['advantage'] = advantages
        experiences['returnn'] = advantages + experiences['value']

        keys_to_keep = ["obs", "mask", "action", "value", "reward", "log_prob", "advantage", "returnn"]
        exps_dict = {key: experiences[key] for key in keys_to_keep}

        # Reshape data to (num_procs * num_frames_per_proc, ...) for the update step
        exps = jax.tree.map(lambda x: jnp.swapaxes(x, 0, 1).reshape(-1, *x.shape[2:]), exps_dict)

        # Logging logic remains the same, operating on the collected experiences
        done_mask = experiences['done'].flatten()

        log_return = algo_state.log_return + experiences['ep_return_at_done'].flatten()[done_mask].tolist()
        log_reshaped_return = algo_state.log_reshaped_return + experiences['ep_reshaped_return_at_done'].flatten()[done_mask].tolist()
        log_num_frames = algo_state.log_num_frames + experiences['ep_num_frames_at_done'].flatten()[done_mask].tolist()

        # dataclasses.replace is a good pattern for immutable updates
        new_algo_state = AlgoState(
            params=algo_state.params,
            opt_state=algo_state.opt_state,
            log_return=log_return,
            log_reshaped_return=log_reshaped_return,
            log_num_frames=log_num_frames
        )

        keep = max(int(done_mask.sum()), self.num_procs)
        logs = {
            "return_per_episode": new_algo_state.log_return[-keep:],
            "reshaped_return_per_episode": new_algo_state.log_reshaped_return[-keep:],
            "num_frames_per_episode": new_algo_state.log_num_frames[-keep:],
            "num_frames": self.num_frames,
        }
        return exps, logs, new_algo_state, final_rollout_state



    @abstractmethod
    def update_parameters(self, exps: dict, state: AlgoState, rng: jnp.ndarray) -> tuple[Dict, AlgoState]:
        pass

# PPO Algo

In [None]:
#
# PPO JAX Implementation
#
from typing import Callable, Dict, Any, Tuple

import chex
import jax
import jax.numpy as jnp
import optax
import flax.linen as nn
from gymnax.environments.environment import Environment

# Import the dataclasses and BaseAlgo class provided in the prompt
# Note: Some definitions from the prompt are included here for completeness.


@dataclass
class AlgoState:
    """High-level state of the algorithm, including model parameters and optimizer state."""
    params: Dict[str, Any]
    opt_state: Any
    # Non-JAX-compatible fields for logging are kept from the base class
    log_return: list = field(default_factory=list)
    log_reshaped_return: list = field(default_factory=list)
    log_num_frames: list = field(default_factory=list)

# Assuming the BaseAlgo class from the prompt is defined in the same scope
# For this code to be executable, the full `BaseAlgo` class must be present.
# class BaseAlgo(ABC):
#     ... (definition as provided in the prompt) ...

class PPO(BaseAlgo):
    """
    The Proximal Policy Optimization (PPO) algorithm implemented in JAX.

    This class inherits from the JAX-based `BaseAlgo` and implements the
    parameter update step according to the PPO objective function.
    """

    def __init__(self,
                 env: Environment,
                 env_params: Any,
                 acmodel: nn.Module,
                 num_procs: int,
                 num_frames_per_proc: int,
                 discount: float = 0.99,
                 lr: float = 0.001,
                 gae_lambda: float = 0.95,
                 entropy_coef: float = 0.01,
                 value_loss_coef: float = 0.5,
                 max_grad_norm: float = 0.5,
                 adam_eps: float = 1e-8,
                 clip_eps: float = 0.2,
                 epochs: int = 4,
                 batch_size: int = 256,
                 reshape_reward: Callable = None):
        """Initializes the PPO algorithm with its specific hyperparameters."""
        super().__init__(env, env_params, acmodel, num_procs, num_frames_per_proc, discount,
                         gae_lambda, reshape_reward)

        # PPO-specific hyperparameters
        self.entropy_coef = entropy_coef
        self.value_loss_coef = value_loss_coef
        self.clip_eps = clip_eps
        self.epochs = epochs
        self.batch_size = batch_size

        # The total number of frames is flattened for minibatching
        assert self.num_frames % self.batch_size == 0, "Total frames must be divisible by batch_size."

        # Initialize the optimizer using Optax
        # We chain gradient clipping with the Adam optimizer
        self.optimizer = optax.chain(
            optax.clip_by_global_norm(max_grad_norm),
            optax.adam(learning_rate=lr, eps=adam_eps)
        )

        # JIT-compile the core update loop for maximum performance
        self._jitted_update_epochs = jax.jit(self._update_epochs)

    def update_parameters(self, exps: Experience, algo_state: AlgoState, rng: chex.PRNGKey) -> Tuple[Dict, AlgoState]:
        """
        Updates the actor-critic model parameters using PPO.

        This method serves as a JAX-friendly wrapper around the core, JIT-compiled
        update logic contained in `_update_epochs`.

        Args:
            exps: A PyTree of collected experiences.
            algo_state: The current state of the algorithm (parameters, optimizer state).
            rng: A JAX random key for shuffling data.

        Returns:
            A tuple containing a dictionary of logs and the updated `AlgoState`.
        """
        # Execute the JIT-compiled update function
        (params, opt_state), logs = self._jitted_update_epochs(
            algo_state.params,
            algo_state.opt_state,
            exps,
            rng
        )

        # Create the new state with updated parameters and optimizer state
        new_algo_state = AlgoState(
            params=params,
            opt_state=opt_state,
            log_return=algo_state.log_return,
            log_reshaped_return=algo_state.log_reshaped_return,
            log_num_frames=algo_state.log_num_frames
        )

        return logs, new_algo_state

    def _loss_fn(self, params: Dict, batch: Experience) -> Tuple[chex.Array, Tuple]:
        """
        Computes the PPO loss for a single minibatch of experience.
        This function is designed to be used with `jax.grad`.

        Args:
            params: The model parameters.
            batch: A PyTree representing a minibatch of experiences.

        Returns:
            A tuple containing the total loss and auxiliary metrics (policy loss,
            value loss, entropy) for logging.
        """
        # Forward pass to get policy distribution and value estimates
        dist, value = self.acmodel.apply({'params': params}, batch.obs)

        # --- Policy Loss (Clipped Surrogate Objective) ---
        new_log_prob = dist.log_prob(batch.action)
        ratio = jnp.exp(new_log_prob - batch.log_prob)

        # Normalize advantages for stability (a common practice)
        advantage = (batch.advantage - batch.advantage.mean()) / (batch.advantage.std() + 1e-8)

        surr1 = ratio * advantage
        surr2 = jnp.clip(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * advantage
        policy_loss = -jnp.minimum(surr1, surr2).mean()

        # --- Value Loss (Clipped Value Function) ---
        value_clipped = batch.value + jnp.clip(value - batch.value, -self.clip_eps, self.clip_eps)
        surr1_v = (value - batch.returnn)**2
        surr2_v = (value_clipped - batch.returnn)**2
        value_loss = 0.5 * jnp.maximum(surr1_v, surr2_v).mean()

        # --- Entropy Bonus ---
        entropy = dist.entropy().mean()

        # --- Total Loss ---
        total_loss = (
            policy_loss
            + self.value_loss_coef * value_loss
            - self.entropy_coef * entropy
        )

        return total_loss, (policy_loss, value_loss, entropy)


    def _update_epochs(self, params: Dict, opt_state: Any, exps: Experience, rng: chex.PRNGKey) -> Tuple[Tuple, Dict]:
        """
        The core update logic that iterates over epochs and minibatches.
        This entire function is JIT-compiled.
        """
        def _epoch_step(carry, _):
            """Represents one full pass (epoch) over the entire dataset."""
            p, o, r = carry # params, opt_state, rng

            # Shuffle the experience data at the start of each epoch
            r, perm_key = jax.random.split(r)
            permutation = jax.random.permutation(perm_key, self.num_frames)
            shuffled_exps = jax.tree.map(lambda x: x[permutation], exps)

            # Reshape data into minibatches
            num_minibatches = self.num_frames // self.batch_size
            minibatches = jax.tree.map(
                lambda x: x.reshape((num_minibatches, self.batch_size) + x.shape[1:]),
                shuffled_exps
            )

            def _minibatch_step(carry, batch):
                """Updates parameters using a single minibatch."""
                params_mb, opt_state_mb = carry
                # Compute gradients and auxiliary loss data
                grad, (pi_loss, v_loss, ent) = jax.grad(self._loss_fn, has_aux=True)(params_mb, batch)
                # Update parameters and optimizer state
                updates, new_opt_state = self.optimizer.update(grad, opt_state_mb, params_mb)
                new_params = optax.apply_updates(params_mb, updates)
                return (new_params, new_opt_state), (pi_loss, v_loss, ent)

            # Scan the update function over all minibatches
            (new_p, new_o), (policy_losses, value_losses, entropies) = jax.lax.scan(
                _minibatch_step, (p, o), minibatches
            )

            # Return updated state and logs for the epoch
            return (new_p, new_o, r), (policy_losses.mean(), value_losses.mean(), entropies.mean())

        # Scan the epoch function over the configured number of epochs
        (final_params, final_opt_state, _), (pl, vl, ent) = jax.lax.scan(
            _epoch_step, (params, opt_state, rng), None, length=self.epochs
        )

        # Aggregate logs by taking the mean across all epochs
        logs = {
            "policy_loss": pl.mean(),
            "value_loss": vl.mean(),
            "entropy": ent.mean()
        }

        return (final_params, final_opt_state), logs

# main


In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
import wandb  # For logging
import numpy as np
import time
from tqdm import tqdm  # For progress bar
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from typing import Callable, Dict, Any, List, Tuple
from gymnax.environments.environment import Environment
import gymnax
import rlax
import chex
from flax.struct import PyTreeNode

def main():
    """
    Main training loop for PPO with wandb logging.
    """

    # --- Hyperparameters ---
    config = {
        "LR": 0.005,
        "NUM_PROCS": 64, # Increased for more parallel data
        "NUM_FRAMES_PER_PROC": 128,
        "DISCOUNT": 0.9,
        "GAE_LAMBDA": 0.5,
        "ENTROPY_COEF": 0.01,
        "VALUE_LOSS_COEF": 0.5,
        "MAX_GRAD_NORM": 0.5,
        "ADAM_EPS": 1e-8,
        "CLIP_EPS": 0.1,
        "EPOCHS": 2,
        "TOTAL_FRAMES": 10_000_000,
        "USE_WANDB": True,
        "WANDB_PROJECT": "jax_ppo_ltl_example", # Change this
        "WANDB_ENTITY": "your_username",       # <--!! CHANGE THIS !!
        "CHECKPOINT_DIR": "./checkpoints/ppo_ltl",
        "CHECKPOINT_INTERVAL": 50, # Save every 50 updates
    }

    # Calculate batch_size
    total_frames_per_update = config["NUM_PROCS"] * config["NUM_FRAMES_PER_PROC"]
    # We set batch_size so it divides the total frames, e.g., for 4 minibatches
    config["BATCH_SIZE"] = total_frames_per_update // 8
    assert total_frames_per_update % config["BATCH_SIZE"] == 0, "BATCH_SIZE must divide (NUM_PROCS * NUM_FRAMES_PER_PROC)"


    # --- W&B Setup ---
    if config["USE_WANDB"]:
        wandb.init(
            project=config["WANDB_PROJECT"],
            entity=config["WANDB_ENTITY"],
            config=config,
            monitor_gym=False, # gymnax not supported by default
            save_code=True,
        )

    os.makedirs(config["CHECKPOINT_DIR"], exist_ok=True)
    options = ocp.CheckpointManagerOptions(max_to_keep=5, create=True)
    checkpoint_manager = ocp.CheckpointManager(
        config["CHECKPOINT_DIR"],
        options=options
    )

    # --- JAX Key Setup ---
    key = jax.random.PRNGKey(int(time.time()))
    key, model_key, reset_key, rollout_key, update_key = jax.random.split(key, 5)

    # --- Environment Setup (using your provided snippet) ---
    print("Setting up environment...")
    letters = encode_letters("abcdefghijkl")
    env_params = EnvParams(grid_size=7, letters=letters, use_fixed_map=False, use_agent_centric_view=True, timeout=75, num_unique_letters=len(set(letters)))

    env = LetterEnv(num_letters=len(set(env_params.letters)), timeout=env_params.timeout)
    ltl_env = LTLEnv(env, env_params, progression_mode="full", ltl_sampler=None, intrinsic=0.0)

    # --- Algorithm and Model Setup ---
    print("Setting up model and algorithm...")
    acmodel = ActorCritic(output_dim=len(set(env_params.letters)))

    # Instantiate the PPO algorithm with its hyperparameters
    algo = PPO(
        env=ltl_env,
        env_params=env_params,
        acmodel=acmodel,
        num_procs=config["NUM_PROCS"],
        num_frames_per_proc=config["NUM_FRAMES_PER_PROC"],
        discount=config["DISCOUNT"],
        lr=config["LR"],
        gae_lambda=config["GAE_LAMBDA"],
        entropy_coef=config["ENTROPY_COEF"],
        value_loss_coef=config["VALUE_LOSS_COEF"],
        max_grad_norm=config["MAX_GRAD_NORM"],
        adam_eps=config["ADAM_EPS"],
        clip_eps=config["CLIP_EPS"],
        epochs=config["EPOCHS"],
        batch_size=config["BATCH_SIZE"],
        reshape_reward=None  # Add your reward shaping function here if you have one
    )

    # --- Initialization ---
    print("Initializing environment and model parameters...")
    # Vmap the environment reset for parallel processes
    vmapped_env_reset = jax.vmap(ltl_env.reset, in_axes=(0, None))
    reset_keys = jax.random.split(reset_key, config["NUM_PROCS"])

    # Get initial state
    init_obs, init_env_state = vmapped_env_reset(reset_keys, env_params)

    # Initialize model parameters
    init_params = algo.acmodel.init(model_key, init_obs)['params']

    # Initialize optimizer state
    init_opt_state = algo.optimizer.init(init_params)

    # Initialize the algorithm state
    algo_state = AlgoState(
        params=init_params,
        opt_state=init_opt_state,
        # Lists are initialized by default factory
    )

    # Initialize the rollout state
    rollout_state = RolloutState(
        rng=rollout_key,
        env_state=init_env_state,
        obs=init_obs,
        mask=jnp.ones((config["NUM_PROCS"],), dtype=jnp.float32),
        ep_return=jnp.zeros((config["NUM_PROCS"],)),
        ep_reshaped_return=jnp.zeros((config["NUM_PROCS"],)),
        ep_num_frames=jnp.zeros((config["NUM_PROCS"],), dtype=jnp.int32)
    )

    # --- Training Loop ---
    num_updates = config["TOTAL_FRAMES"] // total_frames_per_update
    print(f"Starting training for {num_updates} updates ({config['TOTAL_FRAMES']} total frames)...")
    print(f"Total frames per update: {total_frames_per_update}")
    print(f"Batch size: {config['BATCH_SIZE']}, Minibatches per epoch: {total_frames_per_update // config['BATCH_SIZE']}")

    start_time = time.time()

    for update_idx in tqdm(range(num_updates), desc="Training Updates"):
        # 1. Collect Experiences
        exps, rollout_logs, algo_state, rollout_state = algo.collect_experiences(
            algo_state, rollout_state
        )

        # 2. Update Parameters
        update_key, sub_key = jax.random.split(update_key)
        update_logs, algo_state = algo.update_parameters(
            exps, algo_state, sub_key
        )

        # 3. Logging
        total_frames_so_far = (update_idx + 1) * total_frames_per_update
        end_time = time.time()
        fps = total_frames_so_far / (end_time - start_time)

        # Combine logs
        logs = {
            "update": update_idx,
            "total_frames": total_frames_so_far,
            "fps": fps,
            "policy_loss": float(update_logs["policy_loss"]),
            "value_loss": float(update_logs["value_loss"]),
            "entropy": float(update_logs["entropy"]),
        }

        # Process episode logs (mean of all episodes finished *in this update*)
        if len(rollout_logs["return_per_episode"]) > 0:
            logs["mean_return"] = np.mean(rollout_logs["return_per_episode"])
            logs["mean_reshaped_return"] = np.mean(rollout_logs["reshaped_return_per_episode"])
            logs["mean_episode_length"] = np.mean(rollout_logs["num_frames_per_episode"])

        logs["num_episodes_finished_this_update"] = len(rollout_logs["return_per_episode"])

        if config["USE_WANDB"]:
            wandb.log(logs, step=total_frames_so_far)

        if (update_idx + 1) % config["CHECKPOINT_INTERVAL"] == 0:
            print(f"\nSaving checkpoint at update {update_idx}...")
            # Save the entire algo_state (params, opt_state, etc.)
            checkpoint_manager.save(
                update_idx,
                args=ocp.args.StandardSave(algo_state)
            )
            checkpoint_manager.wait_until_finished() # Wait for save to complete
            print("Checkpoint saved.")

        # Clear the per-update episode logs from the algo_state
        # to prevent logging stale data
        algo_state = algo_state.replace(
            log_return=[],
            log_reshaped_return=[],
            log_num_frames=[]
        )

    # --- End of Training ---
    print("Training finished.")
    if config["USE_WANDB"]:
        wandb.finish()


if __name__ == "__main__":
    main()