Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image input into TD3 #869

Closed
C-monC opened this issue May 26, 2020 · 20 comments
Closed

Image input into TD3 #869

C-monC opened this issue May 26, 2020 · 20 comments
Labels
bug Something isn't working custom gym env Issue related to Custom Gym Env question Further information is requested

Comments

@C-monC
Copy link

C-monC commented May 26, 2020

Hi,

I have a custom env with a image observation space and a continuous action space. After training TD3 policies, when I evaluate them there seems to be no reaction to the image observation (I manually drag objects in front of the camera to see what happens).

from stable_baselines.td3.policies import CnnPolicy as td3CnnPolicy
from stable_baselines import TD3

env = gym.make('GripperEnv-v0')
env = Monitor(env, log_dir)
ExperimentName = "TD3_test"
policy_kwargs = dict(layers=[64, 64])
model = TD3(td3CnnPolicy, env, verbose=1, policy_kwargs=policy_kwargs, tensorboard_log="tmp/", buffer_size=15000,
            batch_size=2200, train_freq=2200, learning_starts=10000, learning_rate=1e-3)

callback = SaveOnBestTrainingRewardCallback(check_freq=1100, log_dir=log_dir)
time_steps = 50000
model.learn(total_timesteps=int(time_steps), callback=callback)
model.save("128128/"+ExperimentName)

I can view the observation using opencv and it is the right image (single channel, pixels between 0 and 1).

So how I understand it is that the CNN is 3 conv2D layers that connect to two layers 64 wide. Is it possible that I somehow disconnected these two parts or could it be that my hyper-parameters are just that bad? The behavior that is learnt by the policies is similar to if I just put in zero pixels in the network.

@Miffyli Miffyli added the question Further information is requested label May 26, 2020
@Miffyli
Copy link
Collaborator

Miffyli commented May 26, 2020

I first thought this would be related to #854, but seems like this code is correctly getting parameters of whole models for parameter updates. However the code seems to append some small fully connected layers after CNN, so you might want to try with layers=[] to avoid such bottlenecks (I do not expect this affects much).

TD3 has not been tested much on image-based tasks, so this can very well be hyperparameter thing. Your batch_size seems a bit too high. I am not intimately familiar with TD3, but you can check rl-zoo parameters for some general idea of what the parameters should be.

@C-monC
Copy link
Author

C-monC commented May 26, 2020

Thanks for the help.
I'll try the layers=[] and the batch_size reduction. Is batch_size and buffer_size based on env episodes or env timesteps?

I see layers can't be empty, I'll try a larger layer width like [128,128]

@araffin araffin added the custom gym env Issue related to Custom Gym Env label May 26, 2020
@Miffyli
Copy link
Collaborator

Miffyli commented May 26, 2020

Answers are in the docs. batch_size is the traditional mini-batch size for updates. buffer_size is based on env steps (experiences).

@C-monC
Copy link
Author

C-monC commented May 27, 2020

I copied the hyper-parameters for the cheetah env and trained until convergence. The policy always takes the same action exactly regardless of observation.

I trained SAC as well using cheetahs hyper-parameters for SAC and exactly the same happens. In both cases when I evaluate I change the observation a lot but the action space stays exactly the same.

Surely there should be some changes through the policy. This is what env.step(action) returns as observation (80x80 image).
image

I see the nature_cnn has no pooling and has quite a large filter size for the first layer. Could it be that the CNN can't detect any of the features in the image due to its architecture?

@Miffyli
Copy link
Collaborator

Miffyli commented May 27, 2020

I see the nature_cnn has no pooling and has quite a large filter size for the first layer. Could it be that the CNN can't detect any of the features in the image due to its architecture?

Even with a simple/small network like nature_cnn the agent learns to a degree. I have never seen it being the failure point.

Could you debug few more things? Namely:

  • Create the agent, get the parameters with get_parameters() function, train it and get the new parameters. Have all of the parameters changed (i.e. is the CNN updated as well)
  • And/or, could you check if actions change during the training

Being stuck executing one action could be a sign of too hard environment / bad learning result, but I do not have such an environment at hand to test this out. @araffin Do you have any experience with this? I'd personally be happy to try out SAC/TD3 in more image-based envs, but this will go to SB3 side.

