Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Guidance on supporting TGI-based LLM #22

Closed
gopalsarda opened this issue Mar 27, 2024 · 3 comments
Closed

Guidance on supporting TGI-based LLM #22

gopalsarda opened this issue Mar 27, 2024 · 3 comments
Labels
enhancement New feature or request question Further information is requested

Comments

@gopalsarda
Copy link

gopalsarda commented Mar 27, 2024

Great work on the design and documentation of the repo!

I want to introduce a new LLM class to work with TGI servers. I did not find any detailed documentation on how to go about it. I referred to Creating a new LLM and also looked at other LLM implementations (MistralAI) within the repo but I was not able to get it to work as I had hoped.

The flow works successfully but I can see that the endpoint is getting called multiple times per input. I have attached my test script below. Temporarily, I have replaced the TGI call with a dummy response (test response). When I execute my test script, I see get_generated_texts called printed 6 times (as opposed to 2). It either looks like a bug in the implementation or some gap in my understanding. Can you please help clarify?

Test Script -

    class TGI(MistralAI):

        def _run_batch(
            self,
            max_length_func: Callable[[list[str]], int],
            inputs: list[str],
            max_new_tokens: None | int = None,
            temperature: float = 1.0,
            top_p: float = 0.0,
            n: int = 1,
            stop: None | str | list[str] = None,
            repetition_penalty: None | float = None,
            logit_bias: None | dict[int, float] = None,
            batch_size: int = DEFAULT_BATCH_SIZE,
            seed: None | int = None,
            **kwargs,
        ) -> list[str] | list[list[str]]:
            prompts = inputs
            assert (
                stop is None or stop == []
            ), f"`stop` is not supported for {type(self).__name__}"
            assert (
                repetition_penalty is None
            ), f"`repetition_penalty` is not supported for {type(self).__name__}"
            assert (
                logit_bias is None
            ), f"`logit_bias` is not supported for {type(self).__name__}"
            assert n == 1, f"Only `n` = 1 is supported for {type(self).__name__}"
        
            # Run the model
            def get_generated_texts(self, kwargs, prompt) -> list[str]:                
                print("get_generated_texts called")
                return ["test response"]

            if batch_size not in self.executor_pools:
                self.executor_pools[batch_size] = ThreadPoolExecutor(max_workers=batch_size)
            generated_texts_batch = list(
                self.executor_pools[batch_size].map(
                    partial(get_generated_texts, self, kwargs), prompts
                )
            )

            if n == 1:
                return [batch[0] for batch in generated_texts_batch]
            else:  # pragma: no cover
                return generated_texts_batch           
                
    

    with DataDreamer(":memory:"):
        tgi_model = TGI(model_name="tgi_model")

        eli5_dataset = HFHubDataSource(
            "Get ELI5 Questions",
            "eli5_category",
            split="train",
            trust_remote_code=True,
        ).select_columns(["title"])

        # Keep only 2 examples as a quick demo
        eli5_dataset = eli5_dataset.take(2, lazy=False)

        # Ask llm to ELI5
        questions_and_answers = Prompt(
            "Generate Explanations",
            inputs={"prompts": eli5_dataset.output["title"]},
            args={
                "llm": tgi_model,
                "instruction": (
                    'Given the question, give an "Explain it like I\'m 5" answer.'
                ),
                "lazy": False,
                "top_p": 1.0,
            },
            outputs={"prompts": "questions", "generations": "answers"},
        )

        print(f"{questions_and_answers.head()}")
@AjayP13
Copy link
Collaborator

AjayP13 commented Mar 27, 2024

Hi, thanks for looking at integrating this. Before looking at this further, we do actually have TGI support I believe.

See the HFAPIEndpoint class: https://datadreamer.dev/docs/latest/datadreamer.llms.html#datadreamer.llms.HFAPIEndpoint

You can see that it uses InferenceClient from huggingface_hub under the hood:

return InferenceClient(model=self.endpoint, token=self.token, **self.kwargs)

which is compatible with TGI according to this page: https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/consuming_tgi#inference-client

Does this solve your need?

@AjayP13 AjayP13 added the question Further information is requested label Mar 27, 2024
@gopalsarda
Copy link
Author

@AjayP13 Thanks! This helps with my issue to use TGI. Two follow-up questions -

  1. I am planning to implement some custom steps for my use-case so trying to understand what was the issue with my test script. Why do I see get_generated_texts called printed 6 times
  2. I am interested in implementing a SuperStep to generate multi-turn synthetic data. Do you have any examples I can refer to?

@AjayP13
Copy link
Collaborator

AjayP13 commented Mar 28, 2024

Perfect, yep, I can explain those, the answers are a little complicated.

  1. Even though you are specifying lazy=False, your Prompt step is indeed running lazily due to the :memory: option (in that code path it actually requires a folder to write results.) When a step is run lazily, internally we call the iterator backing the lazy result just to get the first element (a "peek" operation) in order to check what it's type is, since we can't detect the type with an iterator. So the extra calls are coming from these "peek" operations. You generally don't need to be concerned with those though for your implementation though as an end user and they don't add any significant overhead for two reasons: 1) these peeks only read the first element, so if your lazy iterator returns 1 million rows, only the first element is read, so the overhead doesn't scale with the number of rows 2) results from the LLM are cached.

I can work on trying to see if I can remove that constraint so it doesn't run lazily in the future though. I will take this up as an enhancement.

  1. No examples at this time other than this page, but I think this page should have everything you need: https://datadreamer.dev/docs/latest/pages/advanced_usage/creating_a_new_datadreamer_.../step.html

and feel free to ask any questions if you run into any trouble.

@AjayP13 AjayP13 added the enhancement New feature or request label Mar 28, 2024
@AjayP13 AjayP13 closed this as completed Apr 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants