Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: basic correctness test #192

Closed
wants to merge 5 commits into from
Closed

Conversation

derekk-nm
Copy link

Introducing an end-to-end test case that verifies basic correctness of the vllm engine by comparing the tokens output by the vllm OpenAI server with tokens generated by the HuggingFace model created by AutoModelForCausalLM.from_pretrained().
This Work In Progress PR is intended to verify the approach used to reconcile the different logprobs output generated by the server (keyed by tokens) vs that generated by the HF model (keyed by token id). It uses duplicate code in the tests.conftest.HfRunnerNM.generate_greedy_logprobs_nm_use_tokens() method, replacing a token id with the actual token. As long as this approach is approved, I'll come up with some code refactoring to reduce this, and other code duplication used in other fixtures in conftest.py.
The use of the new VllmServer class is also an open question as we are still determining the actual infrastructure that will be used to execute this test. [credit to Domenic Barbuzzi for the VllmServer code]

for review of approach.  still requires refactor of HFRunnernm methods.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this edit should have been a copy of the existing code. clearly, these changes will break other tests using this method.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would make sense to make a copy somewhere inside the neuralmagic folder for this usage.

@dbarbuzzi
Copy link

Thanks @derekk-nm, I think the new test is looking good!

# now repeat using two gpus
# specifically doing it here, rather than as a pytest param,
# to avoid repeating the huggingface inference and data collection

If this test isn’t getting forked through pytest-forked or similar, I can work on a cached helper function to generate the referenced outputs, so using 1 vs. 2 GPUs can be separate tests while still only incurring the cost of the HF output generation once.

@derekk-nm
Copy link
Author

derekk-nm commented Apr 17, 2024

Another open question to resolve w/ this code is the handling of spaces in tokens.

here's an image from my debugger showing that huggingface generated text does not have spaces in the tokens, but the openai server's response does.
(I did try using clean_up_tokenization_spaces=False on the huggingface tokenizer's decode method when I retrieved the token's string, but that didn't help).
generated tokens and leading spaces diff

Also, another image of an "empty" response from one of the other prompts that includes that marker. we'll have to deal w/ this too.
Screenshot 2024-04-17 at 7 31 57 AM

@derekk-nm
Copy link
Author

Thanks @derekk-nm, I think the new test is looking good!

# now repeat using two gpus
# specifically doing it here, rather than as a pytest param,
# to avoid repeating the huggingface inference and data collection

If this test isn’t getting forked through pytest-forked or similar, I can work on a cached helper function to generate the referenced outputs, so using 1 vs. 2 GPUs can be separate tests while still only incurring the cost of the HF output generation once.

As far as I can tell, the tests are currently run with --forked, at least the existing tests. I don't know if we'll end up with some different architecture to run end-to-end tests.

@dhuangnm dhuangnm requested a review from bnellnm April 17, 2024 14:31
@dbarbuzzi
Copy link

As far as I can tell, the tests are currently run with --forked, at least the existing tests. I don't know if we'll end up with some different architecture to run end-to-end tests.

In that case, maybe we can review tests so only tests that require being forked are (by using the pytest marker instead of the CLI flag) or otherwise splitting into multiple runs to be able to not fork this one. Overall, I don’t think such an enhancement is required for the initial implementation.

@dbarbuzzi
Copy link

@derekk-nm I added a couple of comments about the arg list. Per a previous discussion, I left the initial implementation in where the args dict passed to the VllmServer constructor expects all keys and values to be strings, and should actually be throwing errors in your tests due to some of the values in the dicts being ints, floats, and None.

So, summarizing some comments:

  • If the test and server helper are working as-is, something is going wrong with input validation
    • Actually, going even further, even if the built-in error-checking is not passing, the "command" list passed to something like Popen basically requires strings (technically also bytes or Pathlike objects) so no matter what, passing ints or None as values should not actually work…
  • All that said, we can certainly amend VllmServer so that this args dict can have non-string values, that was a design question and simply requires some refactoring of the logic within the _args_to_list private method (no other external usage changes).

@dbarbuzzi
Copy link

Should this --tensor-parallel-size be treated similarly as pytest.mark.parametrize for other arguments e.g. model etc? There is no need to repeat the test code here just for --tensor-parallel-size=2. Also it can be specified to different value than 2 based on the gpu type.

@dhuangnm I believe per the comments that this is currently specifically part of the same test so that the hugging-face outputs only need to be generated once (as they are needed for a baseline comparison after both server portions).


with VllmServer(api_server_args, logger):
completion = await client.completions.create(model=model,
prompt=example_prompts,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is doing something different than the tests that we previously created.

  • Specifically, this tests ONE client, sending N requests

The case we want to test is

  • N clients sending 1 request

I used threading in the previous example to simulate this:

I think we should recreate this here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, it wasn't mentioned in the Basic Correctness Automation doc, so I forgot to think about how your client script worked. I'll refactor the client submissions accordingly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added an async function to accomplish the same thing.

api_server_args["--sparsity"] = sparsity

with VllmServer(api_server_args, logger):
completion = await client.completions.create(model=model,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use the client.chat.completions API

This requires sending a list of messages

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-neuralmagic , are all of the models in the test "chat" models? some say "chat" in their name, but not all.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-neuralmagic , I've made the necessary changes for this, but we'll need to discuss this some more. for the one model I've run through this with the responses between HuggingFace and vllm server are not very close at all when using the chat mode. they seemed to be much closer w/ the non-chat mode of requests.

) -> None:

ds = load_dataset("nm-testing/qa-chat-prompts", split="train_sft")
example_prompts = [m[0]["content"] for m in ds["messages"]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To match the chat.completions API we need to convert these messages into a chat history by using the tokenizer's "chat template"

from transformers import AutoTokenizer

tokenizer=AutoTokenizer.from_pretrained(model)

# 3 chat convo turns
NUM_TURNS = 3 
messages_list = [row["messages"][:NUM_TURNS] for row in ds]
prompts = [ tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) for messages in messages_list]

hf_model = hf_runner_nm(model)
hf_outputs = hf_model.generate_greedy_logprobs_nm_use_tokens(
    example_prompts, max_tokens, num_logprobs)

### ...
# when we call the completions api
completion = await client.chat.completions.create(model=model, messages=messages)

@@ -0,0 +1,141 @@
import logging
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think these should go in the neuralmagic repo

These should go in the tests repo

@@ -0,0 +1,124 @@
import logging
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be in the tests repo

perhaps

tests/integration or something?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we decided to put these end-to-end tests in the neuralmagic tests path?

all_logprobs.append(seq_logprobs_lst)
seq_ids = output.sequences[0]
output_len = seq_ids.shape[0] - input_ids.shape[1]
output_ids = seq_ids[-output_len:]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty hacky and will likely cause us errors down the road, as you can see that the hugging face tokens we create have all the whitespace removed

vLLM output: ['\n', '\n', 'The', 'Sydney', 'Conserv', 'ator', 'ium', 'of', 'Music', '(', 'SC', 'M', '),', 'located', 'in', 'the', 'heart', 'of', 'Sydney', "'", 's', 'central', 'business', 'district', ',', 'offers', 'unique', 'academic', 'and', 'professional', 'opportunities', 'to'], 

HF output: ['\n', '\n', 'The', ' Sydney', ' Conserv', 'ator', 'ium', ' of', ' Music', ' (', 'SC', 'M', '),', ' located', ' in', ' the', ' heart', ' of', ' Sydney', "'", 's', ' central', ' business', ' district', ',', ' offers', ' unique', ' academic', ' and', ' professional', ' opportunities', ' to'],

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a simple example of how to detokenize properly. You have to "lookback" over 3-4 tokens and just extract the incremental text that is generated

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
text = "Hello my name is Robert and I love working on vLLM with the team at Neural Magic"

token_ids = tokenizer(text).input_ids

NUM_PROMPT_TOKENS = 5
NUM_TOKENS = len(token_ids)

LOOKBACK = 4

generation_tokens = []
for cur_idx in range(NUM_PROMPT_TOKENS, NUM_TOKENS):
    prior_str = tokenizer.decode(token_ids[cur_idx - LOOKBACK: cur_idx])
    current_str = tokenizer.decode(token_ids[cur_idx - LOOKBACK: cur_idx + 1])
    token = current_str[-(len(current_str) - len(prior_str)):]
    generation_tokens.append(token)

print(generation_tokens)

generation_str = ""
for generation_token in generation_tokens:
    generation_str += generation_token

print(f"Full Text: {text}")
print(f"Prompt: |{tokenizer.decode(token_ids[:NUM_PROMPT_TOKENS], skip_special_tokens=True)}|")
print(f"Generation: |{generation_str}|")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! I'm happy that this should fix the white space issue! I'll come up with a solution that can be shared w/ your generate_greedy_logprobs_nm() method.

* update to latest Server and logging from Domenic
* improve compare_utils.py to report all prompts in error, not just the first to fail.
* temporarily just do one model
* temporarily include a huggingface token to the HfRunner
* use the client.chat.completions.create with vllm server, and chat prompts with HfRunner
* use individual concurrent clients for the server requests instead of concurrent requests from one client
@derekk-nm
Copy link
Author

@robertgshaw2-neuralmagic , @dbarbuzzi , @dhuangnm , I've pushed a bunch of changes to this WIP PR. Key things are specified in the commit message. the things that remain are:

  • figure out how to ultimately do the merge to main...pile my stuff on top of Domenic's server branch, or do these separately
  • move files around per the PR comments
  • identify which models will actually work with this test (many will fail due to missing optimum package)
  • get the retrieve the token strings correctly
  • uncomment the line that includes the tensor_parallel_size = 2 parameter and test that it works on a k8s env that includes the necessary gpus

* format messages to the ChatCompletion.create method correctly.
* skipped more models and kept just one for now.
* added logging, doc, and type hints
* used torch.cuda.get_device_capability to set the `dtype` arg accordingly
* refactored some common code to a hidden method in HfRunnerNM
* Added tokenizer decode method to properly decode tokens
* improve conversion of tokenids to strings from HfRunnerNM so that the output can be compared with the output from vllm server.
* reverted the old check_logprobs_close() method and added a new one with the different approach for our basic server correctness test.
@derekk-nm
Copy link
Author

I believe this latest commit correctly decodes the token ids from the HFRunnerNM, but I can't find a way to decode the token ids in the logprobs keys using _decode_token_by_position_index. The call to topk = tok_logprobs.topk() (line 416 in conftest.py) then followed by the call for topk.indices[0] (line 427) is returning the actual token_ids. I can't seem to find any useful explanation of how topk.indices() is defined, but the doc that I found specifically states that topk will return the indices to the tensor. I don't know which array the logprobs come from. when I attempted to use the keys in topk.indices() as an index to seq_ids, the decoded values didn't make any sense (straight decode or passed to self._decode_token_by_position_index.

what I've done is to tokenize the top logprob keys using self.tokenizer.decode . This leaves us with the missing leading space for the huggingface output, so my compare util is stripping white space for the comparison. This is not ideal, I know, but it seems to be working. most of the prompts pass with flying colors. Certainly open to suggestions!

With that said, there are a number of prompts that result in responses that fail the "closeness" test we're using here. A manual review of the results (latest run output attached) demonstrates that conceptually, the responses are similar, but grammatically, they are not, so it's reasonable for them to fail this test. @robertgshaw2-neuralmagic , we'll need to discuss which model to choose to use temporarily. My run yesterday of all of the models in the current list of params had none pass.

This result prompted me to do additional exploration for a way to get the OpenAI server to generate token ids instead of text. ChatGPT and Groq both seem to have hallucinated some interesting args that simply aren't real (at least not for the version of OpenAI that we're using). I still don't think it's currently possible to get the openai server to generate output that matches for format of the HuggingFace runner.

re-reading other comments here, I know that I still have to move things around.

@derekk-nm
Copy link
Author

server_basic_correctness_mistral_results_202404301351.txt
this attachment is the latest output from the test execution. There are two tests using the same model, one test using a single gpu, and another with 2 (via the tensor_parallel_size param). If you scroll to the lines that start with "E", you'll get to the error messages for the failing tests. As you can see, the majority of prompts successfully compare between the HF and vllm server implementations, but the multi-gpu env is more "brittle" for this test.

It's not clear to me if the issue is that the token strings in the logprobs dictionary were created by decoding the token id w/out any lookback (thus possibly missing adding a potential option in the logprobs that would have allowed the prompt to pass the test), or if there's an actual bug in the server implementation causing these prompts to generate results that do not match the HuggingFace implementation.

derekk-nm and others added 2 commits May 2, 2024 01:31
* revert to using check_logprobs_close, but enhance it w/ more a detailed error message
* correctly decode logprobs keys by replacing the current "chosen" token with each key of our current `indexed_seq` list of token ids, then passing that token id and the list to our decoding method.
Co-authored-by: Domenic Barbuzzi <dbarbuzzi@gmail.com>
@derekk-nm
Copy link
Author

Thank you @robertgshaw2-neuralmagic for the consultation. I've updated the script to use the existing comparison function, and corrected the method that decodes the logprob keys. The two tests are now passing with the single model.
@dbarbuzzi , let's coordinate on the location of the tests and merging of to main.

@derekk-nm derekk-nm closed this May 29, 2024
@derekk-nm derekk-nm deleted the end_to_end-test-basic_correctness branch May 29, 2024 20:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants