Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
lmzintgraf committed Sep 28, 2020
1 parent 2fcca1c commit a36d251
Show file tree
Hide file tree
Showing 187 changed files with 4,527 additions and 20,424 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -4,3 +4,5 @@ __pycache__
.idea

logs/
scripts/
docker/
11 changes: 5 additions & 6 deletions README.md
Expand Up @@ -74,16 +74,15 @@ to learn the posterior distribution in a supervised way.
(Note that our implementation is based on the variBAD architecture,
so differs slightly from theirs.)
- The size of the latent dimension can be changed using `--latent_dim`.
- In our experience, the performance of PPO depends a lot on the number of
minibatches (`--ppo_num_minibatch`),
the clip parameter (`--ppo_clip_param`, we suggest values between 0.01 and 0.3),
- In our experience, the performance of PPO depends a lot on
the number of minibatches (`--ppo_num_minibatch`),
the number of epochs (`ppo_num_epochs`),
and the batchsize (change with `--policy_num_steps` and/or `--num_processes`).
Another important parameter is the kl term (`--kl_weight`) for the ELBO term.

Another important parameter is the weight of the kl term (`--kl_weight`) in the ELBO.

### Comments

- When the flag `disable_varibad` is activated, the file `learner.py` will be used instead of `metalearner.py`.
- When the flag `disable_metalearner` is activated, the file `learner.py` will be used instead of `metalearner.py`.
This is a stripped down version without encoder, decoder, stochastic latent variables, etc.
It can be used to train (belief) oracles or policies that are good on average.
- For the environments do not use `np.random` (it's not thread safe) but stick to `random` or `torch.random`.
Expand Down
67 changes: 46 additions & 21 deletions algorithms/a2c.py
Expand Up @@ -9,14 +9,18 @@

class A2C:
def __init__(self,
args,
actor_critic,
value_loss_coef,
entropy_coef,
policy_optimiser,
policy_anneal_lr,
train_steps,
optimiser_vae=None,
lr=None,
eps=None,
alpha=None,
):
self.args = args

# the model
self.actor_critic = actor_critic
Expand All @@ -26,50 +30,66 @@ def __init__(self,
self.entropy_coef = entropy_coef

# optimiser
self.optimizer = optim.RMSprop(actor_critic.parameters(), lr, eps=eps, alpha=alpha)
if policy_optimiser == 'adam':
self.optimiser = optim.Adam(actor_critic.parameters(), lr=lr, eps=eps)
elif policy_optimiser == 'rmsprop':
self.optimiser = optim.RMSprop(actor_critic.parameters(), lr, eps=eps, alpha=0.99)
self.optimiser_vae = optimiser_vae

self.lr_scheduler_policy = None
self.lr_scheduler_encoder = None
if policy_anneal_lr:
lam = lambda f: 1 - f / train_steps
self.lr_scheduler_policy = optim.lr_scheduler.LambdaLR(self.optimiser, lr_lambda=lam)
if hasattr(self.args, 'rlloss_through_encoder') and self.args.rlloss_through_encoder:
self.lr_scheduler_encoder = optim.lr_scheduler.LambdaLR(self.optimiser_vae, lr_lambda=lam)

def update(self,
args,
policy_storage,
encoder=None, # variBAD encoder
rlloss_through_encoder=False, # whether or not to backprop RL loss through encoder
compute_vae_loss=None # function that can compute the VAE loss
):

# -- get action values --
# get action values
advantages = policy_storage.returns[:-1] - policy_storage.value_preds[:-1]

if rlloss_through_encoder:
# re-compute encoding (to build the computation graph from scratch)
utl.recompute_embeddings(policy_storage, encoder, sample=False, update_idx=0)
utl.recompute_embeddings(policy_storage, encoder, sample=False, update_idx=0,
detach_every=self.args.tbptt_stepsize if hasattr(self.args,
'tbptt_stepsize') else None)

# update the normalisation parameters of policy inputs before updating
self.actor_critic.update_rms(args=self.args, policy_storage=policy_storage)

data_generator = policy_storage.feed_forward_generator(advantages, 1)
for sample in data_generator:

obs_batch, actions_batch, latent_sample_batch, latent_mean_batch, latent_logvar_batch, value_preds_batch, \
state_batch, belief_batch, task_batch, \
actions_batch, latent_sample_batch, latent_mean_batch, latent_logvar_batch, value_preds_batch, \
return_batch, old_action_log_probs_batch, adv_targ = sample

if not rlloss_through_encoder:
obs_batch = obs_batch.detach()
state_batch = state_batch.detach()
if latent_sample_batch is not None:
latent_sample_batch = latent_sample_batch.detach()
latent_mean_batch = latent_mean_batch.detach()
latent_logvar_batch = latent_logvar_batch.detach()

obs_aug = utl.get_augmented_obs(args=args,
obs=obs_batch,
latent_sample=latent_sample_batch, latent_mean=latent_mean_batch,
latent_logvar=latent_logvar_batch
)
latent_batch = utl.get_latent_for_policy(args=self.args, latent_sample=latent_sample_batch,
latent_mean=latent_mean_batch, latent_logvar=latent_logvar_batch
)

values, action_log_probs, dist_entropy, action_mean, action_logstd = \
self.actor_critic.evaluate_actions(obs_aug, actions_batch, return_action_mean=True)
self.actor_critic.evaluate_actions(state=state_batch, latent=latent_batch,
belief=belief_batch, task=task_batch,
action=actions_batch, return_action_mean=True)

# -- UPDATE --

# zero out the gradients
self.optimizer.zero_grad()
self.optimiser.zero_grad()
if rlloss_through_encoder:
self.optimiser_vae.zero_grad()

Expand All @@ -82,24 +102,29 @@ def update(self,

# compute vae loss and backprop
if rlloss_through_encoder:
loss += args.vae_loss_coeff * compute_vae_loss()
loss += self.args.vae_loss_coeff * compute_vae_loss()

# compute gradients (will attach to all networks involved in this computation)
loss.backward()
nn.utils.clip_grad_norm_(self.actor_critic.parameters(), args.policy_max_grad_norm)
nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.args.policy_max_grad_norm)
if encoder is not None and rlloss_through_encoder:
nn.utils.clip_grad_norm_(encoder.parameters(), args.policy_max_grad_norm)
nn.utils.clip_grad_norm_(encoder.parameters(), self.args.policy_max_grad_norm)

# update
self.optimizer.step()
self.optimiser.step()
if rlloss_through_encoder:
self.optimiser_vae.step()

if (not rlloss_through_encoder) and (self.optimiser_vae is not None):
for _ in range(args.num_vae_updates - 1):
for _ in range(self.args.num_vae_updates):
compute_vae_loss(update=True)

if self.lr_scheduler_policy is not None:
self.lr_scheduler_policy.step()
if self.lr_scheduler_encoder is not None:
self.lr_scheduler_encoder.step()

return value_loss, action_loss, dist_entropy, loss

def act(self, obs, deterministic=False):
return self.actor_critic.act(obs, deterministic=deterministic)
def act(self, state, latent, belief, task, deterministic=False):
return self.actor_critic.act(state=state, latent=latent, belief=belief, task=task, deterministic=deterministic)

0 comments on commit a36d251

Please sign in to comment.