@araffin
Copy link
Collaborator

araffin commented May 27, 2020

Being stuck executing one action could be a sign of too hard environment / bad learning result, but I do not have such an environment at hand to test this out. @araffin Do you have any experience with this?

SAC/TD3 are very slow with images, I recommend you to do something as here or here where you decouple policy learning from feature extraction.

This does not answer completely the question, but I don't have much time for this right now.

@C-monC
Copy link
Author

C-monC commented May 27, 2020

The first convolutional layer before training:

OrderedDict([('model/pi/c1/w:0', array([[[[-1.04753301e-01, -7.85421357e-02,  9.05572921e-02, ...,
          -2.74648249e-01, -2.92049676e-01,  2.62660027e-01]],

        [[ 1.60399958e-01,  9.32855457e-02,  6.64883479e-02, ...,
          -1.34873893e-02, -3.82574797e-02,  2.41547972e-02]],

        [[-3.66759658e-01,  5.54259829e-02,  2.27044344e-01, ...,
          -5.28310239e-03,  3.47178280e-02, -6.06268235e-02]],

The first convolutional layer after training:

OrderedDict([('model/pi/c1/w:0', array([[[[-1.04753301e-01, -7.85421357e-02,  9.05572921e-02, ...,
          -2.74648249e-01, -2.92049676e-01,  2.62660027e-01]],

        [[ 1.60399958e-01,  9.32855457e-02,  6.64883479e-02, ...,
          -1.34873893e-02, -3.82574797e-02,  2.41547972e-02]],

        [[-3.66759658e-01,  5.54259829e-02,  2.27044344e-01, ...,
          -5.28310239e-03,  3.47178280e-02, -6.06268235e-02]],

All the partly printed arrays match up before and after training. The actions definitely change during training (visible through pybullet gui) but could that just be due to random exploration as the model stays the same?

Thanks for those links @araffin. I'll definitely try them out.

@Miffyli
Copy link
Collaborator

Miffyli commented May 27, 2020

All the partly printed arrays match up before and after training. The actions definitely change during training (visible through pybullet gui) but could that just be due to random exploration as the model stays the same?

That seems worrying and a result of a bug. Even if the changes to CNNs would be minimal, some digits should change. I will look into this more later. Could you do one more thing and check which parameters changed? Something like this should do the trick (not tested):

has_changed = dict([(parameter_name, np.all(np.isclose(old_parameters[parameter_name], new_parameters[parameter_name]))) for parameter_name in old_parameters.keys()])

Edit: Thanks a lot for bringing this up and for informative replies!

@Miffyli Miffyli added the bug Something isn't working label May 27, 2020
@C-monC
Copy link
Author

C-monC commented May 27, 2020

Thanks for the help, I've been struggling with this for days now.

Is it possible it's the custom env? I've double checked the observation and action space is normalized. I also have made sure manually inputting an action array causes the correct action.

The result of that code after 3k steps (forgot to print after the first run so i made it less):
{'model/pi/c1/w:0': True, 'model/pi/c1/b:0': True, 'model/pi/c2/w:0': True, 'model/pi/c2/b:0': True, 'model/pi/c3/w:0': True, 'model/pi/c3/b:0': True, 'model/pi/fc1/w:0': True, 'model/pi/fc1/b:0': True, 'model/pi/fc0/kernel:0': True, 'model/pi/fc0/bias:0': True, 'model/pi/fc1/kernel:0': True, 'model/pi/fc1/bias:0': True, 'model/pi/dense/kernel:0': True, 'model/pi/dense/bias:0': False, 'model/pi/dense_1/kernel:0': True, 'model/pi/dense_1/bias:0': False, 'model/values_fn/c1/w:0': True, 'model/values_fn/c1/b:0': True, 'model/values_fn/c2/w:0': True, 'model/values_fn/c2/b:0': True, 'model/values_fn/c3/w:0': True, 'model/values_fn/c3/b:0': True, 'model/values_fn/fc1/w:0': True, 'model/values_fn/fc1/b:0': True, 'model/values_fn/vf/fc0/kernel:0': True, 'model/values_fn/vf/fc0/bias:0': True, 'model/values_fn/vf/fc1/kernel:0': True, 'model/values_fn/vf/fc1/bias:0': True, 'model/values_fn/vf/vf/kernel:0': True, 'model/values_fn/vf/vf/bias:0': False, 'model/values_fn/qf1/fc0/kernel:0': False, 'model/values_fn/qf1/fc0/bias:0': False, 'model/values_fn/qf1/fc1/kernel:0': False, 'model/values_fn/qf1/fc1/bias:0': False, 'model/values_fn/qf1/qf1/kernel:0': False, 'model/values_fn/qf1/qf1/bias:0': False, 'model/values_fn/qf2/fc0/kernel:0': False, 'model/values_fn/qf2/fc0/bias:0': False, 'model/values_fn/qf2/fc1/kernel:0': False, 'model/values_fn/qf2/fc1/bias:0': False, 'model/values_fn/qf2/qf2/kernel:0': False, 'model/values_fn/qf2/qf2/bias:0': False, 'target/values_fn/vf/fc0/kernel:0': True, 'target/values_fn/vf/fc0/bias:0': True, 'target/values_fn/vf/fc1/kernel:0': True, 'target/values_fn/vf/fc1/bias:0': True, 'target/values_fn/vf/vf/kernel:0': True, 'target/values_fn/vf/vf/bias:0': False}

@Miffyli
Copy link
Collaborator

Miffyli commented May 27, 2020

Is it possible it's the custom env? I've double checked the observation and action space is normalized. I also have made sure manually inputting an action array causes the correct action.

I doubt the custom env is at fault here. Even if normalizations were off / too large, the parameters should change one way or another, but that does not seem to be the case.

Thanks for providing the list! I will take a look at this later today.

@Miffyli
Copy link
Collaborator

Miffyli commented May 27, 2020

@C-monC

I am having trouble replicating this. Could you share the code you used to obtain the result above, and also share info of the relevant versions (os, python, stable-baselines, tensorflow, numpy)? Below is the code I am using to debug this (Python 3.7, current master of stable-baselines, TF 1.15). For me it shows all parameters are updated as expected, but I did not make it an exact match with your setup.

You could also check your with the env checker, if you have not already, for possible bugs.

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow
tensorflow.logging.set_verbosity(tensorflow.logging.ERROR)

import numpy as np
from gym.spaces import Box

from stable_baselines import SAC, TD3
from stable_baselines.ddpg import NormalActionNoise
from stable_baselines.common.identity_env import IdentityEnv, IdentityEnvBox
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.common.evaluation import evaluate_policy

from pprint import pprint
from collections import OrderedDict

class IdentityEnvImageBox(IdentityEnv):
    def __init__(self, low=-1, high=1, eps=0.05, ep_length=100):
        """
        Identity environment for testing purposes

        :param low: (float) the lower bound of the box dim
        :param high: (float) the upper bound of the box dim
        :param eps: (float) the epsilon bound for correct value
        :param ep_length: (int) the length of each episode in timesteps
        """
        space = Box(low=low, high=high, shape=(40, 40, 1), dtype=np.float32)
        super().__init__(ep_length=ep_length, space=space)
        self.observation_space = space
        self.action_space = Box(low=low, high=high, shape=(1,), dtype=np.float32)
        self.eps = eps

    def reset(self):
        self.current_step = 0
        self.num_resets += 1
        self._choose_next_state()
        observation = np.ones((40, 40, 1)) * self.state
        return observation

    def step(self, action):
        reward = self._get_reward(action)
        self._choose_next_state()
        self.current_step += 1
        done = self.current_step >= self.ep_length
        observation = np.ones((40, 40, 1)) * self.state
        return observation, reward, done, {}

    def _get_reward(self, action):
        return 1 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0


def test_identity_continuous(model_class):
    """
    Test if the algorithm (with a given policy)
    can learn an identity transformation (i.e. return observation as an action)
    """
    env = DummyVecEnv([lambda: IdentityEnvImageBox(eps=0.5)])

    if model_class == TD3:
        n_actions = 1
        action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
    else:
        action_noise = None

    policy_kwargs = dict(layers=[64, 64])
    model = model_class("CnnPolicy", env, gamma=0.1, seed=0,
                         action_noise=action_noise, buffer_size=int(1e6),
                         policy_kwargs=policy_kwargs)

    old_parameters = model.get_parameters()

    model.learn(total_timesteps=500)

    print("Evaluation results: {}".format(
        evaluate_policy(model, env, n_eval_episodes=20)
    ))

    new_parameters = model.get_parameters()

    # Check what has changed
    is_same = OrderedDict([
        (
            parameter_name,
            np.all(np.isclose(old_parameters[parameter_name], new_parameters[parameter_name]))
        ) for parameter_name in old_parameters.keys()
    ])

    pprint(is_same)


if __name__ == "__main__":
    for model_class in [SAC, TD3]:
        print("Testing {}".format(model_class))
        test_identity_continuous(model_class)

@C-monC
Copy link
Author

C-monC commented May 27, 2020

Okay this is unfortunate. Sorry for wasting your time.

Pybullet returns RGB images with 0 to 255 but depth images from 0 to 1 and when I swapped over to depth images I didn't scale up the values, I reduced the observation high parameter to 1 instead.

The cnn updates correctly now.
Thanks a lot for the help, appreciate it.

@C-monC C-monC closed this as completed May 27, 2020
@tkelestemur
Copy link

Aren't we supposed to give image observations as values between 0-255? I am using 2 channel images as observation and map it to 0-255 from values between 0-1. Similar to @C-monC, I have depth images as the observations and I'm getting the same problem where the agent always chooses same action no matter the observations are. Btw, I'm using A2C.

@C-monC
Copy link
Author

C-monC commented May 28, 2020

Is your observation space a Box(low=0, high=255, shape=(self._width, self._height, 1), dtype=np.uint8) ?
env checker should tell you if you've setup your observation wrong.

@tkelestemur
Copy link

tkelestemur commented May 28, 2020

It's gym.spaces.Box(shape=(self.width, self.height, 2), high=255, low=0, dtype=np.uint8). For some reason, the agent decides to assign a really high probability (mostly 0.9) to one of the 4 actions. The action with high prob. changes from training to training but it's always one action.

@C-monC
Copy link
Author

C-monC commented May 28, 2020

Is your depth map dual channel? Did env checker return nothing?
This is very similar to what was happening to me

@tkelestemur
Copy link

No, it's one channel, normalized between [0, 1]. I render the depth image in MuJoCo and scale the meter values between [0, 1]. The second channel is the belief of the agent's location. In other words, it's the posterior probability of the agent being in of the pixels.

I tried env_checker but that didn't throw any warning. I'm currently trying PPO to see if it's a problem with A2C. @C-monC How did you solve your issue?

@C-monC
Copy link
Author

C-monC commented May 28, 2020

Is it possible you're normalizing it yourself before stable baselines normalizes it? My issue was that my env was already returning images scaled 0-1 and not letting stable baselines do it.

@tkelestemur
Copy link

tkelestemur commented May 28, 2020

My environment produces the images normalized as in all the values are between 0-1. I was multiplying the images with 255 and typecast them to np.uint8. I tried training without scaling to 255 but that didn't work either. It still chooses only one action. I might have another bug somewhere else.

Edit:
If I run check_env with float images (not scaled to 255), I get the following warning:

UserWarning: It seems that your observation is an image but the `dtype` of your observation_space is not `np.uint8`. If your observation is not an image, we recommend you to flatten the observation to have only a 1D vector
  warnings.warn("It seems that your observation is an image but the `dtype` "

I think normalizing the image-based observation spaces should be optional so that we wouldn't have to discretize the images which cause information loss.

@pengzhenghao
Copy link

Hi @tkelestemur , do you have any progress? I also encounter the problem that the TD3 agent learns nothing but repeat the same action even though I have checked the image observation which should be correct. I guess maybe these is because TD3 does not fit well in image input as @araffin suggests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working custom gym env Issue related to Custom Gym Env question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants