Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for prompt lookup decoding during generation #3917

Merged
merged 2 commits into from
Jan 27, 2024

Conversation

arnavgarg1
Copy link
Contributor

@arnavgarg1 arnavgarg1 commented Jan 25, 2024

Implements support for Prompt Lookup Decoding by exposing a new generation config parameter called prompt_lookup_num_tokens. Compatible with transformer version >= 4.37.0.

In scenarios where the prompt is long and the output generated might re-use a lot of common ngrams, this can speedup token generation by near 2x - 2.4x. However, in scenarios where that may not be the case, such as open-ended questions, it leads to a 10% decrease in tokens per second.

Demo

https://drive.google.com/file/d/1E8qq8HnJBhL7GOuFDuMdih_GY1aAVwEC/view?usp=sharing

Script to Reproduce Demo

import yaml
import logging
from ludwig.api import LudwigModel

config = yaml.safe_load(
    """
model_type: llm
base_model: meta-llama/Llama-2-7b-chat-hf

quantization:
  bits: 4

input_features:
  - name: instruction
    type: text

output_features:
  - name: output
    type: text

generation:
  max_new_tokens: 64
  temperature: 0.1

trainer:
    type: none

backend:
  type: local
"""
)

model = LudwigModel(config, logging_level=logging.INFO)

code_text = """import numpy as np
import matplotlib.pyplot as plt

# Calculate the average
average_throughput = np.mean(tokens_per_sec_arr)
print(f"Average Throughput: {average_throughput} tokens/sec")

# Plotting the histogram
plt.hist(tokens_per_sec_arr, bins=20, color='blue', edgecolor='black', alpha=0.7)
plt.title('Histogram of Throughput Values')
plt.xlabel('Tokens per Second')
plt.ylabel('Frequency')
plt.axvline(average_throughput, color='red', linestyle='dashed', linewidth=1)
plt.text(average_throughput*0.9, max(plt.ylim())*0.9, f'Average: {average_throughput:.2f}', color = 'red')
plt.show()
"""

question = "Can you please change x axis to start from 0"

prompt = """<|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
```python\n{code_text}``` \n\n{question}<|im_end|>
<|im_start|>assistant
""".format(code_text=code_text, question=question)

# Normal generation -> ~ 17s
output = model.generate(
    prompt, 
    generation_config={"max_new_tokens": 500}, 
    streaming=True
)

# With Prompt Lookup Decoding -> ~ 8.7s
output = model.generate(
    prompt, 
    generation_config={"max_new_tokens": 500, "prompt_lookup_num_tokens": 10}, 
    streaming=True
)

Copy link

github-actions bot commented Jan 25, 2024

Unit Test Results

  6 files  ±0    6 suites  ±0   14m 15s ⏱️ -1s
12 tests ±0    9 ✔️ ±0    3 💤 ±0  0 ±0 
60 runs  ±0  42 ✔️ ±0  18 💤 ±0  0 ±0 

Results for commit 218f58b. ± Comparison against base commit 9bb89c6.

♻️ This comment has been updated with latest results.

@arnavgarg1 arnavgarg1 merged commit 6c0ca09 into master Jan 27, 2024
18 checks passed
@arnavgarg1 arnavgarg1 deleted the prompt_lookup_decoding branch January 27, 2024 14:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants