Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
ab68a1b
fix: Lighteval communication with TGI
cpcdoy Jan 15, 2025
f442a29
fix: JSON grammar constrained generation
cpcdoy Jan 15, 2025
6bb6b13
Merge branch 'main' into fix/tgi_inference
cpcdoy Feb 7, 2025
cb13617
Merge branch 'main' into fix/tgi_inference
NathanHB Apr 8, 2025
749887e
Merge branch 'main' into fix/tgi_inference
cpcdoy Apr 9, 2025
55e3790
Merge branch 'main' into fix/tgi_inference
NathanHB Apr 17, 2025
7cc8ad3
Merge branch 'main' into fix/tgi_inference
cpcdoy Jun 11, 2025
dedab95
fix: unit tests +
cpcdoy Jun 11, 2025
2cbc91d
Merge branch 'main' into fix/tgi_inference
cpcdoy Jun 18, 2025
3014120
Merge branch 'main' into fix/tgi_inference
cpcdoy Jun 23, 2025
2691fd4
Merge branch 'main' into fix/tgi_inference
cpcdoy Jul 8, 2025
61d47ad
Merge branch 'main' into fix/tgi_inference
NathanHB Aug 12, 2025
4730d1a
Merge branch 'main' into fix/tgi_inference
cpcdoy Aug 15, 2025
b98702c
fix: request var => doc var after refactor
cpcdoy Aug 15, 2025
9f1b75c
fix: update test to support the new grammar field
cpcdoy Aug 15, 2025
ff9553b
Merge branch 'main' into fix/tgi_inference
cpcdoy Aug 15, 2025
b049367
fix: TGI endpoint with the new refactor
cpcdoy Aug 20, 2025
cfd61a1
update: TGI model config in examples with the latest parameters
cpcdoy Aug 20, 2025
e144ddd
add: example custom task on a classification dataset to demonstrate t…
cpcdoy Aug 20, 2025
3faea79
Merge branch 'main' into fix/tgi_inference
cpcdoy Aug 20, 2025
c95c88d
add: format example task
cpcdoy Aug 20, 2025
9df8575
fix: unit test
cpcdoy Aug 20, 2025
33cebe9
add: adapt the in the yaml config to use similarly to the other end…
cpcdoy Aug 20, 2025
b529ef3
Merge branch 'main' into fix/tgi_inference
cpcdoy Aug 20, 2025
1aebd59
Merge branch 'main' into fix/tgi_inference
cpcdoy Aug 20, 2025
93a645b
clean: moved new task to community_tasks
cpcdoy Aug 20, 2025
a34353b
fix: format
cpcdoy Aug 20, 2025
88da108
clean: delete unused grammar field
cpcdoy Aug 20, 2025
007932d
del: grammar
cpcdoy Aug 20, 2025
19afc28
add: copyright at the top
cpcdoy Aug 21, 2025
90e505f
del: langcodes dep isn't needed anymore
cpcdoy Aug 21, 2025
ab5c6b9
add: use load from file directly in the main endpoint
cpcdoy Aug 21, 2025
834aa08
del: newlines
cpcdoy Aug 21, 2025
2e5c74b
add: mock HTTP request for info to TGI server
cpcdoy Aug 21, 2025
a40b754
Merge branch 'main' into fix/tgi_inference
cpcdoy Aug 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
456 changes: 456 additions & 0 deletions community_tasks/custom_task_classification_grammar_task.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion examples/model_configs/tgi_model.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
model_parameters:
inference_server_address: ""
inference_server_address: "http://localhost:8080" # Replace with your actual TGI server address
inference_server_auth: null
model_name: null # Optional, only required if the TGI container was launched with model_id pointing to a local directory
generation_parameters:
temperature: 0.1
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ dependencies = [
]

