In [104]:
import jax
import jax.numpy as jnp
from jax import random, jit

from typing import Any, Callable

from gymnasium import spaces, Env


In [105]:

class EnvConfig:
    def __init__(self, state_space: Any, action_space: Any, start_state: Any, goal_state: Any, reward_fn: Callable, transition_fn: Callable):
        self.state_space = state_space
        self.action_space = action_space
        self.start_state = start_state
        self.goal_state = goal_state
        self.reward_fn = reward_fn
        self.transition_fn = transition_fn

def create_env(config: EnvConfig):
    return MDPEnv(config)


class MDPEnv(Env):
    def __init__(self, config: EnvConfig):
        super(MDPEnv, self).__init__()
        # config unload
        self.config = config
        
        self.state_space = config.state_space
        self.action_space = config.action_space 
        
        self.state = config.start_state 
        self.goal_state = config.goal_state

        self.reward_fn = config.reward_fn
        self.transition_fn = config.transition_fn
       
        # setups, overide the functions with jax.jit      
        self._transition = jax.jit(lambda state, action: self.transition_fn(state, action, self.state_space.shape))
        self._compute_reward = jax.jit(lambda state: self.reward_fn(state, self.goal_state))
        self._is_done = jax.jit(lambda state: jnp.array_equal(state, self.goal_state))
        
        # modify gym spaces
        # self.observation_space = spaces.Box(low=-jnp.inf, high=jnp.inf, shape=self.state_space.shape, dtype=jnp.float32)

    def reset(self):
        """Resets the environment to the start state and returns the initial observation."""
        self.state = self.config.start_state
        reward = self._compute_reward(self.state)
        done = self._is_done(self.state)
        return self.state, reward, done

    def step(self, action):
        """Takes a step in the environment based on the action provided.
        
        Args:
            action: The action to take.
        
        Returns:
            (next_state, reward, done, info).
        """
        next_state = self._transition(self.state, action)
        reward = self._compute_reward(next_state)
        done = self._is_done(next_state)
        self.state = next_state
        return next_state, reward, done, {}
    


In [106]:
    # EXTRA METHODS INSTEAD OF ONE LINE FUNCTIONS
    #--------------------------------------------  
    @jax.jit
    def _transition(self, state, action):
        return self.config.transition_structure(state, action, self.state_space.shape)
    
    @jax.jit
    def _compute_reward(self, state):
        return self.config.reward_structure(state, self.goal_state)
    @jax.jit
    def _is_done(self, state):
        return  jnp.array_equal(self.state, self.goal_state)

    #--------------------------------------------  

In [107]:
max(5,5)

5

In [108]:
# Define state space and action space
env_width, env_height = 5, 5
state_space = spaces.Box(low=0, high=max(env_width, env_height), shape=(env_width, env_height), dtype=jnp.int32)

action_space_num = 4
action_space = spaces.Discrete(action_space_num)

start_state = jnp.array([0, 0], dtype=jnp.int32)
goal_state = jnp.array([4, 4], dtype=jnp.int32)

def calculate_reward(state, goal_state):
    '''Define your reward logic here.'''
    #if jnp.array_equal(state, jnp.array([4, 4])):
    #    return 10
    #else:
    #    return -1
    return jnp.where(jnp.array_equal(state, goal_state), 10, -1) 

def get_next_state(state, action, state_space_shape):
    """Define your state transition logic here."""
    x, y = state

    def move_up(_):
        return jax.lax.max(0, x - 1), y

    def move_down(_):
        return jax.lax.min(state_space_shape[0] - 1, x + 1), y

    def move_left(_):
        return x, jax.lax.max(0, y - 1)

    def move_right(_):
        return x, jax.lax.min(state_space_shape[1] - 1, y + 1)

    x, y = jax.lax.switch(action, [move_up, move_down, move_left, move_right], None)
    return jnp.array([x, y])

# Create the environment
config = EnvConfig(
    state_space=state_space,
    action_space=action_space,
    start_state=start_state,
    goal_state=goal_state,
    reward_fn=jit(calculate_reward),
    transition_fn=jit(get_next_state)
)



NameError: name 'jit' is not defined

In [None]:
type(int(action))


int

In [None]:
env = create_env(config)
# Simulation
key = random.PRNGKey(0)
num_episodes = 10  # Number of episodes to simulate
for e in range(num_episodes):
    
    state, reward, done = env.reset()
    print(f"episode: {e} state: {state}, reward: {reward}, action: none done: {done}")

    episode_reward = 0  # Track total reward per episode

    for _ in range(200): # Maximum number of steps per episode or until done
        key, subkey = random.split(key)
        action = random.choice(subkey, 4)
        %timeit next_state, reward, done, _ = env.step(action)
        episode_reward += reward  # Accumulate reward per step
        state = next_state
        print(f"episode: {e} state: {state}, reward: {reward}, action: {action}, done: {done}")
        if (done or _ == 199):
            print(f"Total Reward: {episode_reward}")  # Print total reward per episode
            episode_reward = 0  # Reset episode reward
            break

episode: 0 state: [0 0], reward: -1, action: none done: False
episode: 0 state: [0 0], reward: -1, action: 0, done: False
episode: 0 state: [1 0], reward: -1, action: 1, done: False
episode: 0 state: [0 0], reward: -1, action: 0, done: False
episode: 0 state: [0 0], reward: -1, action: 0, done: False
episode: 0 state: [0 0], reward: -1, action: 0, done: False
episode: 0 state: [0 0], reward: -1, action: 0, done: False
episode: 0 state: [0 0], reward: -1, action: 0, done: False
episode: 0 state: [1 0], reward: -1, action: 1, done: False
episode: 0 state: [2 0], reward: -1, action: 1, done: False
episode: 0 state: [2 0], reward: -1, action: 2, done: False
episode: 0 state: [2 0], reward: -1, action: 2, done: False
episode: 0 state: [2 0], reward: -1, action: 2, done: False
episode: 0 state: [3 0], reward: -1, action: 1, done: False
episode: 0 state: [4 0], reward: -1, action: 1, done: False
episode: 0 state: [3 0], reward: -1, action: 0, done: False
episode: 0 state: [3 1], reward: -1, a

In [None]:
a=key

In [None]:
b=subkey