diff --git a/keras_nlp/samplers/contrastive_sampler.py b/keras_nlp/samplers/contrastive_sampler.py index e89f4d2b0c..7c1892958a 100644 --- a/keras_nlp/samplers/contrastive_sampler.py +++ b/keras_nlp/samplers/contrastive_sampler.py @@ -76,8 +76,9 @@ def __init__( k=5, alpha=0.6, seed=None, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.k = k self.alpha = alpha self.seed = seed @@ -133,7 +134,7 @@ def cond(prompt, cache, index, logits, hidden_states): def body(prompt, cache, index, logits, hidden_states): # Compute the softmax distribution for the next token. - probabilities = keras.activations.softmax(logits) + probabilities = keras.activations.softmax(logits / self.temperature) # Replicate for `self.k` times to find the best token in top-k # candidates. diff --git a/keras_nlp/samplers/contrastive_sampler_test.py b/keras_nlp/samplers/contrastive_sampler_test.py index 8981c809cb..d8637488d6 100644 --- a/keras_nlp/samplers/contrastive_sampler_test.py +++ b/keras_nlp/samplers/contrastive_sampler_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Top-K sampler.""" +"""Tests for Contrastive Sampler.""" import tensorflow as tf from absl.testing import parameterized @@ -45,7 +45,7 @@ def next(prompt, cache, index): return logits, hidden_states, cache self.next = next - self.sampler = ContrastiveSampler(k=5, alpha=0.2) + self.sampler = ContrastiveSampler(k=5, alpha=0.2, temperature=1.0) def join_as_string(self, x): return ["".join([self.int_lookup[i] for i in s]) for s in x.numpy()]