Skip to content

Commit

Permalink
fix mask logic in rlhf trainer when eos_token is designated
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 13, 2023
1 parent 6e3a60d commit 1a23215
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions palm_rlhf_pytorch/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d
if exists(val):
return val
return d() if callable(d) else d

def masked_normalize(t, eps = 1e-5, mask = None, dim = None):
dim = default(dim, tuple(range(t.ndim)))
Expand Down Expand Up @@ -630,7 +632,7 @@ def train(

sequence = rearrange(sequence, 'n -> 1 n')
prompt_mask = rearrange(prompt_mask, 'n -> 1 n')
mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device)
mask = default(mask, lambda: torch.ones(sequence.shape, dtype = torch.bool, device = device))

reward = self.reward_model(
sequence,
Expand Down

0 comments on commit 1a23215

Please sign in to comment.