-
Notifications
You must be signed in to change notification settings - Fork 301
Remove the old sampler utilities #948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,16 +11,15 @@ from the root of the repository: | |
python3 ./keras_nlp/benchmarks/text_generation.py | ||
``` | ||
|
||
On running this script on Google Colab (with Tesla T4 GPU, and TensorFlow 2.10.0), | ||
On running this script on Google Colab (with 3090 GPU, and TensorFlow 2.11.0), | ||
the following results were obtained: | ||
|
||
| **Decoding Strategy** | **Graph Mode (sec)** | **Graph Mode with XLA (sec)** | | ||
|:---------------------: |:--------------------: |:-----------------------------: | | ||
| Greedy Search | 495.78 | 293.77 | | ||
| Beam Search | 564.23 | 615.17 | | ||
| Random Search | 446.55 | 296.21 | | ||
| Top-k Search | 458.68 | 302.66 | | ||
| Top-p Search | 468.63 | 565.50 | | ||
| Greedy Search | 470.23 | 61.79 | | ||
| Beam Search | 530.13 | 189.61 | | ||
| Top-k Search | 374.05 | 62.87 | | ||
| Top-p Search | 401.97 | 260.31 | | ||
Comment on lines
+19
to
+22
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Damn, crazy speedup! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a bit apples to oranges. Switched hardware. We should consider updating this to a "cached" version. That will be a huge speedup. |
||
|
||
To change the configuration, say, for example, number of layers in the transformer | ||
model used for inference, the user can modify the config dictionaries given at | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,11 +20,6 @@ | |
from tensorflow import keras | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to update this README with new metrics: https://github.com/keras-team/keras-nlp/blob/master/benchmarks/README.md |
||
|
||
import keras_nlp | ||
from keras_nlp.utils import beam_search | ||
from keras_nlp.utils import greedy_search | ||
from keras_nlp.utils import random_search | ||
from keras_nlp.utils import top_k_search | ||
from keras_nlp.utils import top_p_search | ||
|
||
SEED = 42 | ||
|
||
|
@@ -34,14 +29,8 @@ | |
"batch_size": 2, | ||
} | ||
|
||
TEXT_GEN_ARGS = { | ||
"max_length": 64, | ||
"end_token_id": 2, | ||
"pad_token_id": 0, | ||
} | ||
|
||
MODEL_ARGS = { | ||
"max_length": 300, | ||
"max_length": 64, | ||
"embed_dim": 768, | ||
"num_layers": 8, | ||
"num_heads": 8, | ||
|
@@ -50,55 +39,27 @@ | |
|
||
TEST_RUNS = [ | ||
{ | ||
"decoding_fn": greedy_search, | ||
"execution_methods": ["xla", "graph"], | ||
"args": TEXT_GEN_ARGS, | ||
}, | ||
{ | ||
"decoding_fn": beam_search, | ||
"sampler": "greedy", | ||
"execution_methods": ["xla", "graph"], | ||
"args": { | ||
"num_beams": 2, | ||
"from_logits": True, | ||
**TEXT_GEN_ARGS, | ||
}, | ||
}, | ||
{ | ||
"decoding_fn": random_search, | ||
"sampler": "beam", | ||
"execution_methods": ["xla", "graph"], | ||
"args": { | ||
"seed": SEED, | ||
"from_logits": True, | ||
**TEXT_GEN_ARGS, | ||
}, | ||
}, | ||
{ | ||
"decoding_fn": top_k_search, | ||
"sampler": "top_k", | ||
"execution_methods": ["xla", "graph"], | ||
"args": { | ||
"k": 5, | ||
"seed": SEED, | ||
"from_logits": True, | ||
**TEXT_GEN_ARGS, | ||
}, | ||
}, | ||
{ | ||
"decoding_fn": top_p_search, | ||
"sampler": "top_p", | ||
"execution_methods": ["xla", "graph"], | ||
"args": { | ||
"p": 0.9, | ||
"seed": SEED, | ||
"from_logits": True, | ||
**TEXT_GEN_ARGS, | ||
}, | ||
}, | ||
] | ||
|
||
|
||
def generate_random_ds(vocab_size, num_samples, batch_size, seed): | ||
prompt_length = 2 | ||
def generate_random_ds(vocab_size, num_samples, batch_size, length, seed): | ||
inputs = tf.random.uniform( | ||
shape=(num_samples, prompt_length), | ||
shape=(num_samples, length), | ||
minval=0, | ||
maxval=vocab_size - 1, | ||
dtype=tf.dtypes.int32, | ||
|
@@ -134,18 +95,16 @@ def build_model( | |
|
||
|
||
def generate_text( | ||
decoding_fn, | ||
token_probability_fn, | ||
sampler, | ||
next, | ||
prompt, | ||
text_gen_args, | ||
jit_compile, | ||
): | ||
class TestModel(tf.keras.Model): | ||
def call(self, inputs): | ||
generated = decoding_fn( | ||
token_probability_fn=token_probability_fn, | ||
generated = keras_nlp.samplers.get(sampler)( | ||
next=next, | ||
prompt=inputs, | ||
**text_gen_args, | ||
) | ||
return generated | ||
|
||
|
@@ -165,6 +124,7 @@ def main(): | |
vocab_size=DATASET_ARGS["vocab_size"], | ||
num_samples=DATASET_ARGS["num_samples"], | ||
batch_size=DATASET_ARGS["batch_size"], | ||
length=MODEL_ARGS["max_length"], | ||
seed=SEED, | ||
) | ||
|
||
|
@@ -177,36 +137,34 @@ def main(): | |
ff_dim=MODEL_ARGS["ff_dim"], | ||
) | ||
|
||
def token_logits_fn(inputs): | ||
output = model(inputs) | ||
return output[:, -1, :] | ||
def next(prompt, state, index): | ||
output = model(prompt) | ||
return output[:, index, :], state | ||
|
||
print("*************************************\n") | ||
|
||
with open(csv_path, "w") as res_handler: | ||
res_handler.write("decoding_strategy,execution_method,time\n") | ||
for test_run in TEST_RUNS: | ||
decoding_fn = test_run["decoding_fn"] | ||
decoding_strategy = decoding_fn.__name__ | ||
sampler = test_run["sampler"] | ||
|
||
for execution_method in test_run["execution_methods"]: | ||
print(f"Running {decoding_strategy} in {execution_method} mode") | ||
print(f"Running {sampler} in {execution_method} mode") | ||
|
||
if execution_method == "graph": | ||
jit_compile = False | ||
elif execution_method == "xla": | ||
jit_compile = True | ||
|
||
time_taken = generate_text( | ||
decoding_fn=decoding_fn, | ||
token_probability_fn=token_logits_fn, | ||
sampler=sampler, | ||
next=next, | ||
prompt=ds, | ||
text_gen_args=test_run["args"], | ||
jit_compile=jit_compile, | ||
) | ||
print("Time taken: ", time_taken) | ||
res_handler.write( | ||
f"{decoding_strategy},{execution_method}," f"{time_taken}\n" | ||
f"{sampler},{execution_method}," f"{time_taken}\n" | ||
) | ||
print() | ||
print("*************************************") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should probably upgrade my PC 🛩️