Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[BUGFIX]fix bug of top-p sampling #1503

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion scripts/generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ python3 generate_unconditional_gpt2_samples.py \
--gpu 0 \
--temperature 0.7 \
--top_k 40 \
--batch_size 2 \
--nsamples 1000 > samples
```

Expand Down Expand Up @@ -43,7 +44,7 @@ Some metrics for the unconditional generated text
| topk=40 | 0.4291 | 0.9666 | 0.0 |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think previously we have the results of t=0.9, we should remove that row.

| topk=640 | 0.3384 | 0.9623 | 0.0 |
| topk=40 t=0.7 | 0.4621 | 0.9586 | 1.1 |

| topp=0.95 | 0.2676 | 0.9519 | 0.0 |

Part of some interesting generated unconditional example

Expand Down
19 changes: 14 additions & 5 deletions src/gluonnlp/sequence_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,11 +717,20 @@ def forward(self, samples, valid_length, outputs, scores, step, beam_alive_mask,
probs = mx.npx.softmax(outputs / self._temperature)

if self._sampling_topp > 0:
probs = mx.np.where(
probs > self._sampling_topp,
probs,
mx.np.zeros_like(probs)
)
sorted_probs, sorted_indices = mx.npx.topk(probs, axis=2, k=-1, ret_typ='both', is_ascend=False)
cumsum_probs = mx.np.cumsum(sorted_probs, axis=2)
masked_probs = mx.np.where(
cumsum_probs > self._sampling_topp,
sorted_probs,
mx.np.zeros_like(probs)
)
# choose the borderline prob
p_prob = mx.np.min(masked_probs, axis=2, keepdims=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to use exactly the same implementation as https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm referring to the part in which they choose not to mask the top-1 probability:

sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion. I see that both sort and argsort are implemented but I don't see a way to get both values and indices in one call. The usage of topk(k=-1) that assumes the return values to be sorted seems to be undocumented, which is a bit of a concern.

probs = mx.np.where(
probs >= p_prob,
probs,
mx.np.zeros_like(probs)
)
Copy link
Member

@sxjscience sxjscience Feb 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The major difference between the current implementation and the original pytorch-based implementation is that when sampling_topp < max(probs), it is not clear which probability will be picked.

The pytorch-based implementation will always choose the token that is most probable.

elif self._sampling_topk > 0:
topk_probs = mx.npx.topk(probs, axis=2, k=self._sampling_topk, ret_typ='value')
# choose the k max prob
Expand Down