Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras_nlp/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class BeamSampler(Sampler):
return_all_beams: bool. When set to `True`, the sampler will return all
beams and their respective probabilities score.

Call Args:
Call arguments:
{{call_args}}

Examples:
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/samplers/contrastive_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class ContrastiveSampler(Sampler):
on the similarity than the token probability.
seed: int, defaults to None. The random seed.

Call Args:
Call arguments:
{{call_args}}

Examples:
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/samplers/greedy_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class GreedySampler(Sampler):
This sampler is implemented on greedy search, i.e., always picking up the
token of the largest probability as the next token.

Call Args:
Call arguments:
{{call_args}}

Examples:
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/samplers/random_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class RandomSampler(Sampler):
Args:
seed: int, defaults to None. The random seed.

Call Args:
Call arguments:
{{call_args}}

Examples:
Expand Down
41 changes: 20 additions & 21 deletions keras_nlp/samplers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,25 @@
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.utils.python_utils import format_docstring

call_args_docstring = """
next: A function which takes in the `prompt, cache, index` of the
current generation loop, and outputs a tuple
`(logits, cache, hidden_states)` with `logits` being the logits of next
token, `cache` for next iteration, and `hidden_states` being the
representation of the token.
prompt: A 2D integer tensor with shape `(batch_size, max_length)`. This
tensor will be iteratively updated column by column with new sampled
values, starting at `index`.
cache: Optional. A tensor or nested structure of tensors that will be
updated by each call to `next`. This can be used to cache computations
from early iterations of the generative loop.
index: Optional. The first index to start sampling at.
mask: Optional. A 2D integer tensor with the same shape as `prompt`.
Locations which are `True` in the mask are never updated during
sampling. Often this will mark all ids in `prompt` which were present in
the original input.
end_token_id: Optional. The token marking the end of the sequence. If
specified, sampling will stop as soon as all sequences in the prompt
produce a `end_token_id` in a location where `mask` is `False`.
call_args_docstring = """next: A function which takes in the
`prompt, cache, index` of the current generation loop, and outputs
a tuple `(logits, cache, hidden_states)` with `logits` being the
logits of next token, `cache` for next iteration, and
`hidden_states` being the representation of the token.
prompt: A 2D integer tensor with shape `(batch_size, max_length)`. This
tensor will be iteratively updated column by column with new sampled
values, starting at `index`.
cache: Optional. A tensor or nested structure of tensors that will be
updated by each call to `next`. This can be used to cache
computations from early iterations of the generative loop.
index: Optional. The first index to start sampling at.
mask: Optional. A 2D integer tensor with the same shape as `prompt`.
Locations which are `True` in the mask are never updated during
sampling. Often this will mark all ids in `prompt` which were
present in the original input.
end_token_id: Optional. The token marking the end of the sequence. If
specified, sampling will stop as soon as all sequences in the prompt
produce a `end_token_id` in a location where `mask` is `False`.
"""


Expand All @@ -53,7 +52,7 @@ class Sampler:
randomness of the sampling. The higher the temperature, the
more diverse the samples.

Call Args:
Call arguments:
{{call_args}}

This base class can be extended to implement different auto-regressive
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/samplers/top_k_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class TopKSampler(Sampler):
k: int, the `k` value of top-k.
seed: int, defaults to None. The random seed.

Call Args:
Call arguments:
{{call_args}}

Examples:
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/samplers/top_p_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class TopPSampler(Sampler):
of tokens to sort.
seed: int, defaults to None. The random seed.

Call Args:
Call arguments:
{{call_args}}

Examples:
Expand Down