You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm very confused about the value calculate, why use different mask? In generate method, the mask include prompt. But when training in learn method, the mask did not include prompt.
this is in learn method:
action_masks = ~prompt_masks & masks
action_logits, values = self.actor_critic(
sequences,
mask = action_masks
)
and in generate method:
mask = None
if exists(eos_token):
mask = ((sequence == eos_token).cumsum(dim = -1) == 0)
mask = F.pad(mask, (1, -1), value = True) # include eos token
action_logits, value = self.forward(
sequence,
mask = mask,
return_values = return_values
)
The text was updated successfully, but these errors were encountered:
I'm very confused about the value calculate, why use different mask? In generate method, the mask include prompt. But when training in learn method, the mask did not include prompt.
this is in learn method:
action_masks = ~prompt_masks & masks
action_logits, values = self.actor_critic(
sequences,
mask = action_masks
)
and in generate method:
mask = None
if exists(eos_token):
mask = ((sequence == eos_token).cumsum(dim = -1) == 0)
mask = F.pad(mask, (1, -1), value = True) # include eos token
action_logits, value = self.forward(
sequence,
mask = mask,
return_values = return_values
)
The text was updated successfully, but these errors were encountered: