From 8364d7e89282ddd52fe4ca9cae3ac1c49d767e72 Mon Sep 17 00:00:00 2001 From: hutao Date: Tue, 26 Jan 2021 10:14:54 +0800 Subject: [PATCH 1/3] update --- scripts/generation/README.md | 3 ++- src/gluonnlp/sequence_sampler.py | 20 +++++++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/scripts/generation/README.md b/scripts/generation/README.md index 06a34a2799..580311b5a1 100644 --- a/scripts/generation/README.md +++ b/scripts/generation/README.md @@ -14,6 +14,7 @@ python3 generate_unconditional_gpt2_samples.py \ --gpu 0 \ --temperature 0.7 \ --top_k 40 \ + --batch_size 2 \ --nsamples 1000 > samples ``` @@ -43,7 +44,7 @@ Some metrics for the unconditional generated text | topk=40 | 0.4291 | 0.9666 | 0.0 | | topk=640 | 0.3384 | 0.9623 | 0.0 | | topk=40 t=0.7 | 0.4621 | 0.9586 | 1.1 | - +| topp=0.95 | 0.2676 | 0.9519 | 0.0 | Part of some interesting generated unconditional example diff --git a/src/gluonnlp/sequence_sampler.py b/src/gluonnlp/sequence_sampler.py index d5f621acba..dd3447a7f9 100644 --- a/src/gluonnlp/sequence_sampler.py +++ b/src/gluonnlp/sequence_sampler.py @@ -717,11 +717,21 @@ def forward(self, samples, valid_length, outputs, scores, step, beam_alive_mask, probs = mx.npx.softmax(outputs / self._temperature) if self._sampling_topp > 0: - probs = mx.np.where( - probs > self._sampling_topp, - probs, - mx.np.zeros_like(probs) - ) + sorted_probs, sorted_indices = mx.npx.topk(probs, axis=2, k=-1, ret_typ='both', is_ascend=False) + cumsum_probs = mx.np.cumsum(sorted_probs, axis=2) + masked_probs = mx.np.where( + cumsum_probs > self._sampling_topp, + sorted_probs, + mx.np.zeros_like(probs) + ) + # choose the borderline prob + p_prob = mx.np.min(masked_probs, axis=2, keepdims=True) + probs = mx.np.where( + probs > self._sampling_topp, + probs >= p_prob, + probs, + mx.np.zeros_like(probs) + ) elif self._sampling_topk > 0: topk_probs = mx.npx.topk(probs, axis=2, k=self._sampling_topk, ret_typ='value') # choose the k max prob From b3b846b06fb13ffa75a4b64e1fe2edc62b9db355 Mon Sep 17 00:00:00 2001 From: hutao Date: Tue, 26 Jan 2021 10:19:23 +0800 Subject: [PATCH 2/3] update --- src/gluonnlp/sequence_sampler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gluonnlp/sequence_sampler.py b/src/gluonnlp/sequence_sampler.py index dd3447a7f9..99807bb6cf 100644 --- a/src/gluonnlp/sequence_sampler.py +++ b/src/gluonnlp/sequence_sampler.py @@ -727,7 +727,6 @@ def forward(self, samples, valid_length, outputs, scores, step, beam_alive_mask, # choose the borderline prob p_prob = mx.np.min(masked_probs, axis=2, keepdims=True) probs = mx.np.where( - probs > self._sampling_topp, probs >= p_prob, probs, mx.np.zeros_like(probs) From f766c013b748358834d146c77913c2203d8d71e8 Mon Sep 17 00:00:00 2001 From: hutao Date: Wed, 3 Feb 2021 19:35:09 +0800 Subject: [PATCH 3/3] update --- src/gluonnlp/sequence_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gluonnlp/sequence_sampler.py b/src/gluonnlp/sequence_sampler.py index 99807bb6cf..18556e085d 100644 --- a/src/gluonnlp/sequence_sampler.py +++ b/src/gluonnlp/sequence_sampler.py @@ -722,7 +722,7 @@ def forward(self, samples, valid_length, outputs, scores, step, beam_alive_mask, masked_probs = mx.np.where( cumsum_probs > self._sampling_topp, sorted_probs, - mx.np.zeros_like(probs) + mx.np.ones_like(probs) ) # choose the borderline prob p_prob = mx.np.min(masked_probs, axis=2, keepdims=True)