Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HER : new functionality, enables demo based training #474

Merged
merged 1 commit into from Oct 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
48 changes: 48 additions & 0 deletions baselines/her/README.md
Expand Up @@ -30,3 +30,51 @@ python -m baselines.her.experiment.train --num_cpu 19
This will require a machine with sufficient amount of physical CPU cores. In our experiments,
we used [Azure's D15v2 instances](https://docs.microsoft.com/en-us/azure/virtual-machines/linux/sizes),
which have 20 physical cores. We only scheduled the experiment on 19 of those to leave some head-room on the system.


## Hindsight Experience Replay with Demonstrations
Using pre-recorded demonstrations to Overcome the exploration problem in HER based Reinforcement learning.
For details, please read the [paper](https://arxiv.org/pdf/1709.10089.pdf).

### Getting started
The first step is to generate the demonstration dataset. This can be done in two ways, either by using a VR system to manipulate the arm using physical VR trackers or the simpler way is to write a script to carry out the respective task. Now some tasks can be complex and thus it would be difficult to write a hardcoded script for that task (eg. Fetch Push), but here our focus is on providing an algorithm that helps the agent to learn from demonstrations, and not on the demonstration generation paradigm itself. Thus the data collection part is left to the reader's choice.

We provide a script for the Fetch Pick and Place task, to generate demonstrations for the Pick and Place task execute:
```bash
python experiment/data_generation/fetch_data_generation.py
```
This outputs ```data_fetch_random_100.npz``` file which is our data file.

#### Configuration
The provided configuration is for training an agent with HER without demonstrations, we need to change a few paramters for the HER algorithm to learn through demonstrations, to do that, set:

* bc_loss: 1 - whether or not to use the behavior cloning loss as an auxilliary loss
* q_filter: 1 - whether or not a Q value filter should be used on the Actor outputs
* num_demo: 100 - number of expert demo episodes
* demo_batch_size: 128 - number of samples to be used from the demonstrations buffer, per mpi thread
* prm_loss_weight: 0.001 - Weight corresponding to the primary loss
* aux_loss_weight: 0.0078 - Weight corresponding to the auxilliary loss also called the cloning loss

Apart from these changes the reported results also have the following configurational changes:

* n_cycles: 20 - per epoch
* batch_size: 1024 - per mpi thread, total batch size
* random_eps: 0.1 - percentage of time a random action is taken
* noise_eps: 0.1 - std of gaussian noise added to not-completely-random actions

Now training an agent with pre-recorded demonstrations:
```bash
python -m baselines.her.experiment.train --env=FetchPickAndPlace-v0 --n_epochs=1000 --demo_file=/Path/to/demo_file.npz --num_cpu=1
```

This will train a DDPG+HER agent on the `FetchPickAndPlace` environment by using previously generated demonstration data.
To inspect what the agent has learned, use the play script as described above.

### Results
Training with demonstrations helps overcome the exploration problem and achieves a faster and better convergence. The following graphs contrast the difference between training with and without demonstration data, We report the mean Q values vs Epoch and the Success Rate vs Epoch:


<div class="imgcap" align="middle">
<center><img src="../../data/fetchPickAndPlaceContrast.png"></center>
<div class="thecap" align="middle"><b>Training results for Fetch Pick and Place task constrasting between training with and without demonstration data.</b></div>
</div>
101 changes: 99 additions & 2 deletions baselines/her/ddpg.py
Expand Up @@ -6,7 +6,7 @@

from baselines import logger
from baselines.her.util import (
import_function, store_args, flatten_grads, transitions_in_episode_batch)
import_function, store_args, flatten_grads, transitions_in_episode_batch, convert_episode_to_batch_major)
from baselines.her.normalizer import Normalizer
from baselines.her.replay_buffer import ReplayBuffer
from baselines.common.mpi_adam import MpiAdam
Expand All @@ -16,13 +16,17 @@ def dims_to_shapes(input_dims):
return {key: tuple([val]) if val > 0 else tuple() for key, val in input_dims.items()}


global demoBuffer #buffer for demonstrations

class DDPG(object):
@store_args
def __init__(self, input_dims, buffer_size, hidden, layers, network_class, polyak, batch_size,
Q_lr, pi_lr, norm_eps, norm_clip, max_u, action_l2, clip_obs, scope, T,
rollout_batch_size, subtract_goals, relative_goals, clip_pos_returns, clip_return,
bc_loss, q_filter, num_demo, demo_batch_size, prm_loss_weight, aux_loss_weight,
sample_transitions, gamma, reuse=False, **kwargs):
"""Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).
Added functionality to use demonstrations for training to Overcome exploration problem.
Args:
input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
Expand Down Expand Up @@ -50,6 +54,12 @@ def __init__(self, input_dims, buffer_size, hidden, layers, network_class, polya
sample_transitions (function) function that samples from the replay buffer
gamma (float): gamma used for Q learning updates
reuse (boolean): whether or not the networks should be reused
bc_loss: whether or not the behavior cloning loss should be used as an auxilliary loss
q_filter: whether or not a filter on the q value update should be used when training with demonstartions
num_demo: Number of episodes in to be used in the demonstration buffer
demo_batch_size: number of samples to be used from the demonstrations buffer, per mpi thread
prm_loss_weight: Weight corresponding to the primary loss
aux_loss_weight: Weight corresponding to the auxilliary loss also called the cloning loss
"""
if self.clip_return is None:
self.clip_return = np.inf
Expand Down Expand Up @@ -92,6 +102,9 @@ def __init__(self, input_dims, buffer_size, hidden, layers, network_class, polya
buffer_size = (self.buffer_size // self.rollout_batch_size) * self.rollout_batch_size
self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions)

global demoBuffer
demoBuffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions) #initialize the demo buffer; in the same way as the primary data buffer

def _random_action(self, n):
return np.random.uniform(low=-self.max_u, high=self.max_u, size=(n, self.dimu))

Expand Down Expand Up @@ -138,6 +151,57 @@ def get_actions(self, o, ag, g, noise_eps=0., random_eps=0., use_target_net=Fals
else:
return ret

def initDemoBuffer(self, demoDataFile, update_stats=True): #function that initializes the demo buffer

demoData = np.load(demoDataFile) #load the demonstration data from data file
info_keys = [key.replace('info_', '') for key in self.input_dims.keys() if key.startswith('info_')]
info_values = [np.empty((self.T, 1, self.input_dims['info_' + key]), np.float32) for key in info_keys]

for epsd in range(self.num_demo): # we initialize the whole demo buffer at the start of the training
obs, acts, goals, achieved_goals = [], [] ,[] ,[]
i = 0
for transition in range(self.T):
obs.append([demoData['obs'][epsd ][transition].get('observation')])
acts.append([demoData['acs'][epsd][transition]])
goals.append([demoData['obs'][epsd][transition].get('desired_goal')])
achieved_goals.append([demoData['obs'][epsd][transition].get('achieved_goal')])
for idx, key in enumerate(info_keys):
info_values[idx][transition, i] = demoData['info'][epsd][transition][key]

obs.append([demoData['obs'][epsd][self.T].get('observation')])
achieved_goals.append([demoData['obs'][epsd][self.T].get('achieved_goal')])

episode = dict(o=obs,
u=acts,
g=goals,
ag=achieved_goals)
for key, value in zip(info_keys, info_values):
episode['info_{}'.format(key)] = value

episode = convert_episode_to_batch_major(episode)
global demoBuffer
demoBuffer.store_episode(episode) # create the observation dict and append them into the demonstration buffer

print("Demo buffer size currently ", demoBuffer.get_current_size()) #print out the demonstration buffer size

if update_stats:
# add transitions to normalizer to normalize the demo data as well
episode['o_2'] = episode['o'][:, 1:, :]
episode['ag_2'] = episode['ag'][:, 1:, :]
num_normalizing_transitions = transitions_in_episode_batch(episode)
transitions = self.sample_transitions(episode, num_normalizing_transitions)

o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions['g'], transitions['ag']
transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
# No need to preprocess the o_2 and g_2 since this is only used for stats

self.o_stats.update(transitions['o'])
self.g_stats.update(transitions['g'])

self.o_stats.recompute_stats()
self.g_stats.recompute_stats()
episode.clear()

def store_episode(self, episode_batch, update_stats=True):
"""
episode_batch: array of batch_size x (T or T+1) x dim_key
Expand Down Expand Up @@ -185,7 +249,18 @@ def _update(self, Q_grad, pi_grad):
self.pi_adam.update(pi_grad, self.pi_lr)

def sample_batch(self):
transitions = self.buffer.sample(self.batch_size)
if self.bc_loss: #use demonstration buffer to sample as well if bc_loss flag is set TRUE
transitions = self.buffer.sample(self.batch_size - self.demo_batch_size)
global demoBuffer
transitionsDemo = demoBuffer.sample(self.demo_batch_size) #sample from the demo buffer
for k, values in transitionsDemo.items():
rolloutV = transitions[k].tolist()
for v in values:
rolloutV.append(v.tolist())
transitions[k] = np.array(rolloutV)
else:
transitions = self.buffer.sample(self.batch_size) #otherwise only sample from primary buffer

o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
ag, ag_2 = transitions['ag'], transitions['ag_2']
transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
Expand Down Expand Up @@ -248,6 +323,9 @@ def _create_network(self, reuse=False):
for i, key in enumerate(self.stage_shapes.keys())])
batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])

#choose only the demo buffer samples
mask = np.concatenate((np.zeros(self.batch_size - self.demo_batch_size), np.ones(self.demo_batch_size)), axis = 0)

# networks
with tf.variable_scope('main') as vs:
if reuse:
Expand All @@ -270,6 +348,25 @@ def _create_network(self, reuse=False):
clip_range = (-self.clip_return, 0. if self.clip_pos_returns else np.inf)
target_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
self.Q_loss_tf = tf.reduce_mean(tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))

if self.bc_loss ==1 and self.q_filter == 1 : # train with demonstrations and use bc_loss and q_filter both
maskMain = tf.reshape(tf.boolean_mask(self.main.Q_tf > self.main.Q_pi_tf, mask), [-1]) #where is the demonstrator action better than actor action according to the critic? choose those samples only
#define the cloning loss on the actor's actions only on the samples which adhere to the above masks
self.cloning_loss_tf = tf.reduce_sum(tf.square(tf.boolean_mask(tf.boolean_mask((self.main.pi_tf), mask), maskMain, axis=0) - tf.boolean_mask(tf.boolean_mask((batch_tf['u']), mask), maskMain, axis=0)))
self.pi_loss_tf = -self.prm_loss_weight * tf.reduce_mean(self.main.Q_pi_tf) #primary loss scaled by it's respective weight prm_loss_weight
self.pi_loss_tf += self.prm_loss_weight * self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u)) #L2 loss on action values scaled by the same weight prm_loss_weight
self.pi_loss_tf += self.aux_loss_weight * self.cloning_loss_tf #adding the cloning loss to the actor loss as an auxilliary loss scaled by its weight aux_loss_weight

elif self.bc_loss == 1 and self.q_filter == 0: # train with demonstrations without q_filter
self.cloning_loss_tf = tf.reduce_sum(tf.square(tf.boolean_mask((self.main.pi_tf), mask) - tf.boolean_mask((batch_tf['u']), mask)))
self.pi_loss_tf = -self.prm_loss_weight * tf.reduce_mean(self.main.Q_pi_tf)
self.pi_loss_tf += self.prm_loss_weight * self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
self.pi_loss_tf += self.aux_loss_weight * self.cloning_loss_tf

else: #If not training with demonstrations
self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))

self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
Expand Down
13 changes: 13 additions & 0 deletions baselines/her/experiment/config.py
Expand Up @@ -44,6 +44,13 @@
# normalization
'norm_eps': 0.01, # epsilon used for observation normalization
'norm_clip': 5, # normalized observations are cropped to this values

'bc_loss': 0, # whether or not to use the behavior cloning loss as an auxilliary loss
'q_filter': 0, # whether or not a Q value filter should be used on the Actor outputs
'num_demo': 100, # number of expert demo episodes
'demo_batch_size': 128, #number of samples to be used from the demonstrations buffer, per mpi thread 128/1024 or 32/256
'prm_loss_weight': 0.001, #Weight corresponding to the primary loss
'aux_loss_weight': 0.0078, #Weight corresponding to the auxilliary loss also called the cloning loss
}


Expand Down Expand Up @@ -145,6 +152,12 @@ def configure_ddpg(dims, params, reuse=False, use_mpi=True, clip_return=True):
'subtract_goals': simple_goal_subtract,
'sample_transitions': sample_her_transitions,
'gamma': gamma,
'bc_loss': params['bc_loss'],
'q_filter': params['q_filter'],
'num_demo': params['num_demo'],
'demo_batch_size': params['demo_batch_size'],
'prm_loss_weight': params['prm_loss_weight'],
'aux_loss_weight': params['aux_loss_weight'],
})
ddpg_params['info'] = {
'env_name': params['env_name'],
Expand Down
149 changes: 149 additions & 0 deletions baselines/her/experiment/data_generation/fetch_data_generation.py
@@ -0,0 +1,149 @@
import gym
import time
import random
import numpy as np
import rospy
import roslaunch

from random import randint
from std_srvs.srv import Empty
from sensor_msgs.msg import JointState
from geometry_msgs.msg import PoseStamped
from geometry_msgs.msg import Pose
from std_msgs.msg import Float64
from controller_manager_msgs.srv import SwitchController
from gym.utils import seeding


"""Data generation for the case of a single block pick and place in Fetch Env"""

actions = []
observations = []
infos = []

def main():
env = gym.make('FetchPickAndPlace-v0')
numItr = 100
initStateSpace = "random"
env.reset()
print("Reset!")
while len(actions) < numItr:
obs = env.reset()
print("ITERATION NUMBER ", len(actions))
goToGoal(env, obs)


fileName = "data_fetch"
fileName += "_" + initStateSpace
fileName += "_" + str(numItr)
fileName += ".npz"

np.savez_compressed(fileName, acs=actions, obs=observations, info=infos) # save the file

def goToGoal(env, lastObs):

goal = lastObs['desired_goal']
objectPos = lastObs['observation'][3:6]
gripperPos = lastObs['observation'][:3]
gripperState = lastObs['observation'][9:11]
object_rel_pos = lastObs['observation'][6:9]
episodeAcs = []
episodeObs = []
episodeInfo = []

object_oriented_goal = object_rel_pos.copy()
object_oriented_goal[2] += 0.03 # first make the gripper go slightly above the object

timeStep = 0 #count the total number of timesteps
episodeObs.append(lastObs)

while np.linalg.norm(object_oriented_goal) >= 0.005 and timeStep <= env._max_episode_steps:
env.render()
action = [0, 0, 0, 0]
object_oriented_goal = object_rel_pos.copy()
object_oriented_goal[2] += 0.03

for i in range(len(object_oriented_goal)):
action[i] = object_oriented_goal[i]*6

action[len(action)-1] = 0.05 #open

obsDataNew, reward, done, info = env.step(action)
timeStep += 1

episodeAcs.append(action)
episodeInfo.append(info)
episodeObs.append(obsDataNew)

objectPos = obsDataNew['observation'][3:6]
gripperPos = obsDataNew['observation'][:3]
gripperState = obsDataNew['observation'][9:11]
object_rel_pos = obsDataNew['observation'][6:9]

while np.linalg.norm(object_rel_pos) >= 0.005 and timeStep <= env._max_episode_steps :
env.render()
action = [0, 0, 0, 0]
for i in range(len(object_rel_pos)):
action[i] = object_rel_pos[i]*6

action[len(action)-1] = -0.005

obsDataNew, reward, done, info = env.step(action)
timeStep += 1

episodeAcs.append(action)
episodeInfo.append(info)
episodeObs.append(obsDataNew)

objectPos = obsDataNew['observation'][3:6]
gripperPos = obsDataNew['observation'][:3]
gripperState = obsDataNew['observation'][9:11]
object_rel_pos = obsDataNew['observation'][6:9]


while np.linalg.norm(goal - objectPos) >= 0.01 and timeStep <= env._max_episode_steps :
env.render()
action = [0, 0, 0, 0]
for i in range(len(goal - objectPos)):
action[i] = (goal - objectPos)[i]*6

action[len(action)-1] = -0.005

obsDataNew, reward, done, info = env.step(action)
timeStep += 1

episodeAcs.append(action)
episodeInfo.append(info)
episodeObs.append(obsDataNew)

objectPos = obsDataNew['observation'][3:6]
gripperPos = obsDataNew['observation'][:3]
gripperState = obsDataNew['observation'][9:11]
object_rel_pos = obsDataNew['observation'][6:9]

while True: #limit the number of timesteps in the episode to a fixed duration
env.render()
action = [0, 0, 0, 0]
action[len(action)-1] = -0.005 # keep the gripper closed

obsDataNew, reward, done, info = env.step(action)
timeStep += 1

episodeAcs.append(action)
episodeInfo.append(info)
episodeObs.append(obsDataNew)

objectPos = obsDataNew['observation'][3:6]
gripperPos = obsDataNew['observation'][:3]
gripperState = obsDataNew['observation'][9:11]
object_rel_pos = obsDataNew['observation'][6:9]

if timeStep >= env._max_episode_steps: break

actions.append(episodeAcs)
observations.append(episodeObs)
infos.append(episodeInfo)


if __name__ == "__main__":
main()