In [1]:
from wrapper import RelativePosition, FlattenDict, SerializeAction
import numpy as np
from loguru import logger
from IPython.display import clear_output
import time
from gymnasium.envs.registration import register
import gymnasium as gym
import train_params_with_model as params

size = params.size
relay_config = params.relay_config
client_config = params.client_config
init_config = params.init_config
is_polar = False

# register the environment
register(
    id='GridWorld-v0',
    entry_point='grid_world:GridWorldEnv',
    max_episode_steps=500,
    kwargs={
        "size": size,
        "relay_config": relay_config,
        "client_config": client_config,
        "init_config": init_config,
        "is_polar": is_polar,
        "is_plot": True,
        "is_log": True,
        "use_model": True,
    }
)

# create the environment
origin_env = gym.make('GridWorld-v0')
relative_env = RelativePosition(origin_env)
flatten_env = FlattenDict(relative_env)
env = SerializeAction(flatten_env, is_polar=is_polar)

In [2]:
import modified_DDPG

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0] 
max_action = float(env.action_space.high[0])

kwargs = {
		"state_dim": state_dim,
		"action_dim": action_dim,
		"max_action": max_action,
		"discount": 0.5,
		"tau": 0.005,
	}

kwargs["position_range"] = {
			"position": [-size / 2, size / 2],
			"height": [relay_config.min_height, relay_config.max_height]
		}
kwargs["relay_dim"] = relay_config.num * 3
kwargs["client_dim"] = client_config.num * 2
kwargs["speed"] = relay_config.speed

policy = modified_DDPG.DDPG(**kwargs)
policy.load("models/modified_DDPG_GridWorld-v0_with_model_0_2024-10-13_22-50-42")



In [None]:
seed = 0
# Set seeds
env.action_space.seed(seed)
np.random.seed(seed)
state, info = env.reset(seed=seed)
reward_list = []
current_Q_list = []
next_Q_list = []

for i in range(500):
    
    # action = env.action_space.sample()
    action = policy.select_action(np.array(state))

    next_state, reward, terminated, truncated, info = env.step(action=action) 
    reward_list.append(reward)
    logger.info(f"reward: {reward}")

    current_Q = policy.calculate_Q(state, action)
    current_Q_list.append(current_Q)
    logger.info(f"current_Q: {current_Q}")

    next_Q = policy.calculate_Q(next_state, policy.select_action(next_state))
    next_Q_list.append(next_Q)
    logger.info(f"next_Q: {policy.calculate_Q(next_state, policy.select_action(next_state))}")

    logger.info(f"reward/Q:{reward/policy.calculate_Q(state, action)}")

    state = next_state

    

    time.sleep(0.1)
    clear_output(wait=True)

In [None]:
# print(current_Q_list)
# print(next_Q_list)
# print(reward_list)
for i in range(99):
    print(current_Q_list[i+1] - next_Q_list[i])