[project.optional-dependencies]
litellm = ["litellm", "diskcache"]
tgi = ["text-generation>=0.6.0"]
litellm = ["litellm[caching]", "diskcache"]
tgi = ["text-generation>=0.7.0"]
optimum = ["optimum==1.12.0"]
quantization = ["bitsandbytes>=0.41.0", "auto-gptq>=0.4.2"]
adapters = ["peft==0.3.0"]
Expand Down
9 changes: 1 addition & 8 deletions src/lighteval/main_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,8 @@ def tgi(
"""
Evaluate models using TGI as backend.
"""
import yaml

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.endpoints.tgi_model import TGIModelConfig
from lighteval.models.model_input import GenerationParameters
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters

evaluation_tracker = EvaluationTracker(
Expand All @@ -269,11 +266,7 @@ def tgi(

parallelism_manager = ParallelismManager.TGI

with open(model_config_path, "r") as f:
config = yaml.safe_load(f)

generation_parameters = GenerationParameters(**config.get("generation", {}))
model_config = TGIModelConfig(**config["model"], generation_parameters=generation_parameters)
model_config = TGIModelConfig.from_path(model_config_path)

pipeline_params = PipelineParameters(
launcher_type=parallelism_manager,
Expand Down
2 changes: 2 additions & 0 deletions src/lighteval/models/endpoints/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ async def _async_process_batch_logprob(self, docs: list[Doc], rolling: bool = Fa
context=context if rolling else context + doc.choices[0],
stop_tokens=[],
max_tokens=1,
grammar=doc.generation_grammar,
)
for context, doc in zip(contexts, docs)
]
Expand All @@ -539,6 +540,7 @@ def _process_batch_logprob(self, docs: list[Doc], rolling: bool = False) -> list
context=context if rolling else context + doc.choices[0],
stop_tokens=[],
max_tokens=1,
grammar=doc.generation_grammar,
)
for context, doc in zip(contexts, docs)
]
Expand Down
33 changes: 32 additions & 1 deletion src/lighteval/models/endpoints/tgi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

from lighteval.models.abstract_model import ModelConfig
from lighteval.models.endpoints.endpoint_model import InferenceEndpointModel
from lighteval.tasks.prompt_manager import PromptManager
from lighteval.utils.cache_management import SampleCache
from lighteval.utils.imports import NO_TGI_ERROR_MSG, is_tgi_available


Expand Down Expand Up @@ -87,6 +89,7 @@ class TGIModelConfig(ModelConfig):
inference_server_auth: str | None = None
model_name: str | None
model_info: dict | None = None
batch_size: int = 1


# inherit from InferenceEndpointModel instead of LightevalModel since they both use the same interface, and only overwrite
Expand All @@ -110,12 +113,23 @@ def __init__(self, config: TGIModelConfig) -> None:
raise ValueError("Error occurred when fetching info: " + str(self.model_info))
if config.model_name:
self.model_info["model_id"] = config.model_name
else:
# Set the model_name in config to the actual model_id from server for caching
config.model_name = self.model_info["model_id"]
self.config = config
self._tokenizer = AutoTokenizer.from_pretrained(self.model_info["model_id"])
self._add_special_tokens = True
self.use_async = True
self.config.model_info = self.model_info

# Initialize prompt manager (required by parent class)
self.prompt_manager = PromptManager(
use_chat_template=True, tokenizer=self.tokenizer, system_prompt=config.system_prompt
)

# Initialize cache for tokenization and predictions
self._cache = SampleCache(config)

def _async_process_request(
self,
context: str,
Expand All @@ -134,7 +148,24 @@ def _async_process_request(
grammar=grammar,
)

generated_text = self.client.generate(prompt=context, generation_config=generation_config)
generated_text = self.client.generate(
prompt=context,
do_sample=generation_config.do_sample or False,
max_new_tokens=generation_config.max_new_tokens,
best_of=generation_config.best_of,
repetition_penalty=generation_config.repetition_penalty,
return_full_text=generation_config.return_full_text or False,
seed=generation_config.seed,
stop_sequences=generation_config.stop,
temperature=generation_config.temperature,
top_k=generation_config.top_k,
top_p=generation_config.top_p,
truncate=generation_config.truncate,
typical_p=generation_config.typical_p,
watermark=generation_config.watermark or False,
decoder_input_details=generation_config.decoder_input_details,
grammar=generation_config.grammar,
)

return generated_text

Expand Down
4 changes: 1 addition & 3 deletions src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,7 @@ def load_model_with_tgi(config: TGIModelConfig):
raise ImportError(NO_TGI_ERROR_MSG)

logger.info(f"Load model from inference server: {config.inference_server_address}")
model = ModelClient(
address=config.inference_server_address, auth_token=config.inference_server_auth, model_id=config.model_id
)
model = ModelClient(config=config)
return model


Expand Down
6 changes: 6 additions & 0 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]:
item["__index"] = ix
doc = self.formatter(item, self.name)
doc.id = str(ix)

# Transfer task-level generation parameters to the document
doc.generation_grammar = self.generation_grammar
doc.generation_size = self.generation_size
doc.stop_sequences = self.stop_sequence

docs.append(doc)

return docs
Expand Down
5 changes: 3 additions & 2 deletions tests/models/endpoints/test_tgi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ class TestTGIModelConfig:
(
"examples/model_configs/tgi_model.yaml",
{
"inference_server_address": "",
"inference_server_address": "http://localhost:8080",
"inference_server_auth": None,
"model_name": None,
"model_info": None,
"system_prompt": None,
"batch_size": 1,
"generation_parameters": {
"block_size": None,
"num_blocks": None,
Expand All @@ -52,7 +53,7 @@ class TestTGIModelConfig:
"repetition_penalty": None,
"seed": None,
"stop_tokens": None,
"temperature": 0,
"temperature": 0.1,
"top_k": None,
"top_p": None,
"truncate_prompt": None,
Expand Down
12 changes: 9 additions & 3 deletions tests/utils/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,21 +219,27 @@ def test_cache_vllm(self, mock_create_model, mock_greedy_until, mock_loglikeliho

self._test_cache(model)

@patch("requests.get")
@patch("lighteval.models.endpoints.tgi_model.ModelClient._greedy_until")
@patch("lighteval.models.endpoints.tgi_model.ModelClient._loglikelihood")
def test_cache_tgi(self, mock_greedy_until, mock_loglikelihood):
def test_cache_tgi(self, mock_loglikelihood, mock_greedy_until, mock_requests_get):
from lighteval.models.endpoints.tgi_model import ModelClient, TGIModelConfig
from lighteval.utils.imports import is_tgi_available

if not is_tgi_available():
pytest.skip("Skipping because missing the imports")

# Mock TGI requests
mock_greedy_until.return_value = self.model_responses
mock_loglikelihood.return_value = self.model_responses
mock_greedy_until.return_value = self.model_responses

# Mock HTTP info request
mock_requests_get.return_value.json.return_value = {"model_id": "Qwen/Qwen3-0.6B"}

with tempfile.TemporaryDirectory() as temp_dir:
config = TGIModelConfig(model_name="Qwen/Qwen3-0.6B", cache_dir=temp_dir)
config = TGIModelConfig(
model_name="Qwen/Qwen3-0.6B", cache_dir=temp_dir, inference_server_address="http://localhost:8080"
)
model = ModelClient(config)

self._test_cache(model)
Expand Down
Loading