Skip to content

Commit

Permalink
No more Q states (#30)
Browse files Browse the repository at this point in the history
* No more Q-states. Support for more complex environments.

* Whoops I suck at git

* Fix typo, remove unused discrete actionspace support
  • Loading branch information
Raelifin authored and nottombrown committed Aug 15, 2017
1 parent 572b19f commit 18670ce
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 73 deletions.
11 changes: 9 additions & 2 deletions rl_teacher/nn.py
@@ -1,10 +1,15 @@
import numpy as np
import tensorflow as tf

from keras.layers import Dense, Dropout, LeakyReLU
from keras.models import Sequential

class FullyConnectedMLP(object):
"""Vanilla two hidden layer multi-layer perceptron"""

def __init__(self, input_dim, h_size=64):
def __init__(self, obs_shape, act_shape, h_size=64):
input_dim = np.prod(obs_shape) + np.prod(act_shape)

self.model = Sequential()
self.model.add(Dense(h_size, input_dim=input_dim))
self.model.add(LeakyReLU())
Expand All @@ -16,5 +21,7 @@ def __init__(self, input_dim, h_size=64):
self.model.add(Dropout(0.5))
self.model.add(Dense(1))

def run(self, x):
def run(self, obs, act):
flat_obs = tf.contrib.layers.flatten(obs)
x = tf.concat([flat_obs, act], axis=1)
return self.model(x)
178 changes: 107 additions & 71 deletions rl_teacher/teach.py
Expand Up @@ -15,7 +15,6 @@
from rl_teacher.envs import make_with_torque_removed
from rl_teacher.label_schedules import LabelAnnealer, ConstantLabelSchedule
from rl_teacher.nn import FullyConnectedMLP
from rl_teacher.segment_sampling import create_segment_q_states
from rl_teacher.segment_sampling import sample_segment_from_path
from rl_teacher.segment_sampling import segments_from_rand_rollout
from rl_teacher.summaries import AgentLogger, make_summary_writer
Expand All @@ -24,14 +23,14 @@

CLIP_LENGTH = 1.5

class TraditionalRLRewardPredictor():
class TraditionalRLRewardPredictor(object):
"""Predictor that always returns the true reward provided by the environment."""

def __init__(self, summary_writer):
self.agent_logger = AgentLogger(summary_writer)

def predict_reward(self, path):
self.agent_logger.log_episode(path)
self.agent_logger.log_episode(path) # <-- This may cause problems in future versions of Teacher.
return path["original_rewards"]

def path_callback(self, path):
Expand All @@ -54,44 +53,66 @@ def __init__(self, env, summary_writer, comparison_collector, agent_logger, labe
self._elapsed_predictor_training_iters = 0

# Build and initialize our predictor model
self.sess = tf.InteractiveSession()
self.q_state_size = np.product(env.observation_space.shape) + np.product(env.action_space.shape)
self._build_model()
config = tf.ConfigProto(
device_count={'GPU': 0}
)
self.sess = tf.InteractiveSession(config=config)
self.obs_shape = env.observation_space.shape
self.discrete_action_space = not hasattr(env.action_space, "shape")
self.act_shape = (env.action_space.n,) if self.discrete_action_space else env.action_space.shape
self.graph = self._build_model()
self.sess.run(tf.global_variables_initializer())

def _predict_rewards(self, segments):
def _predict_rewards(self, obs_segments, act_segments, network):
"""
:param segments: tensor with shape = (batch_size, segment_length, q_state_size)
:param obs_segments: tensor with shape = (batch_size, segment_length) + obs_shape
:param act_segments: tensor with shape = (batch_size, segment_length) + act_shape
:param network: neural net with .run() that maps obs and act tensors into a (scalar) value tensor
:return: tensor with shape = (batch_size, segment_length)
"""
segment_length = tf.shape(segments)[1]
batchsize = tf.shape(segments)[0]
batchsize = tf.shape(obs_segments)[0]
segment_length = tf.shape(obs_segments)[1]

# Temporarily chop up segments into individual q_states
q_states = tf.reshape(segments, [batchsize * segment_length, self.q_state_size])
# Temporarily chop up segments into individual observations and actions
obs = tf.reshape(obs_segments, (-1,) + self.obs_shape)
acts = tf.reshape(act_segments, (-1,) + self.act_shape)

# Run them through our MLP
rewards = self.mlp.run(q_states)
# Run them through our neural network
rewards = network.run(obs, acts)

# Group the rewards back into their segments
return tf.reshape(rewards, (batchsize, segment_length))

def _build_model(self):
"""Our model takes in a vector of q_states from a segment and returns a reward for each one"""
self.segment_placeholder = tf.placeholder(
dtype=tf.float32, shape=(None, None, self.q_state_size), name="obs_placeholder")
self.segment_alt_placeholder = tf.placeholder(
dtype=tf.float32, shape=(None, None, self.q_state_size), name="obs_placeholder")

# A vanilla MLP maps a q_state to a reward
self.mlp = FullyConnectedMLP(self.q_state_size)
self.q_state_reward_pred = self._predict_rewards(self.segment_placeholder)
q_state_alt_reward_pred = self._predict_rewards(self.segment_alt_placeholder)

# We use trajectory segments rather than individual q_states because video clips of segments are easier for
# humans to evaluate
segment_reward_pred_left = tf.reduce_sum(self.q_state_reward_pred, axis=1)
segment_reward_pred_right = tf.reduce_sum(q_state_alt_reward_pred, axis=1)
"""
Our model takes in path segments with states and actions, and generates Q values.
These Q values serve as predictions of the true reward.
We can compare two segments and sum the Q values to get a prediction of a label
of which segment is better. We then learn the weights for our model by comparing
these labels with an authority (either a human or synthetic labeler).
"""
# Set up observation placeholders
self.segment_obs_placeholder = tf.placeholder(
dtype=tf.float32, shape=(None, None) + self.obs_shape, name="obs_placeholder")
self.segment_alt_obs_placeholder = tf.placeholder(
dtype=tf.float32, shape=(None, None) + self.obs_shape, name="alt_obs_placeholder")

self.segment_act_placeholder = tf.placeholder(
dtype=tf.float32, shape=(None, None) + self.act_shape, name="act_placeholder")
self.segment_alt_act_placeholder = tf.placeholder(
dtype=tf.float32, shape=(None, None) + self.act_shape, name="alt_act_placeholder")


# A vanilla multi-layer perceptron maps a (state, action) pair to a reward (Q-value)
mlp = FullyConnectedMLP(self.obs_shape, self.act_shape)

self.q_value = self._predict_rewards(self.segment_obs_placeholder, self.segment_act_placeholder, mlp)
alt_q_value = self._predict_rewards(self.segment_alt_obs_placeholder, self.segment_alt_act_placeholder, mlp)

# We use trajectory segments rather than individual (state, action) pairs because
# video clips of segments are easier for humans to evaluate
segment_reward_pred_left = tf.reduce_sum(self.q_value, axis=1)
segment_reward_pred_right = tf.reduce_sum(alt_q_value, axis=1)
reward_logits = tf.stack([segment_reward_pred_left, segment_reward_pred_right], axis=1) # (batch_size, 2)

self.labels = tf.placeholder(dtype=tf.int32, shape=(None,), name="comparison_labels")
Expand All @@ -103,16 +124,20 @@ def _build_model(self):

self.loss_op = tf.reduce_mean(data_loss)

self.global_step = tf.Variable(0, name='global_step', trainable=False)
self.train_op = tf.train.AdamOptimizer().minimize(self.loss_op, global_step=self.global_step)
global_step = tf.Variable(0, name='global_step', trainable=False)
self.train_op = tf.train.AdamOptimizer().minimize(self.loss_op, global_step=global_step)

return tf.get_default_graph()

def predict_reward(self, path):
"""Predict the reward for each step in a given path"""
q_state_reward_pred = self.sess.run(self.q_state_reward_pred, feed_dict={
self.segment_placeholder: np.array([create_segment_q_states(path)]),
K.learning_phase(): False
})
return q_state_reward_pred[0]
with self.graph.as_default():
q_value = self.sess.run(self.q_value, feed_dict={
self.segment_obs_placeholder: np.asarray([path["obs"]]),
self.segment_act_placeholder: np.asarray([path["actions"]]),
K.learning_phase(): False
})
return q_value[0]

def path_callback(self, path):
path_length = len(path["obs"])
Expand Down Expand Up @@ -141,32 +166,40 @@ def train_predictor(self):

minibatch_size = min(64, len(self.comparison_collector.labeled_decisive_comparisons))
labeled_comparisons = random.sample(self.comparison_collector.labeled_decisive_comparisons, minibatch_size)
left_q_states = np.asarray([comp['left']['q_states'] for comp in labeled_comparisons])
right_q_states = np.asarray([comp['right']['q_states'] for comp in labeled_comparisons])

_, loss = self.sess.run([self.train_op, self.loss_op], feed_dict={
self.segment_placeholder: left_q_states,
self.segment_alt_placeholder: right_q_states,
self.labels: np.asarray([comp['label'] for comp in labeled_comparisons]),
K.learning_phase(): True
})
self._elapsed_predictor_training_iters += 1
self._write_training_summaries(loss)
left_obs = np.asarray([comp['left']['obs'] for comp in labeled_comparisons])
left_acts = np.asarray([comp['left']['actions'] for comp in labeled_comparisons])
right_obs = np.asarray([comp['right']['obs'] for comp in labeled_comparisons])
right_acts = np.asarray([comp['right']['actions'] for comp in labeled_comparisons])
labels = np.asarray([comp['label'] for comp in labeled_comparisons])

with self.graph.as_default():
_, loss = self.sess.run([self.train_op, self.loss_op], feed_dict={
self.segment_obs_placeholder: left_obs,
self.segment_act_placeholder: left_acts,
self.segment_alt_obs_placeholder: right_obs,
self.segment_alt_act_placeholder: right_acts,
self.labels: labels,
K.learning_phase(): True
})
self._elapsed_predictor_training_iters += 1
self._write_training_summaries(loss)

def _write_training_summaries(self, loss):
self.agent_logger.log_simple("predictor/loss", loss)

# Calculate correlation between true and predicted reward by running validation on recent episodes
recent_paths = self.agent_logger.get_recent_paths_with_padding()
if len(recent_paths) > 1 and self.agent_logger.summary_step % 10 == 0: # Run validation every 10 iters
validation_q_states = np.asarray([create_segment_q_states(path) for path in recent_paths])
q_state_reward_pred = self.sess.run(self.q_state_reward_pred, feed_dict={
self.segment_placeholder: validation_q_states,
validation_obs = np.asarray([path["obs"] for path in recent_paths])
validation_acts = np.asarray([path["actions"] for path in recent_paths])
q_value = self.sess.run(self.q_value, feed_dict={
self.segment_obs_placeholder: validation_obs,
self.segment_act_placeholder: validation_acts,
K.learning_phase(): False
})
ep_reward_pred = np.sum(q_state_reward_pred, axis=1)
q_state_reward_true = np.asarray([path['original_rewards'] for path in recent_paths])
ep_reward_true = np.sum(q_state_reward_true, axis=1)
ep_reward_pred = np.sum(q_value, axis=1)
reward_true = np.asarray([path['original_rewards'] for path in recent_paths])
ep_reward_true = np.sum(reward_true, axis=1)
self.agent_logger.log_simple("predictor/correlations", corrcoef(ep_reward_true, ep_reward_pred))

self.agent_logger.log_simple("predictor/num_training_iters", self._elapsed_predictor_training_iters)
Expand All @@ -191,6 +224,8 @@ def main():
parser.add_argument('-V', '--no_videos', action="store_true")
args = parser.parse_args()

print("Setting things up...")

env_id = args.env_id
run_name = "%s/%s-%s" % (env_id, args.name, int(time()))
summary_writer = make_summary_writer(run_name)
Expand All @@ -205,16 +240,6 @@ def main():
else:
agent_logger = AgentLogger(summary_writer)

if args.predictor == "synth":
comparison_collector = SyntheticComparisonCollector()

elif args.predictor == "human":
bucket = os.environ.get('RL_TEACHER_GCS_BUCKET')
assert bucket and bucket.startswith("gs://"), "env variable RL_TEACHER_GCS_BUCKET must start with gs://"
comparison_collector = HumanComparisonCollector(env_id, experiment_name=experiment_name)
else:
raise ValueError("Bad value for --predictor: %s" % args.predictor)

pretrain_labels = args.pretrain_labels if args.pretrain_labels else args.n_labels // 4

if args.n_labels:
Expand All @@ -224,9 +249,27 @@ def main():
final_labels=args.n_labels,
pretrain_labels=pretrain_labels)
else:
print("No label limit given. We will request one label every few seconds")
print("No label limit given. We will request one label every few seconds.")
label_schedule = ConstantLabelSchedule(pretrain_labels=pretrain_labels)

if args.predictor == "synth":
comparison_collector = SyntheticComparisonCollector()

elif args.predictor == "human":
bucket = os.environ.get('RL_TEACHER_GCS_BUCKET')
assert bucket and bucket.startswith("gs://"), "env variable RL_TEACHER_GCS_BUCKET must start with gs://"
comparison_collector = HumanComparisonCollector(env_id, experiment_name=experiment_name)
else:
raise ValueError("Bad value for --predictor: %s" % args.predictor)

predictor = ComparisonRewardPredictor(
env,
summary_writer,
comparison_collector=comparison_collector,
agent_logger=agent_logger,
label_schedule=label_schedule,
)

print("Starting random rollouts to generate pretraining segments. No learning will take place...")
pretrain_segments = segments_from_rand_rollout(
env_id, make_with_torque_removed, n_desired_segments=pretrain_labels * 2,
Expand All @@ -245,13 +288,6 @@ def main():
sleep(5)

# Start the actual training
predictor = ComparisonRewardPredictor(
env,
summary_writer,
comparison_collector=comparison_collector,
agent_logger=agent_logger,
label_schedule=label_schedule,
)
for i in range(args.pretrain_iters):
predictor.train_predictor() # Train on pretraining labels
if i % 100 == 0:
Expand Down

0 comments on commit 18670ce

Please sign in to comment.