Skip to content

Commit

Permalink
fix an error with the way action log prob is collected during the epi…
Browse files Browse the repository at this point in the history
…sode rollouts, addressing #31 and thanks to @kisseternity
  • Loading branch information
lucidrains committed Feb 22, 2023
1 parent bfcffe7 commit a159310
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions palm_rlhf_pytorch/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))

def log_prob(prob, indices):
assert prob.shape[:2] == indices.shape, f'preceding shapes of prob {prob.shape[:2]} and indices {indices.shape} must match'
return log(prob.gather(-1, indices[..., None])).squeeze(-1)

def shift(t, value = 0, shift = 1, dim = -1):
Expand Down Expand Up @@ -608,11 +609,13 @@ def train(
temperature = temperature,
return_values = True
)

action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token

action_prob = action_logits.softmax(dim = -1)
action_log_prob = log_prob(action_prob, actions)

action_len = actions.shape[-1]
action_log_prob = log_prob(action_prob, sequence)
action_log_prob = action_log_prob[:, -action_len:]

actions = rearrange(actions, '1 ... -> ...')

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'PaLM-rlhf-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.64',
version = '0.0.65',
license='MIT',
description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit a159310

Please sign in to comment.