diff --git a/scripts/generation/README.md b/scripts/generation/README.md index 06a34a2799..580311b5a1 100644 --- a/scripts/generation/README.md +++ b/scripts/generation/README.md @@ -14,6 +14,7 @@ python3 generate_unconditional_gpt2_samples.py \ --gpu 0 \ --temperature 0.7 \ --top_k 40 \ + --batch_size 2 \ --nsamples 1000 > samples ``` @@ -43,7 +44,7 @@ Some metrics for the unconditional generated text | topk=40 | 0.4291 | 0.9666 | 0.0 | | topk=640 | 0.3384 | 0.9623 | 0.0 | | topk=40 t=0.7 | 0.4621 | 0.9586 | 1.1 | - +| topp=0.95 | 0.2676 | 0.9519 | 0.0 | Part of some interesting generated unconditional example diff --git a/src/gluonnlp/sequence_sampler.py b/src/gluonnlp/sequence_sampler.py index d5f621acba..18556e085d 100644 --- a/src/gluonnlp/sequence_sampler.py +++ b/src/gluonnlp/sequence_sampler.py @@ -717,11 +717,20 @@ def forward(self, samples, valid_length, outputs, scores, step, beam_alive_mask, probs = mx.npx.softmax(outputs / self._temperature) if self._sampling_topp > 0: - probs = mx.np.where( - probs > self._sampling_topp, - probs, - mx.np.zeros_like(probs) - ) + sorted_probs, sorted_indices = mx.npx.topk(probs, axis=2, k=-1, ret_typ='both', is_ascend=False) + cumsum_probs = mx.np.cumsum(sorted_probs, axis=2) + masked_probs = mx.np.where( + cumsum_probs > self._sampling_topp, + sorted_probs, + mx.np.ones_like(probs) + ) + # choose the borderline prob + p_prob = mx.np.min(masked_probs, axis=2, keepdims=True) + probs = mx.np.where( + probs >= p_prob, + probs, + mx.np.zeros_like(probs) + ) elif self._sampling_topk > 0: topk_probs = mx.npx.topk(probs, axis=2, k=self._sampling_topk, ret_typ='value') # choose the k max prob