Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speculative sampling #17

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Speculative sampling #17

wants to merge 10 commits into from

Conversation

daviswer
Copy link
Collaborator

Implements simple speculative sampling via candidate-consistent ground-truth sampling. See #12 for a discussion on implementation details and why this is needed in the first place.

  • Add __generate_targets() function, implementing both greedy and non-greedy selection
  • Enable sampling in speculative_generate()
  • Support temperature and top_k sampling as in non-speculative generate
  • Allow user to set these as arguments in the paged_speculative_inference.py demo script

Notably, for low temperature and top_k, we anecdotally observe no reduction in speculator performance compared to the greedy case!

@daviswer
Copy link
Collaborator Author

daviswer commented Apr 11, 2024

Example outputs demonstrating new sampling capabilities:

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Provide a list of instructions for preparing chicken soup.

### Response:

Greedy baseline:

Sure! Here are the steps to prepare chicken soup:

1. Start by chopping 1 onion and 3 cloves of garlic.
2. In a large pot, heat 2 tablespoons of olive oil over medium heat.
3. Add the chopped onion and sauté until it's translucent.
4. Add the chopped garlic and sauté for an additional 2 minutes.
5. Add 
103 tokens in 34 steps

time to first token: 0.026240825653076172
time per token (decode): 0.009298106999073213

top_k=5, temperature=2: no slowdown

Here is a recipe for preparing a delicious homemade chicken soup:

Step 1: Gather Ingredients:

* 2 lbs. boneless skinless chicken breasts or thighs, cut into bite-sized pieces
* 2 carrots, peeled and chopped
* 2 celery stalks, chopped
* 1 large onion, chopped
* 2-4 cups of
104 tokens in 32 steps

time to first token: 0.02620673179626465
time per token (decode): 0.00881159076323876

top_k=5, temperature=5: slowdown due to low likelihood of (ridiculous) output

Soup-tactical Movement:
To create this savor-ific broth:
1. Gather the squad of chicken brews and chicken parts (legs and feet work best) in one area for staging. Ensuring all ingredients are at room temperature.
2. In a galleon (Larger pot) add one-quint (about 1.3 litters) water and two pinch's salt to create the bre
102 tokens in 68 steps

time to first token: 0.026073217391967773
time per token (decode): 0.018784146682888855

# Composite greedy and non greedy outputs
greedy = logits.argmax(-1)
mask = do_sample[:, None, None].int()
return samples * mask + (1 - mask) * greedy
Copy link
Member

@nairbv nairbv Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the mask is really a mask and not a weighting, might be better to use torch.where.

we're calculating the sampled results even if we don't use them? I guess that's something to do with compilation but I would have thought the generation code would be outside the compile path?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll swap to torch.where. We are calculating the sampled result for every case, and while that will be useful for compile down the road, in this case it's mostly just for efficient gpu usage - pretty sure that partitioning the greedy/non-greedy lines and then re-mixing them after is more work than just sampling everything

For example, if the base model predicts tokens A and B with equal 50% probability, and the
speculator produces one candidate with A and another with B, with independent sampling there's
a 25% chance of rejecting both, even though one must be correct. Consistent sampling allows us
to avoid this.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the goal is to speculate on a mutually exclusive set of possible continuations, why are we sampling at all and not just speculating on the top-k predictions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do this, but we're more concerned with the ability to sample here than we are with the non-greediness of the approach. In this case "not greedy" is meant strictly literally, in that sampling involves not selecting greedily (assuming I'm understanding the question)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants