In [None]:
from environment import *
from dqn_agent import Agent
import jax
from jax import jit, random, numpy as jnp
from gymnasium import spaces
import os



@jit
def example_reward_function(state, goal_state):
    """Define your reward logic here."""
    # if jnp.array_equal(state, jnp.array([4, 4])):
    #    return 10
    # else:
    #    return -1
    is_goal = jnp.all(state == goal_state)
    current_distance = jnp.linalg.norm(state - goal_state)
    total_distance = jnp.sqrt(2)

    percentage =  (current_distance / total_distance)
    
    distance_reward = jnp.where(
        percentage < 0.33, -1,
        jnp.where(percentage < 0.66, -2, -3)
    )
    
    # Use jnp.where to select between goal reward and distance-based reward
    reward = jnp.where(is_goal, 15.0, distance_reward)
    return reward


In [None]:


def example_transition_function(state, action, state_space_shape):
    """Define your state transition logic here."""
    x, y = state

    def magic(key):
        return random.uniform(random.PRNGKey(key), (1, 2), minval=-1, maxval=1)[0]

    def action_one(_):
        return x + y

    def action_two(_):
        return jnp.abs(x - y)

    def action_three(_):
        return x / y

    def action_four(_):
        return y / x

    fp_key = jax.lax.switch(
        action, [action_one, action_two, action_three, action_four], None
    )
    int_key = jnp.array((fp_key * 10**7), int)
    x, y = magic(int_key) * fp_key
    return jnp.array([x, y])



In [None]:


# ENV SETTINGS

dimensions = env_min, env_max = -1.0, 1.0
dtype = jnp.float32

state_space = spaces.Box(
    low=env_min, high=max(env_min, env_max), shape=(1,len(dimensions)), dtype=dtype
)

action_space_n = 4
action_space = spaces.Discrete(action_space_n)

target_state = jnp.array([0.0, 0.0], dtype=dtype)
initial_state = jnp.array([-1.0, 1.0], dtype=dtype)

config = EnvironmentConfig(
    state_space=state_space,
    action_space=action_space,
    initial_state=initial_state,
    target_state=target_state,
    reward_function=jit(example_reward_function),
    transition_function=jit(example_transition_function),
)

In [None]:

env = create_environment(config)
state_size = len(state_space.shape)
action_size = env.action_space.n

batch_size = 32  # increase by powers of 2
num_episodes = 100  # Number of episodes to simulate
num_iterations = 2000  # Number of steps per episode


output_dir = "results/gen1"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

agent = Agent(state_size, action_size)

# ------------------- Training -------------------

import time

done = False

start_time = time.time()  # Start time for the episode

for episode in range(num_episodes):

    state, _, _ = env.reset()
    episode_reward = 0  # Track total reward per episode


    for t in range(num_iterations):
        action = agent.act(state)
        next_state, reward, done, _ = env.step(action)

        episode_reward += reward  # Accumulate reward per step
        agent.remember(state, action, reward, next_state, done)

        state = next_state

        if done or t == (num_iterations - 1):
            # need better logging
            elapsed_time = time.time() - start_time
            hours, rem = divmod(elapsed_time, 3600)
            minutes, seconds = divmod(rem, 60)
            print(
                f"Time: {int(hours):02}h {int(minutes):02}m {int(seconds):02}s, "
                f"Episode: {episode}/{num_episodes}, Score: {episode_reward}, Done: {done}"
            )
            break

    # Replay the experience
    if len(agent.memory) > 32:
        agent.replay(32)
        if agent.epsilon > agent.epsilon_min:
            agent.epsilon *= agent.epsilon_decay