In [22]:
import gymnasium as gym
import gymnasium_robotics

# PyTorch
import torch

import os

# from collections import deque
import numpy as np
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import EvalCallback

In [23]:
env_id = 'FrankaKitchen-v1'
task = 'kettle'
gym.register_envs(gymnasium_robotics)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
n_actions = 9

In [24]:
granularity = 4
flat_dim = 59 + 7
action_space_size = granularity*9+1
transform_action_space = np.linspace(start=-1.0, stop=1.0, num=granularity)

def transform_action_from_int(action: int):
    if isinstance(action, list) :
        return action
    array_action = np.zeros(9)
    if action > action_space_size-1 or action < 0:
        raise AssertionError("transform_action_from_int")
    if action == action_space_size-1:
        return array_action
    which_action = action % 9
    singular_action = transform_action_space[action // 9]
    array_action[which_action] = singular_action
    return array_action

def transform_action_to_int(action) -> int:
    if isinstance(action, np.int64) :
        return action
    if sum(action) == 0.0:
        return action_space_size-1
    for i in range(9):
        if action[i] != 0.0:
            value = action[i]
            closest_quantized = min(transform_action_space, key=lambda x:abs(x - value))
            closest_quantized_index = -1
            for j in range(len(transform_action_space)):
                if transform_action_space[j] == closest_quantized:
                    closest_quantized_index = j
            if closest_quantized_index == -1: raise AssertionError("transform_action_to_int index is -1")
            return i + 9 * closest_quantized_index
    raise AssertionError("transform_action_to_int shouldn't be here")

def flatten_observation(observation):
    if not isinstance(observation, dict):
        return observation
    achieved = observation['achieved_goal'][task].astype(np.float32)
    obs = observation['observation'].astype(np.float32)

    flat_obs = np.concatenate([achieved, obs], dtype=np.float32)
    return flat_obs

In [25]:
def custom_reward(observation):
    achieved = observation['achieved_goal'][task][0:4]
    desired = observation['desired_goal'][task][0:4]
    res = 1.0 - np.linalg.norm(achieved - desired)
    assert res <= 1.0 and res >= 0.0, "Reward out of range!"
    return res   


In [26]:
obs_low = np.full((flat_dim,), -1e10, dtype=np.float32)
obs_high = np.full((flat_dim,), 1e10, dtype=np.float32)

class FlattenDictWrapper(gym.Wrapper):    
    def __init__(self, env):
        super().__init__(env)
        self.keys = env.observation_space.spaces.keys()
        self.observation_space = gym.spaces.Box(low=obs_low, high=obs_high, shape=(flat_dim,), dtype=np.float32)
        self.action_space = gym.spaces.Discrete(n=action_space_size)

    def observation(self, observation):
        return flatten_observation(observation)
    
    def action(self, action):
        return transform_action_to_int(action)
    
    def step(self, action):
        transformed_action = transform_action_to_int(action)
        obs, reward, terminated, truncated, info = self.env.step(transformed_action)
        if reward == 0.0:
            reward = custom_reward(obs)
        obs = flatten_observation(obs)
        return obs, reward, terminated, truncated, info

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        return flatten_observation(obs)
    
    
def make_env():
    env = gym.make(env_id, render_mode=None, tasks_to_complete=[task])  # Or your actual task
    env = FlattenDictWrapper(env)
    return env


In [None]:
n_training_envs = 16
env = DummyVecEnv([make_env]*n_training_envs)
eval_env = DummyVecEnv([make_env])

In [None]:
max_timesteps = 50000
exploration_fraction=0.95
run_name = f"dqn_{granularity}p_{max_timesteps}_reward_shaping_{int(exploration_fraction*100)}_"+task

In [None]:
model = DQN("MlpPolicy", env, device=device, exploration_fraction=exploration_fraction, learning_rate=0.001)

In [None]:
# model = DQN.load("dqn_3_10000_"+task)

In [None]:
eval_log_dir = os.path.join("eval_logs", run_name)
eval_callback = EvalCallback(eval_env, best_model_save_path=eval_log_dir,
                              log_path=eval_log_dir, eval_freq=max(500 // n_training_envs, 1),
                              n_eval_episodes=5, deterministic=True,
                              render=False)

In [None]:

model.learn(total_timesteps=max_timesteps, callback=eval_callback)

Eval num_timesteps=496, episode_reward=166.28 +/- 0.00
Episode length: 280.00 +/- 0.00
New best mean reward!
Eval num_timesteps=992, episode_reward=166.28 +/- 0.00
Episode length: 280.00 +/- 0.00
Eval num_timesteps=1488, episode_reward=166.28 +/- 0.00
Episode length: 280.00 +/- 0.00
Eval num_timesteps=1984, episode_reward=166.28 +/- 0.00
Episode length: 280.00 +/- 0.00
Eval num_timesteps=2480, episode_reward=166.28 +/- 0.00
Episode length: 280.00 +/- 0.00
Eval num_timesteps=2976, episode_reward=166.28 +/- 0.00
Episode length: 280.00 +/- 0.00
Eval num_timesteps=3472, episode_reward=166.28 +/- 0.00
Episode length: 280.00 +/- 0.00
Eval num_timesteps=3968, episode_reward=166.28 +/- 0.00
Episode length: 280.00 +/- 0.00
Eval num_timesteps=4464, episode_reward=166.28 +/- 0.00
Episode length: 280.00 +/- 0.00
Eval num_timesteps=4960, episode_reward=166.28 +/- 0.00
Episode length: 280.00 +/- 0.00
Eval num_timesteps=5456, episode_reward=166.28 +/- 0.00
Episode length: 280.00 +/- 0.00
Eval num_tim

<stable_baselines3.dqn.dqn.DQN at 0x7f65a9ce03d0>

In [None]:
model.save(run_name)

In [None]:
for i in range(10):
	env_eval = make_env()
	obs, _ = env_eval.reset()
	done = False
	ep_reward = 0

	while not done:
		action, _ = model.predict(obs, deterministic=True)
		obs, reward, terminated, truncated, _ = env_eval.step(transform_action_from_int(action))
		obs = flatten_observation(obs)
		done = terminated or truncated
		ep_reward += reward
	print(f"Episode reward: {ep_reward}")

Episode reward: 166.28406681999144
Episode reward: 166.2840668199915
Episode reward: 166.28406681999152
Episode reward: 166.28406681999152
Episode reward: 166.2840668199915
Episode reward: 166.28406681999152
Episode reward: 166.28406681999144
Episode reward: 166.28406681999144
Episode reward: 166.28406681999147
Episode reward: 166.28406681999144


In [None]:

env_ = gym.make(env_id, render_mode=None, tasks_to_complete=[task])

({'observation': array([ 1.47801565e-01, -1.76829107e+00,  1.84395217e+00, -2.47610622e+00,
          2.60691996e-01,  7.12663739e-01,  1.59498559e+00,  4.87271619e-02,
          3.67021062e-02, -2.53472528e-04, -2.65877959e-04,  2.26002676e-04,
         -3.65829334e-04,  5.72075854e-04,  6.87336170e-04,  8.14281501e-04,
         -1.63403919e-05, -1.42713245e-04, -2.67131394e-04, -5.12181348e-05,
          3.13329051e-05, -4.53444766e-05, -3.82073543e-06, -4.20677973e-05,
          6.28998687e-05,  4.04362722e-05,  4.62748053e-04, -2.26011323e-04,
         -4.67009093e-04, -6.44076225e-03, -1.79233727e-03,  1.05146686e-03,
         -2.69397033e-01,  3.50382421e-01,  1.61944820e+00,  9.99968903e-01,
          4.00662752e-03, -6.59483102e-03, -2.88975688e-04, -1.05641845e-06,
         -1.57191980e-06,  1.32664513e-06,  1.38601982e-06, -1.43032028e-06,
          1.43688572e-06, -8.37722653e-08,  1.46978715e-06, -3.77669880e-07,
          1.92717639e-06, -1.84364257e-06, -1.22620412e-05,  

In [None]:
state = env_.reset()

In [None]:
state[0]

{'observation': array([ 1.49064696e-01, -1.76751037e+00,  1.84415087e+00, -2.47686538e+00,
         2.59389739e-01,  7.12103377e-01,  1.59467207e+00,  4.76014193e-02,
         3.72559511e-02,  1.24158133e-04, -4.72168496e-04,  2.44411661e-04,
        -7.30645572e-05, -9.94192991e-04,  1.76777846e-04, -4.32781443e-04,
         2.40148757e-04,  5.28851150e-04, -2.40551674e-04, -5.30422618e-05,
         3.29306957e-05, -4.48706136e-05, -3.81012541e-06, -4.22421163e-05,
         6.28081484e-05,  4.02719850e-05,  4.62976087e-04, -2.29308999e-04,
        -4.66343836e-04, -6.43987546e-03, -1.72089090e-03,  1.05306444e-03,
        -2.69415644e-01,  3.50383656e-01,  1.61944537e+00,  9.99971732e-01,
         3.99291977e-03, -6.57423291e-03, -3.15833682e-04, -1.10305015e-06,
         1.05751294e-06, -1.67263545e-06,  1.35567182e-06, -1.21150897e-06,
        -2.21962015e-07,  2.02304157e-06,  8.22931005e-07, -1.77298674e-06,
         1.34866114e-06,  3.18721696e-07,  2.37975197e-05,  2.23025188e-0

In [None]:
sample = state[0]
desired_xyz = sample['desired_goal']['kettle'][0:3]
obs_xyz = sample['observation'][32:35]
distance_for_kettle = desired_xyz - obs_xyz

In [None]:
distance_for_kettle

array([0.03941564, 0.39961634, 0.00055463])