diff --git a/keras_nlp/models/gpt2/gpt2_causal_lm.py b/keras_nlp/models/gpt2/gpt2_causal_lm.py index b772dfa5be..8fb1b68a55 100644 --- a/keras_nlp/models/gpt2/gpt2_causal_lm.py +++ b/keras_nlp/models/gpt2/gpt2_causal_lm.py @@ -79,7 +79,7 @@ class GPT2CausalLM(Task): gpt2_lm.generate("I want to say", max_length=30) gpt2_lm.compile(sampler=keras_nlp.samplers.BeamSampler(num_beams=2)) - gpt2_lm.generate("I want to say", max_length=30, sampler=sampler) + gpt2_lm.generate("I want to say", max_length=30) ``` Map raw string to languages model logit predictions. @@ -331,15 +331,13 @@ def generate( This method generates text based on given `prompt`. Generation will continue until `max_length` is met, and all tokens generated after - `end_token` will be truncated. The sampling approach used can be - controlled via the sampler argument. + `end_token` will be truncated. The sampling strategy can be set in + the `compile` method. Args: prompt: a string, string Tensor or string RaggedTensor. The prompt text for generation. max_length: int. The max length of generated sequence. - sampler: a string or `keras_nlp.samplers.Sampler` instance. The - sampler to be used for text generation. """ if self.preprocessor is None: raise ValueError(