From ab68a1b9f90a5724230ad1ba77e6b2b80a873343 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Wed, 15 Jan 2025 15:12:49 +0100 Subject: [PATCH 01/20] fix: Lighteval communication with TGI --- src/lighteval/models/endpoints/tgi_model.py | 20 ++++++++++++++++++-- src/lighteval/models/model_loader.py | 4 +--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/lighteval/models/endpoints/tgi_model.py b/src/lighteval/models/endpoints/tgi_model.py index f0bb712b6..903241ef9 100644 --- a/src/lighteval/models/endpoints/tgi_model.py +++ b/src/lighteval/models/endpoints/tgi_model.py @@ -101,7 +101,7 @@ def __init__(self, config: TGIModelConfig) -> None: model_name = str(self.model_info["model_id"]) model_sha = self.model_info["model_sha"] - model_precision = self.model_info["model_dtype"] + model_precision = self.model_info.get("model_dtype") self.model_info = ModelInfo( model_name=model_name, model_sha=model_sha, @@ -127,7 +127,23 @@ def _async_process_request( grammar=grammar, ) - generated_text = self.client.generate(prompt=context, generation_config=generation_config) + generated_text = self.client.generate( + prompt=context, + do_sample=generation_config.do_sample or False, + max_new_tokens=generation_config.max_new_tokens, + best_of=generation_config.best_of, + repetition_penalty=generation_config.repetition_penalty, + return_full_text=generation_config.return_full_text or False, + seed=generation_config.seed, + stop_sequences=generation_config.stop, + temperature=generation_config.temperature, + top_k=generation_config.top_k, + top_p=generation_config.top_p, + truncate=generation_config.truncate, + typical_p=generation_config.typical_p, + watermark=generation_config.watermark or False, + decoder_input_details=generation_config.decoder_input_details, + ) return generated_text diff --git a/src/lighteval/models/model_loader.py b/src/lighteval/models/model_loader.py index 68835fda7..d24e84045 100644 --- a/src/lighteval/models/model_loader.py +++ b/src/lighteval/models/model_loader.py @@ -108,9 +108,7 @@ def load_model_with_tgi(config: TGIModelConfig): raise ImportError(NO_TGI_ERROR_MSG) logger.info(f"Load model from inference server: {config.inference_server_address}") - model = ModelClient( - address=config.inference_server_address, auth_token=config.inference_server_auth, model_id=config.model_id - ) + model = ModelClient(config=config) return model From f442a29aaf9e74a0b42cb6c2ad7d5d3e25521a07 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Wed, 15 Jan 2025 18:00:27 +0100 Subject: [PATCH 02/20] fix: JSON grammar constrained generation --- pyproject.toml | 2 +- src/lighteval/models/endpoints/endpoint_model.py | 2 ++ src/lighteval/models/endpoints/tgi_model.py | 1 + src/lighteval/models/model_input.py | 2 ++ 4 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f60e610ea..a4729105e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,7 @@ dependencies = [ [project.optional-dependencies] litellm = ["litellm", "diskcache"] -tgi = ["text-generation==0.6.0"] +tgi = ["text-generation==0.7.0"] optimum = ["optimum==1.12.0"] quantization = ["bitsandbytes>=0.41.0", "auto-gptq>=0.4.2"] adapters = ["peft==0.3.0"] diff --git a/src/lighteval/models/endpoints/endpoint_model.py b/src/lighteval/models/endpoints/endpoint_model.py index 37bb9754e..ad8d79682 100644 --- a/src/lighteval/models/endpoints/endpoint_model.py +++ b/src/lighteval/models/endpoints/endpoint_model.py @@ -478,6 +478,7 @@ async def _async_process_batch_logprob( context=request.context if rolling else request.context + request.choice, stop_tokens=[], max_tokens=1, + grammar=request.generation_grammar, ) for request in requests ] @@ -491,6 +492,7 @@ def _process_batch_logprob( context=request.context if rolling else request.context + request.choice, stop_tokens=[], max_tokens=1, + grammar=request.generation_grammar, ) for request in requests ] diff --git a/src/lighteval/models/endpoints/tgi_model.py b/src/lighteval/models/endpoints/tgi_model.py index 903241ef9..fc1083aa0 100644 --- a/src/lighteval/models/endpoints/tgi_model.py +++ b/src/lighteval/models/endpoints/tgi_model.py @@ -143,6 +143,7 @@ def _async_process_request( typical_p=generation_config.typical_p, watermark=generation_config.watermark or False, decoder_input_details=generation_config.decoder_input_details, + grammar=generation_config.grammar, ) return generated_text diff --git a/src/lighteval/models/model_input.py b/src/lighteval/models/model_input.py index 04e35be17..c552a7ae0 100644 --- a/src/lighteval/models/model_input.py +++ b/src/lighteval/models/model_input.py @@ -42,6 +42,7 @@ class GenerationParameters: min_p: Optional[float] = None # vllm, transformers top_p: Optional[int] = None # vllm, transformers, tgi truncate_prompt: Optional[bool] = None # vllm, tgi + grammar: Optional[str] = None # tgi @classmethod def from_dict(cls, config_dict: dict): @@ -117,5 +118,6 @@ def to_tgi_ie_dict(self) -> dict: "top_k": self.top_k, "top_p": self.top_p, "truncate": self.truncate_prompt, + "grammar": self.grammar, } return {k: v for k, v in args.items() if v is not None} From dedab95bf2acc819313e9e826aa5ca7ac1a1cb77 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Wed, 11 Jun 2025 10:58:45 +0200 Subject: [PATCH 03/20] fix: unit tests + add: dep in extra --- pyproject.toml | 3 ++- tests/models/endpoints/test_endpoint_model.py | 1 + tests/models/endpoints/test_tgi_model.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 212f6fbf6..163f09851 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ dependencies = [ "GitPython>=3.1.41", # for logging "datasets>=3.5.0", "pydantic", - "numpy<2", # pinned to avoid incompatibilities + "numpy<2", # pinned to avoid incompatibilities # Prettiness "typer", "termcolor==2.3.0", @@ -110,6 +110,7 @@ multilingual = [ "spacy[ja,ko,th]", "jieba", # for chinese tokenizer "pyvi", # for vietnamese tokenizer + "langcodes>=3.5.0", ] math = ["latex2sympy2_extended==1.0.6"] wandb = ["wandb"] diff --git a/tests/models/endpoints/test_endpoint_model.py b/tests/models/endpoints/test_endpoint_model.py index d2291d17e..dd46d874c 100644 --- a/tests/models/endpoints/test_endpoint_model.py +++ b/tests/models/endpoints/test_endpoint_model.py @@ -52,6 +52,7 @@ class TestInferenceEndpointModelConfig: "generation_parameters": { "early_stopping": None, "frequency_penalty": None, + "grammar": None, "length_penalty": None, "max_new_tokens": 256, "min_new_tokens": None, diff --git a/tests/models/endpoints/test_tgi_model.py b/tests/models/endpoints/test_tgi_model.py index 872dc06ce..d030607f8 100644 --- a/tests/models/endpoints/test_tgi_model.py +++ b/tests/models/endpoints/test_tgi_model.py @@ -39,6 +39,7 @@ class TestTGIModelConfig: "generation_parameters": { "early_stopping": None, "frequency_penalty": None, + "grammar": None, "length_penalty": None, "max_new_tokens": None, "min_new_tokens": None, From b98702cf7374da902375d084ee48ae95102e2237 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Fri, 15 Aug 2025 11:12:47 +0200 Subject: [PATCH 04/20] fix: request var => doc var after refactor --- src/lighteval/models/endpoints/endpoint_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lighteval/models/endpoints/endpoint_model.py b/src/lighteval/models/endpoints/endpoint_model.py index 39c2ad6c3..f2dc2c03b 100644 --- a/src/lighteval/models/endpoints/endpoint_model.py +++ b/src/lighteval/models/endpoints/endpoint_model.py @@ -527,7 +527,7 @@ async def _async_process_batch_logprob(self, docs: list[Doc], rolling: bool = Fa context=context if rolling else context + doc.choices[0], stop_tokens=[], max_tokens=1, - grammar=request.generation_grammar, + grammar=doc.generation_grammar, ) for context, doc in zip(contexts, docs) ] @@ -540,7 +540,7 @@ def _process_batch_logprob(self, docs: list[Doc], rolling: bool = False) -> list context=context if rolling else context + doc.choices[0], stop_tokens=[], max_tokens=1, - grammar=request.generation_grammar, + grammar=doc.generation_grammar, ) for context, doc in zip(contexts, docs) ] From 9f1b75c9df925d381a1617d935bb8b92b54c2c44 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Fri, 15 Aug 2025 11:13:32 +0200 Subject: [PATCH 05/20] fix: update test to support the new grammar field --- tests/logging/test_evaluation_tracker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/logging/test_evaluation_tracker.py b/tests/logging/test_evaluation_tracker.py index ba4517245..8d464e963 100644 --- a/tests/logging/test_evaluation_tracker.py +++ b/tests/logging/test_evaluation_tracker.py @@ -241,6 +241,7 @@ def setUp(self): "presence_penalty": None, "max_new_tokens": None, "min_new_tokens": None, + "grammar": None, "seed": None, "stop_tokens": None, "temperature": 0, From b049367fe5d7ae12f96f35c8d5a23511ee7618b1 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Wed, 20 Aug 2025 15:15:57 +0200 Subject: [PATCH 06/20] fix: TGI endpoint with the new refactor --- pyproject.toml | 4 ++-- src/lighteval/main_endpoint.py | 2 +- src/lighteval/models/endpoints/tgi_model.py | 14 ++++++++++++++ src/lighteval/tasks/lighteval_task.py | 6 ++++++ 4 files changed, 23 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2c7e42e60..c9fb2c0a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,8 +85,8 @@ dependencies = [ ] [project.optional-dependencies] -litellm = ["litellm", "diskcache"] -tgi = ["text-generation==0.7.0"] +litellm = ["litellm[caching]", "diskcache"] +tgi = ["text-generation>=0.7.0"] optimum = ["optimum==1.12.0"] quantization = ["bitsandbytes>=0.41.0", "auto-gptq>=0.4.2"] adapters = ["peft==0.3.0"] diff --git a/src/lighteval/main_endpoint.py b/src/lighteval/main_endpoint.py index 793a578a5..c16fee2ab 100644 --- a/src/lighteval/main_endpoint.py +++ b/src/lighteval/main_endpoint.py @@ -265,7 +265,7 @@ def tgi( config = yaml.safe_load(f) generation_parameters = GenerationParameters(**config.get("generation", {})) - model_config = TGIModelConfig(**config["model"], generation_parameters=generation_parameters) + model_config = TGIModelConfig(**config["model_parameters"], generation_parameters=generation_parameters) pipeline_params = PipelineParameters( launcher_type=parallelism_manager, diff --git a/src/lighteval/models/endpoints/tgi_model.py b/src/lighteval/models/endpoints/tgi_model.py index 8527f7dbd..350196a87 100644 --- a/src/lighteval/models/endpoints/tgi_model.py +++ b/src/lighteval/models/endpoints/tgi_model.py @@ -30,6 +30,8 @@ from lighteval.models.abstract_model import ModelConfig from lighteval.models.endpoints.endpoint_model import InferenceEndpointModel +from lighteval.tasks.prompt_manager import PromptManager +from lighteval.utils.cache_management import SampleCache from lighteval.utils.imports import NO_TGI_ERROR_MSG, is_tgi_available @@ -87,6 +89,7 @@ class TGIModelConfig(ModelConfig): inference_server_auth: str | None = None model_name: str | None model_info: dict | None = None + batch_size: int = 1 # inherit from InferenceEndpointModel instead of LightevalModel since they both use the same interface, and only overwrite @@ -110,11 +113,22 @@ def __init__(self, config: TGIModelConfig) -> None: raise ValueError("Error occurred when fetching info: " + str(self.model_info)) if config.model_name: self.model_info["model_id"] = config.model_name + else: + # Set the model_name in config to the actual model_id from server for caching + config.model_name = self.model_info["model_id"] self.config = config self._tokenizer = AutoTokenizer.from_pretrained(self.model_info["model_id"]) self._add_special_tokens = True self.use_async = True self.config.model_info = self.model_info + + # Initialize prompt manager (required by parent class) + self.prompt_manager = PromptManager( + use_chat_template=True, tokenizer=self.tokenizer, system_prompt=config.system_prompt + ) + + # Initialize cache for tokenization and predictions + self._cache = SampleCache(config) def _async_process_request( self, diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index 42ba3408e..3317b7c8a 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -256,6 +256,12 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]: item["__index"] = ix doc = self.formatter(item, self.name) doc.id = str(ix) + + # Transfer task-level generation parameters to the document + doc.generation_grammar = self.generation_grammar + doc.generation_size = self.generation_size + doc.stop_sequences = self.stop_sequence + docs.append(doc) return docs From cfd61a1cf4716eb5e50c8bb08046256ab32dbf99 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Wed, 20 Aug 2025 15:42:15 +0200 Subject: [PATCH 07/20] update: TGI model config in examples with the latest parameters --- examples/model_configs/tgi_model.yaml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/model_configs/tgi_model.yaml b/examples/model_configs/tgi_model.yaml index 34dbaa831..55525b4b8 100644 --- a/examples/model_configs/tgi_model.yaml +++ b/examples/model_configs/tgi_model.yaml @@ -1,4 +1,10 @@ model_parameters: - inference_server_address: "" + inference_server_address: "http://localhost:8080" # Replace with your actual TGI server address inference_server_auth: null model_name: null # Optional, only required if the TGI container was launched with model_id pointing to a local directory + batch_size: 1 # Batch size for inference + +generation: + temperature: 0.1 + max_new_tokens: 256 + top_p: 0.9 From e144ddd7cc1620373088a2a58c68bb36261f4139 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Wed, 20 Aug 2025 15:42:55 +0200 Subject: [PATCH 08/20] add: example custom task on a classification dataset to demonstrate the usage of constrained grammar generation using TGI --- ...custom_task_classification_grammar_task.py | 437 ++++++++++++++++++ 1 file changed, 437 insertions(+) create mode 100644 examples/custom_tasks_templates/custom_task_classification_grammar_task.py diff --git a/examples/custom_tasks_templates/custom_task_classification_grammar_task.py b/examples/custom_tasks_templates/custom_task_classification_grammar_task.py new file mode 100644 index 000000000..6dbe994c0 --- /dev/null +++ b/examples/custom_tasks_templates/custom_task_classification_grammar_task.py @@ -0,0 +1,437 @@ +"""Emotion Classification Task with Grammar Constraints using LightEval + +This module demonstrates how to create a classification task in LightEval with JSON grammar-constrained generation for structured responses. + + +The task performs emotion classification on the 'emotion' dataset from HuggingFace Hub, +classifying text into one of six emotion categories: sadness, joy, love, anger, fear, surprise. + +Example usage: + TGI endpoint evaluation: + ```bash + uv run --active --extra litellm --extra tgi lighteval endpoint tgi examples/model_configs/tgi_model.yaml "custom|emotion_classification|0|0" + --custom-tasks examples/custom_tasks_templates/custom_task_classification_grammar_task.py + --output-dir results + --save-details + --no-public-run + ``` + +Dataset: + The task uses the 'emotion' dataset from HuggingFace Hub, which contains + English Twitter messages labeled with one of six emotions. The dataset + includes train/validation/test splits with the following distribution: + - Total samples: ~416k (train: ~16k, validation: ~2k, test: ~2k) + - Labels: sadness, joy, love, anger, fear, surprise + - Text format: Short social media posts in English + +Customization: + To adapt this task for other classification problems: + 1. Update EMOTION_LABELS with your target labels + 2. Modify prompt_emotion_classification() for your use case + 3. Update the grammar schema in get_emotion_classification_grammar() + 4. Adjust the HuggingFace dataset reference in EMOTION_CLASSIFICATION_TASK + 5. Update metric calculations in emotion_classification_metric() if needed +""" + +import json +import logging +from typing import Any + +import numpy as np + +from lighteval.metrics.utils.metric_utils import SampleLevelMetricGrouping +from lighteval.models.model_output import ModelResponse +from lighteval.tasks.lighteval_task import ( + LightevalTaskConfig, + TextGenerationInputGrammarType, +) +from lighteval.tasks.requests import Doc, SamplingMethod + +logger = logging.getLogger(__name__) + +# Emotion labels for the emotion dataset from HuggingFace Hub +# These correspond to the 6-class emotion classification task with the following mapping: +# 0: sadness, 1: joy, 2: love, 3: anger, 4: fear, 5: surprise +EMOTION_LABELS = ["sadness", "joy", "love", "anger", "fear", "surprise"] + + +def parse_emotion_response(response: str | dict) -> dict[str, Any]: + """Parse the model's response into a standardized format. + + This function handles both JSON string and dictionary inputs, providing robust + parsing with validation against the predefined emotion labels. Invalid predictions + are automatically mapped to 'unknown' with appropriate logging. + + Args: + response (str | dict): The model's response, either as a JSON string + containing {"classification": "emotion_label"} or as a dictionary + with the same structure. + + Returns: + dict[str, Any]: Standardized dictionary containing: + - classification (str): The predicted emotion label, validated against + EMOTION_LABELS or 'unknown' if invalid/unparseable + + Examples: + >>> parse_emotion_response('{"classification": "joy"}') + {'classification': 'joy'} + + >>> parse_emotion_response({'classification': 'ANGER'}) + {'classification': 'anger'} + + >>> parse_emotion_response('{"classification": "invalid_emotion"}') + {'classification': 'unknown'} # with warning logged + + >>> parse_emotion_response('malformed json') + {'classification': 'unknown'} # with error logged + + Note: + - Case-insensitive matching: 'ANGER' and 'Anger' are normalized to 'anger' + - Whitespace is automatically stripped from predictions + - All parsing errors result in 'unknown' classification with detailed logging + """ + try: + # Handle dictionary input (already parsed JSON) + if isinstance(response, dict): + result = response + # Handle string input (JSON string that needs parsing) + else: + result = json.loads(response.strip()) + + # Extract and normalize the predicted emotion + predicted_emotion = result["classification"].lower().strip() + + # Validate that the prediction is one of the valid emotion labels + if predicted_emotion not in EMOTION_LABELS: + logger.warning( + f"Invalid emotion prediction: '{predicted_emotion}'. " + f"Expected one of {EMOTION_LABELS}. Using 'unknown'." + ) + predicted_emotion = "unknown" + + return { + "classification": predicted_emotion, + } + except (json.JSONDecodeError, KeyError, AttributeError, TypeError) as e: + # Handle specific parsing errors with detailed logging + logger.error(f"Error parsing response: {str(e)}") + logger.error(f"Failed response was: {response}") + logger.error("Expected format: {\"classification\": \"emotion_label\"}") + return { + "classification": "unknown", + } + except Exception as e: + # Catch any other unexpected errors + logger.error(f"Unexpected error parsing response: {str(e)}") + logger.error(f"Failed response was: {response}") + return { + "classification": "unknown", + } + + +def emotion_classification_metric( + model_response: ModelResponse, doc: Doc, **kwargs +) -> dict[str, float]: + """Evaluate emotion classification predictions at the sample level. + + This function computes evaluation metrics for a single prediction, comparing + the model's emotion classification against the gold standard. It provides + detailed logging for debugging and tracks prediction quality. + + Args: + model_response (ModelResponse): The model's response containing generated text + in the text attribute, typically containing one prediction as either a + JSON string or dictionary with format {"classification": "emotion_label"} + doc (Doc): The document containing the query, choices, and gold + standard information. Must have gold_index attribute pointing to the + correct emotion label index. + **kwargs: Additional keyword arguments (unused but required for compatibility + with LightEval's metric interface) + + Returns: + dict[str, float]: Dictionary containing sample-level metrics: + - exact_match (float): 1.0 if prediction matches gold label, 0.0 otherwise + - unknown_prediction (float): 1.0 if prediction was 'unknown' (parsing + failure), 0.0 otherwise + - total_samples (float): Always 1.0 (count for this sample) + + Examples: + >>> doc = Doc(query="I'm so happy!", gold_index=2) # joy + >>> model_response = ModelResponse(text=['{"classification": "joy"}'], ...) + >>> result = emotion_classification_metric(model_response, doc) + >>> result + {'exact_match': 1.0, 'unknown_prediction': 0.0, 'total_samples': 1.0} + + >>> model_response = ModelResponse(text=['{"classification": "sadness"}'], ...) + >>> result = emotion_classification_metric(model_response, doc) + >>> result + {'exact_match': 0.0, 'unknown_prediction': 0.0, 'total_samples': 1.0} + + Note: + - The function expects exactly one prediction in the model_response.text list + - Gold labels are mapped from integer indices to emotion label strings + - All errors in prediction parsing result in 'unknown' classification + - Detailed logging is provided for debugging classification performance + """ + try: + # Parse the first (and typically only) prediction + prediction = parse_emotion_response(model_response.text[0]) + + # Map the gold label index to the corresponding emotion string + # The emotion dataset uses integer indices: 0=anger, 1=fear, 2=joy, etc. + gold_label_idx = doc.gold_index + expected_emotion = EMOTION_LABELS[gold_label_idx] + + # Log detailed information for debugging and analysis + logger.info("-" * 50) + logger.info("Processing new sample") + logger.info(f"- Text: {doc.query}") + logger.info(f"- Prediction: {prediction}") + logger.info(f"- Expected: {expected_emotion} (index: {gold_label_idx})") + + # Calculate evaluation metrics + is_exact_match = prediction["classification"] == expected_emotion + is_unknown = prediction["classification"] == "unknown" + + metrics = { + "exact_match": float(is_exact_match), + "unknown_prediction": float(is_unknown), + "total_samples": 1.0, + } + + logger.info(f"- Metrics: {metrics}") + if is_exact_match: + logger.info("✓ Correct prediction") + elif is_unknown: + logger.info("⚠ Parsing failure (unknown prediction)") + else: + logger.info("✗ Incorrect prediction") + logger.info("-" * 50) + + return metrics + + except (IndexError, KeyError) as e: + # Handle errors related to accessing gold label or prediction structure + logger.error(f"Error accessing gold label or prediction: {str(e)}") + logger.error(f"Gold index: {getattr(doc, 'gold_index', 'N/A')}") + logger.error(f"Raw prediction: {model_response.text[0] if model_response.text else 'Empty predictions'}") + return { + "exact_match": 0.0, + "unknown_prediction": 1.0, + "total_samples": 1.0, + } + except Exception as e: + # Handle any other unexpected errors + logger.error(f"Unexpected error processing prediction: {str(e)}") + logger.error(f"Raw prediction was: {model_response.text[0] if model_response.text else 'Empty predictions'}") + return { + "exact_match": 0.0, + "unknown_prediction": 1.0, + "total_samples": 1.0, + } + + +# Define the metric group for emotion classification evaluation +# This configures both sample-level and corpus-level metric calculations +emotion_classification_group = SampleLevelMetricGrouping( + metric_name=[ + "exact_match", # Primary accuracy metric + "unknown_prediction", # Tracks parsing failures + "total_samples", # Sample count for aggregation + ], + higher_is_better={ + "exact_match": True, # Higher accuracy is better + "unknown_prediction": False, # Fewer parsing failures is better + "total_samples": True, # More samples processed is better + }, + category=SamplingMethod.GENERATIVE, # Classification via text generation + sample_level_fn=emotion_classification_metric, # Function for individual samples + corpus_level_fn={ + "exact_match": np.mean, # Average accuracy across all samples + "unknown_prediction": np.mean, # Proportion of parsing failures + "total_samples": np.sum, # Total number of samples processed + }, +) + + +def prompt_emotion_classification(line: dict[str, Any], task_name: str = None) -> Doc: + """Format the emotion classification task with detailed prompt engineering. + + This function converts a single sample from the emotion dataset into a structured + prompt that provides clear instructions and emotion definitions to improve + classification accuracy. The prompt includes detailed explanations of each + emotion category to reduce ambiguity. + + Args: + line (dict[str, Any]): A single sample from the emotion dataset containing: + - 'text' (str): The input text to classify + - 'label' (int): The gold standard emotion label (0-5) + task_name (str, optional): Name of the task for identification purposes. + Defaults to None. + + Returns: + Doc: A formatted document object containing: + - task_name: Task identifier + - query: The formatted prompt with text and emotion definitions + - choices: List of available emotion labels + - gold_index: The correct emotion label index + - instruction: Empty string (instructions are embedded in query) + + Examples: + >>> line = {'text': 'I am so excited for tomorrow!', 'label': 2} + >>> doc = prompt_emotion_classification(line, 'emotion_test') + >>> print(doc.query) + Classify the emotion expressed in the following text: "I am so excited for tomorrow!" + ... + >>> doc.gold_index + 2 + >>> doc.choices + ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise'] + + Note: + - The prompt includes detailed definitions for each emotion to improve accuracy + - Emotion definitions are based on common psychological categorizations + - The format is optimized for both human readability and model understanding + """ + # Extract the text to be classified + text = line["text"] + + # Create a comprehensive classification prompt with detailed emotion definitions + # This approach helps models understand the subtle differences between emotions + prompt = f"""Classify the emotion expressed in the following text: "{text}" + +Available emotion labels and their meanings: +- sadness: Feeling of sorrow, grief, or unhappiness. Covers melancholy, disappointment, + loss, or general negative emotional states related to unfortunate circumstances. +- joy: Feeling of happiness, delight, or pleasure. Encompasses positive emotions like + excitement, satisfaction, contentment, and general well-being. +- love: Feeling of affection, care, or romantic attachment. Includes expressions of + deep fondness, romantic interest, or strong positive feelings toward people or things. +- anger: Feeling of displeasure, hostility, or annoyance. Often involves frustration, + irritation, or aggressive sentiments toward people, situations, or objects. +- fear: Feeling of anxiety, worry, or being afraid. Includes nervousness, concern + about future events, or apprehension about potential threats or negative outcomes. +- surprise: Feeling of astonishment or being caught off guard. Includes unexpected + reactions, amazement, or responses to sudden or unanticipated events. + +Choose the emotion that best matches the sentiment expressed in the text.""" + + return Doc( + task_name=task_name, + query=prompt, + choices=EMOTION_LABELS, # Available emotion label options + gold_index=line["label"], # Gold standard emotion index (0-5) + instruction="", # Instructions are embedded in the query + ) + + +def get_emotion_classification_grammar() -> TextGenerationInputGrammarType: + """Define the JSON schema grammar for constrained emotion classification responses. + + This function creates a strict JSON schema that constrains the model's output + to only valid emotion labels, preventing hallucination and ensuring consistent + response format. The grammar constraint is enforced during text generation. + + Returns: + TextGenerationInputGrammarType: A JSON schema grammar specification that: + - Enforces JSON object structure with required "classification" field + - Constrains classification values to only valid emotion labels + - Ensures consistent response parsing across different models + + Schema Structure: + { + "type": "object", + "properties": { + "classification": { + "type": "string", + "description": "Emotion classification", + "enum": ["anger", "fear", "joy", "love", "sadness", "surprise"] + } + }, + "required": ["classification"] + } + + Examples: + Valid responses that match this grammar: + - {"classification": "joy"} + - {"classification": "anger"} + + Invalid responses that would be rejected: + - {"emotion": "joy"} # Wrong field name + - {"classification": "happy"} # Invalid emotion label + - "joy" # Not a JSON object + + Note: + - This grammar constraint significantly improves response consistency + - It prevents the model from generating invalid emotion labels + - Compatible with grammar-enabled backends like vLLM, TGI, and others + - The enum constraint is crucial for maintaining label consistency + """ + return TextGenerationInputGrammarType( + type="json", # Specify JSON schema grammar type + value={ + "type": "object", # Require JSON object structure + "properties": { + "classification": { + "type": "string", # Classification must be a string + "description": "Emotion classification from the provided list", + "enum": EMOTION_LABELS, # Strictly constrain to valid emotion labels only + }, + }, + "required": ["classification"], # Classification field is mandatory + "additionalProperties": False, # Prevent extra fields in response + }, + ) + + +# Task configuration for emotion classification using the HuggingFace emotion dataset +# This configuration optimizes for accuracy while maintaining efficient resource usage +EMOTION_CLASSIFICATION_TASK = LightevalTaskConfig( + name="emotion_classification", # Unique task identifier + prompt_function=prompt_emotion_classification, # Custom prompt formatting function + suite=["custom"], # Classification as a community/custom task + hf_repo="emotion", # HuggingFace Hub dataset repository + hf_subset=None, # Use default subset (no subset specified) + metrics=[emotion_classification_group], # Evaluation metrics configuration + generation_size=64, # Conservative token limit for JSON responses (~30-40 tokens typical) + generation_grammar=get_emotion_classification_grammar(), # JSON schema constraint + stop_sequence=["\n\n"], # Early stopping on double newline + trust_dataset=True, # Trust the HuggingFace dataset (required for emotion dataset) + evaluation_splits=["test"], # Evaluate on test split only + hf_avail_splits=["train", "validation", "test"], # Available dataset splits + # Additional configuration notes: + # - generation_size is kept small since responses are simple JSON objects + # - Grammar constraint ensures valid JSON structure and emotion labels + # - Using test split for evaluation follows standard ML practices + # - trust_dataset=True is required for datasets that need additional verification +) + +# Export the task for LightEval discovery +# This list is automatically detected by LightEval when loading custom tasks +TASKS_TABLE = [EMOTION_CLASSIFICATION_TASK] + +# Development and testing utilities +if __name__ == "__main__": + # Print available tasks for verification + print("Available tasks:", [t.name for t in TASKS_TABLE]) + print("Total tasks:", len(TASKS_TABLE)) + + # Print task configuration summary for debugging + task = TASKS_TABLE[0] + print(f"\nTask Configuration Summary:") + print(f" Name: {task.name}") + print(f" Dataset: {task.hf_repo}") + print(f" Splits: {task.evaluation_splits}") + print(f" Metrics: {[m.metric_name for m in task.metric]}") + print(f" Generation size: {task.generation_size}") + print(f" Grammar constrained: {task.generation_grammar is not None}") + print(f" Stop sequences: {task.stop_sequence}") + + # Verify emotion labels configuration + print(f"\nEmotion Labels ({len(EMOTION_LABELS)}):") + for i, label in enumerate(EMOTION_LABELS): + print(f" {i}: {label}") + + print(f"\nUsage Examples:") + print(f" TGI: uv run lighteval endpoint tgi config/tgi/tgi.yaml 'custom|{task.name}|0|0' --custom-tasks {__file__} --output-dir results --override-batch-size 1 --use-chat-template --save-details --no-public-run --max-samples 10") + print(f" Full: uv run lighteval endpoint tgi config/tgi/tgi.yaml 'custom|{task.name}|5|1' --custom-tasks {__file__} --output-dir results --override-batch-size 1 --use-chat-template --save-details --no-public-run") From c95c88dc2441a118c919f94ce12ba9e921ed0e93 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Wed, 20 Aug 2025 15:54:56 +0200 Subject: [PATCH 09/20] add: format example task --- ...custom_task_classification_grammar_task.py | 95 ++++++++++--------- 1 file changed, 49 insertions(+), 46 deletions(-) diff --git a/examples/custom_tasks_templates/custom_task_classification_grammar_task.py b/examples/custom_tasks_templates/custom_task_classification_grammar_task.py index 6dbe994c0..e4d54f935 100644 --- a/examples/custom_tasks_templates/custom_task_classification_grammar_task.py +++ b/examples/custom_tasks_templates/custom_task_classification_grammar_task.py @@ -47,6 +47,7 @@ ) from lighteval.tasks.requests import Doc, SamplingMethod + logger = logging.getLogger(__name__) # Emotion labels for the emotion dataset from HuggingFace Hub @@ -57,34 +58,34 @@ def parse_emotion_response(response: str | dict) -> dict[str, Any]: """Parse the model's response into a standardized format. - + This function handles both JSON string and dictionary inputs, providing robust parsing with validation against the predefined emotion labels. Invalid predictions are automatically mapped to 'unknown' with appropriate logging. - + Args: response (str | dict): The model's response, either as a JSON string containing {"classification": "emotion_label"} or as a dictionary with the same structure. - + Returns: dict[str, Any]: Standardized dictionary containing: - classification (str): The predicted emotion label, validated against EMOTION_LABELS or 'unknown' if invalid/unparseable - + Examples: >>> parse_emotion_response('{"classification": "joy"}') {'classification': 'joy'} - + >>> parse_emotion_response({'classification': 'ANGER'}) {'classification': 'anger'} - + >>> parse_emotion_response('{"classification": "invalid_emotion"}') {'classification': 'unknown'} # with warning logged - + >>> parse_emotion_response('malformed json') {'classification': 'unknown'} # with error logged - + Note: - Case-insensitive matching: 'ANGER' and 'Anger' are normalized to 'anger' - Whitespace is automatically stripped from predictions @@ -100,7 +101,7 @@ def parse_emotion_response(response: str | dict) -> dict[str, Any]: # Extract and normalize the predicted emotion predicted_emotion = result["classification"].lower().strip() - + # Validate that the prediction is one of the valid emotion labels if predicted_emotion not in EMOTION_LABELS: logger.warning( @@ -116,7 +117,7 @@ def parse_emotion_response(response: str | dict) -> dict[str, Any]: # Handle specific parsing errors with detailed logging logger.error(f"Error parsing response: {str(e)}") logger.error(f"Failed response was: {response}") - logger.error("Expected format: {\"classification\": \"emotion_label\"}") + logger.error('Expected format: {"classification": "emotion_label"}') return { "classification": "unknown", } @@ -129,15 +130,13 @@ def parse_emotion_response(response: str | dict) -> dict[str, Any]: } -def emotion_classification_metric( - model_response: ModelResponse, doc: Doc, **kwargs -) -> dict[str, float]: +def emotion_classification_metric(model_response: ModelResponse, doc: Doc, **kwargs) -> dict[str, float]: """Evaluate emotion classification predictions at the sample level. - + This function computes evaluation metrics for a single prediction, comparing the model's emotion classification against the gold standard. It provides detailed logging for debugging and tracks prediction quality. - + Args: model_response (ModelResponse): The model's response containing generated text in the text attribute, typically containing one prediction as either a @@ -147,26 +146,26 @@ def emotion_classification_metric( correct emotion label index. **kwargs: Additional keyword arguments (unused but required for compatibility with LightEval's metric interface) - + Returns: dict[str, float]: Dictionary containing sample-level metrics: - exact_match (float): 1.0 if prediction matches gold label, 0.0 otherwise - unknown_prediction (float): 1.0 if prediction was 'unknown' (parsing failure), 0.0 otherwise - total_samples (float): Always 1.0 (count for this sample) - + Examples: >>> doc = Doc(query="I'm so happy!", gold_index=2) # joy >>> model_response = ModelResponse(text=['{"classification": "joy"}'], ...) >>> result = emotion_classification_metric(model_response, doc) >>> result {'exact_match': 1.0, 'unknown_prediction': 0.0, 'total_samples': 1.0} - + >>> model_response = ModelResponse(text=['{"classification": "sadness"}'], ...) >>> result = emotion_classification_metric(model_response, doc) >>> result {'exact_match': 0.0, 'unknown_prediction': 0.0, 'total_samples': 1.0} - + Note: - The function expects exactly one prediction in the model_response.text list - Gold labels are mapped from integer indices to emotion label strings @@ -176,7 +175,7 @@ def emotion_classification_metric( try: # Parse the first (and typically only) prediction prediction = parse_emotion_response(model_response.text[0]) - + # Map the gold label index to the corresponding emotion string # The emotion dataset uses integer indices: 0=anger, 1=fear, 2=joy, etc. gold_label_idx = doc.gold_index @@ -192,7 +191,7 @@ def emotion_classification_metric( # Calculate evaluation metrics is_exact_match = prediction["classification"] == expected_emotion is_unknown = prediction["classification"] == "unknown" - + metrics = { "exact_match": float(is_exact_match), "unknown_prediction": float(is_unknown), @@ -235,40 +234,40 @@ def emotion_classification_metric( # This configures both sample-level and corpus-level metric calculations emotion_classification_group = SampleLevelMetricGrouping( metric_name=[ - "exact_match", # Primary accuracy metric + "exact_match", # Primary accuracy metric "unknown_prediction", # Tracks parsing failures - "total_samples", # Sample count for aggregation + "total_samples", # Sample count for aggregation ], higher_is_better={ - "exact_match": True, # Higher accuracy is better - "unknown_prediction": False, # Fewer parsing failures is better - "total_samples": True, # More samples processed is better + "exact_match": True, # Higher accuracy is better + "unknown_prediction": False, # Fewer parsing failures is better + "total_samples": True, # More samples processed is better }, category=SamplingMethod.GENERATIVE, # Classification via text generation sample_level_fn=emotion_classification_metric, # Function for individual samples corpus_level_fn={ - "exact_match": np.mean, # Average accuracy across all samples + "exact_match": np.mean, # Average accuracy across all samples "unknown_prediction": np.mean, # Proportion of parsing failures - "total_samples": np.sum, # Total number of samples processed + "total_samples": np.sum, # Total number of samples processed }, ) def prompt_emotion_classification(line: dict[str, Any], task_name: str = None) -> Doc: """Format the emotion classification task with detailed prompt engineering. - + This function converts a single sample from the emotion dataset into a structured prompt that provides clear instructions and emotion definitions to improve classification accuracy. The prompt includes detailed explanations of each emotion category to reduce ambiguity. - + Args: line (dict[str, Any]): A single sample from the emotion dataset containing: - 'text' (str): The input text to classify - 'label' (int): The gold standard emotion label (0-5) task_name (str, optional): Name of the task for identification purposes. Defaults to None. - + Returns: Doc: A formatted document object containing: - task_name: Task identifier @@ -276,7 +275,7 @@ def prompt_emotion_classification(line: dict[str, Any], task_name: str = None) - - choices: List of available emotion labels - gold_index: The correct emotion label index - instruction: Empty string (instructions are embedded in query) - + Examples: >>> line = {'text': 'I am so excited for tomorrow!', 'label': 2} >>> doc = prompt_emotion_classification(line, 'emotion_test') @@ -287,7 +286,7 @@ def prompt_emotion_classification(line: dict[str, Any], task_name: str = None) - 2 >>> doc.choices ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise'] - + Note: - The prompt includes detailed definitions for each emotion to improve accuracy - Emotion definitions are based on common psychological categorizations @@ -295,7 +294,7 @@ def prompt_emotion_classification(line: dict[str, Any], task_name: str = None) - """ # Extract the text to be classified text = line["text"] - + # Create a comprehensive classification prompt with detailed emotion definitions # This approach helps models understand the subtle differences between emotions prompt = f"""Classify the emotion expressed in the following text: "{text}" @@ -315,7 +314,7 @@ def prompt_emotion_classification(line: dict[str, Any], task_name: str = None) - reactions, amazement, or responses to sudden or unanticipated events. Choose the emotion that best matches the sentiment expressed in the text.""" - + return Doc( task_name=task_name, query=prompt, @@ -327,17 +326,17 @@ def prompt_emotion_classification(line: dict[str, Any], task_name: str = None) - def get_emotion_classification_grammar() -> TextGenerationInputGrammarType: """Define the JSON schema grammar for constrained emotion classification responses. - + This function creates a strict JSON schema that constrains the model's output to only valid emotion labels, preventing hallucination and ensuring consistent response format. The grammar constraint is enforced during text generation. - + Returns: TextGenerationInputGrammarType: A JSON schema grammar specification that: - Enforces JSON object structure with required "classification" field - Constrains classification values to only valid emotion labels - Ensures consistent response parsing across different models - + Schema Structure: { "type": "object", @@ -350,17 +349,17 @@ def get_emotion_classification_grammar() -> TextGenerationInputGrammarType: }, "required": ["classification"] } - + Examples: Valid responses that match this grammar: - {"classification": "joy"} - {"classification": "anger"} - + Invalid responses that would be rejected: - {"emotion": "joy"} # Wrong field name - {"classification": "happy"} # Invalid emotion label - "joy" # Not a JSON object - + Note: - This grammar constraint significantly improves response consistency - It prevents the model from generating invalid emotion labels @@ -415,7 +414,7 @@ def get_emotion_classification_grammar() -> TextGenerationInputGrammarType: # Print available tasks for verification print("Available tasks:", [t.name for t in TASKS_TABLE]) print("Total tasks:", len(TASKS_TABLE)) - + # Print task configuration summary for debugging task = TASKS_TABLE[0] print(f"\nTask Configuration Summary:") @@ -426,12 +425,16 @@ def get_emotion_classification_grammar() -> TextGenerationInputGrammarType: print(f" Generation size: {task.generation_size}") print(f" Grammar constrained: {task.generation_grammar is not None}") print(f" Stop sequences: {task.stop_sequence}") - + # Verify emotion labels configuration print(f"\nEmotion Labels ({len(EMOTION_LABELS)}):") for i, label in enumerate(EMOTION_LABELS): print(f" {i}: {label}") - + print(f"\nUsage Examples:") - print(f" TGI: uv run lighteval endpoint tgi config/tgi/tgi.yaml 'custom|{task.name}|0|0' --custom-tasks {__file__} --output-dir results --override-batch-size 1 --use-chat-template --save-details --no-public-run --max-samples 10") - print(f" Full: uv run lighteval endpoint tgi config/tgi/tgi.yaml 'custom|{task.name}|5|1' --custom-tasks {__file__} --output-dir results --override-batch-size 1 --use-chat-template --save-details --no-public-run") + print( + f" TGI: uv run lighteval endpoint tgi config/tgi/tgi.yaml 'custom|{task.name}|0|0' --custom-tasks {__file__} --output-dir results --override-batch-size 1 --use-chat-template --save-details --no-public-run --max-samples 10" + ) + print( + f" Full: uv run lighteval endpoint tgi config/tgi/tgi.yaml 'custom|{task.name}|5|1' --custom-tasks {__file__} --output-dir results --override-batch-size 1 --use-chat-template --save-details --no-public-run" + ) From 9df8575d7ab645e2a672dbf86963bf6b8459ceb1 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Wed, 20 Aug 2025 16:16:47 +0200 Subject: [PATCH 10/20] fix: unit test --- tests/models/endpoints/test_tgi_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/models/endpoints/test_tgi_model.py b/tests/models/endpoints/test_tgi_model.py index 969dcfb19..a6f1ce92d 100644 --- a/tests/models/endpoints/test_tgi_model.py +++ b/tests/models/endpoints/test_tgi_model.py @@ -33,11 +33,12 @@ class TestTGIModelConfig: ( "examples/model_configs/tgi_model.yaml", { - "inference_server_address": "", + "inference_server_address": "http://localhost:8080", "inference_server_auth": None, "model_name": None, "model_info": None, "system_prompt": None, + "batch_size": 1, "generation_parameters": { "block_size": None, "num_blocks": None, @@ -53,7 +54,7 @@ class TestTGIModelConfig: "repetition_penalty": None, "seed": None, "stop_tokens": None, - "temperature": 0, + "temperature": 0.1, "top_k": None, "top_p": None, "truncate_prompt": None, From 33cebe95219d418c8a6f1ae7cad75b8a8b5e2323 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Wed, 20 Aug 2025 16:17:39 +0200 Subject: [PATCH 11/20] add: adapt the in the yaml config to use similarly to the other endpoints --- examples/model_configs/tgi_model.yaml | 10 +++------- src/lighteval/main_endpoint.py | 11 +++++++++-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/examples/model_configs/tgi_model.yaml b/examples/model_configs/tgi_model.yaml index 55525b4b8..a31ba0f7b 100644 --- a/examples/model_configs/tgi_model.yaml +++ b/examples/model_configs/tgi_model.yaml @@ -1,10 +1,6 @@ model_parameters: - inference_server_address: "http://localhost:8080" # Replace with your actual TGI server address + inference_server_address: "http://localhost:8080" # Replace with your actual TGI server address inference_server_auth: null model_name: null # Optional, only required if the TGI container was launched with model_id pointing to a local directory - batch_size: 1 # Batch size for inference - -generation: - temperature: 0.1 - max_new_tokens: 256 - top_p: 0.9 + generation_parameters: + temperature: 0.1 diff --git a/src/lighteval/main_endpoint.py b/src/lighteval/main_endpoint.py index ccd1582eb..2d4ad517f 100644 --- a/src/lighteval/main_endpoint.py +++ b/src/lighteval/main_endpoint.py @@ -272,8 +272,15 @@ def tgi( with open(model_config_path, "r") as f: config = yaml.safe_load(f) - generation_parameters = GenerationParameters(**config.get("generation", {})) - model_config = TGIModelConfig(**config["model_parameters"], generation_parameters=generation_parameters) + # Extract generation_parameters from model_parameters if they exist + model_params = config["model_parameters"].copy() + yaml_gen_params = model_params.pop("generation_parameters", {}) + + # Start with defaults and override with YAML values + generation_parameters = GenerationParameters(**yaml_gen_params) + + # Create model config without generation_parameters in model_params + model_config = TGIModelConfig(**model_params, generation_parameters=generation_parameters) pipeline_params = PipelineParameters( launcher_type=parallelism_manager, From 93a645b9b8e4980c9b29ff10f80dfa47cd916ccb Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Wed, 20 Aug 2025 17:37:49 +0200 Subject: [PATCH 12/20] clean: moved new task to community_tasks --- .../custom_task_classification_grammar_task.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {examples/custom_tasks_templates => community_tasks}/custom_task_classification_grammar_task.py (100%) diff --git a/examples/custom_tasks_templates/custom_task_classification_grammar_task.py b/community_tasks/custom_task_classification_grammar_task.py similarity index 100% rename from examples/custom_tasks_templates/custom_task_classification_grammar_task.py rename to community_tasks/custom_task_classification_grammar_task.py From a34353b361d2546b16432fb47db53ee997d06177 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Wed, 20 Aug 2025 17:41:52 +0200 Subject: [PATCH 13/20] fix: format --- .../custom_task_classification_grammar_task.py | 9 ++++----- src/lighteval/main_endpoint.py | 4 ++-- src/lighteval/models/endpoints/tgi_model.py | 4 ++-- src/lighteval/tasks/lighteval_task.py | 4 ++-- tests/logging/test_evaluation_tracker.py | 1 - tests/models/endpoints/test_tgi_model.py | 1 - 6 files changed, 10 insertions(+), 13 deletions(-) diff --git a/community_tasks/custom_task_classification_grammar_task.py b/community_tasks/custom_task_classification_grammar_task.py index e4d54f935..c7130f50b 100644 --- a/community_tasks/custom_task_classification_grammar_task.py +++ b/community_tasks/custom_task_classification_grammar_task.py @@ -115,16 +115,15 @@ def parse_emotion_response(response: str | dict) -> dict[str, Any]: } except (json.JSONDecodeError, KeyError, AttributeError, TypeError) as e: # Handle specific parsing errors with detailed logging - logger.error(f"Error parsing response: {str(e)}") - logger.error(f"Failed response was: {response}") - logger.error('Expected format: {"classification": "emotion_label"}') + logger.error( + f"Error parsing response: {str(e)}. Failed response was: {response}. Expected format: {{'classification': 'emotion_label'}}" + ) return { "classification": "unknown", } except Exception as e: # Catch any other unexpected errors - logger.error(f"Unexpected error parsing response: {str(e)}") - logger.error(f"Failed response was: {response}") + logger.error(f"Unexpected error parsing response: {str(e)}. Failed response was: {response}") return { "classification": "unknown", } diff --git a/src/lighteval/main_endpoint.py b/src/lighteval/main_endpoint.py index aeb86d62c..969c742df 100644 --- a/src/lighteval/main_endpoint.py +++ b/src/lighteval/main_endpoint.py @@ -275,10 +275,10 @@ def tgi( # Extract generation_parameters from model_parameters if they exist model_params = config["model_parameters"].copy() yaml_gen_params = model_params.pop("generation_parameters", {}) - + # Start with defaults and override with YAML values generation_parameters = GenerationParameters(**yaml_gen_params) - + # Create model config without generation_parameters in model_params model_config = TGIModelConfig(**model_params, generation_parameters=generation_parameters) diff --git a/src/lighteval/models/endpoints/tgi_model.py b/src/lighteval/models/endpoints/tgi_model.py index 350196a87..91940f88f 100644 --- a/src/lighteval/models/endpoints/tgi_model.py +++ b/src/lighteval/models/endpoints/tgi_model.py @@ -121,12 +121,12 @@ def __init__(self, config: TGIModelConfig) -> None: self._add_special_tokens = True self.use_async = True self.config.model_info = self.model_info - + # Initialize prompt manager (required by parent class) self.prompt_manager = PromptManager( use_chat_template=True, tokenizer=self.tokenizer, system_prompt=config.system_prompt ) - + # Initialize cache for tokenization and predictions self._cache = SampleCache(config) diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index 7c1df4d36..4749a25a0 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -251,12 +251,12 @@ def _get_docs_from_split(self, splits: list[str], few_shots=False) -> list[Doc]: item["__index"] = ix doc = self.formatter(item, self.name) doc.id = str(ix) - + # Transfer task-level generation parameters to the document doc.generation_grammar = self.generation_grammar doc.generation_size = self.generation_size doc.stop_sequences = self.stop_sequence - + docs.append(doc) return docs diff --git a/tests/logging/test_evaluation_tracker.py b/tests/logging/test_evaluation_tracker.py index 8d464e963..ba4517245 100644 --- a/tests/logging/test_evaluation_tracker.py +++ b/tests/logging/test_evaluation_tracker.py @@ -241,7 +241,6 @@ def setUp(self): "presence_penalty": None, "max_new_tokens": None, "min_new_tokens": None, - "grammar": None, "seed": None, "stop_tokens": None, "temperature": 0, diff --git a/tests/models/endpoints/test_tgi_model.py b/tests/models/endpoints/test_tgi_model.py index a6f1ce92d..e784bc0d4 100644 --- a/tests/models/endpoints/test_tgi_model.py +++ b/tests/models/endpoints/test_tgi_model.py @@ -45,7 +45,6 @@ class TestTGIModelConfig: "cache_implementation": None, "early_stopping": None, "frequency_penalty": None, - "grammar": None, "length_penalty": None, "max_new_tokens": None, "min_new_tokens": None, From 88da108f6fd6ead9aa38be8f99a8a4139fcfd5eb Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Wed, 20 Aug 2025 18:03:14 +0200 Subject: [PATCH 14/20] clean: delete unused grammar field --- .../custom_task_classification_grammar_task.py | 10 ++-------- src/lighteval/models/model_input.py | 2 -- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/community_tasks/custom_task_classification_grammar_task.py b/community_tasks/custom_task_classification_grammar_task.py index c7130f50b..ac18a0e5d 100644 --- a/community_tasks/custom_task_classification_grammar_task.py +++ b/community_tasks/custom_task_classification_grammar_task.py @@ -394,14 +394,8 @@ def get_emotion_classification_grammar() -> TextGenerationInputGrammarType: generation_size=64, # Conservative token limit for JSON responses (~30-40 tokens typical) generation_grammar=get_emotion_classification_grammar(), # JSON schema constraint stop_sequence=["\n\n"], # Early stopping on double newline - trust_dataset=True, # Trust the HuggingFace dataset (required for emotion dataset) evaluation_splits=["test"], # Evaluate on test split only hf_avail_splits=["train", "validation", "test"], # Available dataset splits - # Additional configuration notes: - # - generation_size is kept small since responses are simple JSON objects - # - Grammar constraint ensures valid JSON structure and emotion labels - # - Using test split for evaluation follows standard ML practices - # - trust_dataset=True is required for datasets that need additional verification ) # Export the task for LightEval discovery @@ -416,7 +410,7 @@ def get_emotion_classification_grammar() -> TextGenerationInputGrammarType: # Print task configuration summary for debugging task = TASKS_TABLE[0] - print(f"\nTask Configuration Summary:") + print("\nTask Configuration Summary:") print(f" Name: {task.name}") print(f" Dataset: {task.hf_repo}") print(f" Splits: {task.evaluation_splits}") @@ -430,7 +424,7 @@ def get_emotion_classification_grammar() -> TextGenerationInputGrammarType: for i, label in enumerate(EMOTION_LABELS): print(f" {i}: {label}") - print(f"\nUsage Examples:") + print("\nUsage Examples:") print( f" TGI: uv run lighteval endpoint tgi config/tgi/tgi.yaml 'custom|{task.name}|0|0' --custom-tasks {__file__} --output-dir results --override-batch-size 1 --use-chat-template --save-details --no-public-run --max-samples 10" ) diff --git a/src/lighteval/models/model_input.py b/src/lighteval/models/model_input.py index d5cdfe889..b617ba368 100644 --- a/src/lighteval/models/model_input.py +++ b/src/lighteval/models/model_input.py @@ -35,7 +35,6 @@ class GenerationParameters(BaseModel, extra="forbid"): presence_penalty: NonNegativeFloat | None = None # vllm, sglang max_new_tokens: NonNegativeInt | None = None # vllm, transformers, tgi, litellm, sglang min_new_tokens: NonNegativeInt | None = None # vllm, transformers, sglang - grammar: str | None = None # tgi seed: NonNegativeInt | None = None # vllm, tgi, litellm stop_tokens: list[str] | None = None # vllm, transformers, tgi, litellm, sglang temperature: NonNegativeFloat = ( @@ -217,7 +216,6 @@ def to_tgi_ie_dict(self) -> dict: "top_k": self.top_k, "top_p": self.top_p, "truncate": self.truncate_prompt, - "grammar": self.grammar, } return {k: v for k, v in args.items() if v is not None} From 007932db93f52bc55af7cb6f96af2303569a4815 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Wed, 20 Aug 2025 18:18:49 +0200 Subject: [PATCH 15/20] del: grammar --- tests/models/endpoints/test_endpoint_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/endpoints/test_endpoint_model.py b/tests/models/endpoints/test_endpoint_model.py index f1d6f1726..4f009ca9a 100644 --- a/tests/models/endpoints/test_endpoint_model.py +++ b/tests/models/endpoints/test_endpoint_model.py @@ -57,7 +57,6 @@ class TestInferenceEndpointModelConfig: "cache_implementation": None, "early_stopping": None, "frequency_penalty": None, - "grammar": None, "length_penalty": None, "max_new_tokens": 256, "min_new_tokens": None, From 19afc28738913eece45bcb55b997d393c20bba41 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Thu, 21 Aug 2025 16:42:07 +0200 Subject: [PATCH 16/20] add: copyright at the top --- ...custom_task_classification_grammar_task.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/community_tasks/custom_task_classification_grammar_task.py b/community_tasks/custom_task_classification_grammar_task.py index ac18a0e5d..f513cf0bf 100644 --- a/community_tasks/custom_task_classification_grammar_task.py +++ b/community_tasks/custom_task_classification_grammar_task.py @@ -1,3 +1,26 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# ruff: noqa: F405, F403, F401 """Emotion Classification Task with Grammar Constraints using LightEval This module demonstrates how to create a classification task in LightEval with JSON grammar-constrained generation for structured responses. From 90e505fbf4ce54052940bf8f0ba89b16e42937a6 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Thu, 21 Aug 2025 16:42:47 +0200 Subject: [PATCH 17/20] del: langcodes dep isn't needed anymore --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8f0cb9c05..b3b69a384 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,7 +111,6 @@ multilingual = [ "spacy[ja,ko,th]", "jieba", # for chinese tokenizer "pyvi", # for vietnamese tokenizer - "langcodes>=3.5.0", ] math = ["latex2sympy2_extended==1.0.6"] wandb = ["wandb"] From ab5c6b9f5c8ceadec8d1124678017cb72d79304c Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Thu, 21 Aug 2025 16:43:15 +0200 Subject: [PATCH 18/20] add: use load from file directly in the main endpoint --- src/lighteval/main_endpoint.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/src/lighteval/main_endpoint.py b/src/lighteval/main_endpoint.py index 969c742df..f824ca7ab 100644 --- a/src/lighteval/main_endpoint.py +++ b/src/lighteval/main_endpoint.py @@ -249,11 +249,8 @@ def tgi( """ Evaluate models using TGI as backend. """ - import yaml - from lighteval.logging.evaluation_tracker import EvaluationTracker from lighteval.models.endpoints.tgi_model import TGIModelConfig - from lighteval.models.model_input import GenerationParameters from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters evaluation_tracker = EvaluationTracker( @@ -269,18 +266,7 @@ def tgi( parallelism_manager = ParallelismManager.TGI - with open(model_config_path, "r") as f: - config = yaml.safe_load(f) - - # Extract generation_parameters from model_parameters if they exist - model_params = config["model_parameters"].copy() - yaml_gen_params = model_params.pop("generation_parameters", {}) - - # Start with defaults and override with YAML values - generation_parameters = GenerationParameters(**yaml_gen_params) - - # Create model config without generation_parameters in model_params - model_config = TGIModelConfig(**model_params, generation_parameters=generation_parameters) + model_config = TGIModelConfig.from_path(model_config_path) pipeline_params = PipelineParameters( launcher_type=parallelism_manager, From 834aa083b779f6a03226095b8cdee7882a548154 Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Thu, 21 Aug 2025 16:43:53 +0200 Subject: [PATCH 19/20] del: newlines --- src/lighteval/models/model_input.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lighteval/models/model_input.py b/src/lighteval/models/model_input.py index b617ba368..d6e5fe193 100644 --- a/src/lighteval/models/model_input.py +++ b/src/lighteval/models/model_input.py @@ -33,8 +33,10 @@ class GenerationParameters(BaseModel, extra="forbid"): frequency_penalty: NonNegativeFloat | None = None # vllm, tgi, sglang length_penalty: NonNegativeFloat | None = None # vllm, transformers presence_penalty: NonNegativeFloat | None = None # vllm, sglang + max_new_tokens: NonNegativeInt | None = None # vllm, transformers, tgi, litellm, sglang min_new_tokens: NonNegativeInt | None = None # vllm, transformers, sglang + seed: NonNegativeInt | None = None # vllm, tgi, litellm stop_tokens: list[str] | None = None # vllm, transformers, tgi, litellm, sglang temperature: NonNegativeFloat = ( From 2e5c74ba41211861e2ace0ff85a980cf416758ac Mon Sep 17 00:00:00 2001 From: cpcdoy Date: Thu, 21 Aug 2025 16:44:09 +0200 Subject: [PATCH 20/20] add: mock HTTP request for info to TGI server --- tests/utils/test_caching.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/utils/test_caching.py b/tests/utils/test_caching.py index dcaee4559..1d8f6060d 100644 --- a/tests/utils/test_caching.py +++ b/tests/utils/test_caching.py @@ -219,9 +219,10 @@ def test_cache_vllm(self, mock_create_model, mock_greedy_until, mock_loglikeliho self._test_cache(model) + @patch("requests.get") @patch("lighteval.models.endpoints.tgi_model.ModelClient._greedy_until") @patch("lighteval.models.endpoints.tgi_model.ModelClient._loglikelihood") - def test_cache_tgi(self, mock_greedy_until, mock_loglikelihood): + def test_cache_tgi(self, mock_loglikelihood, mock_greedy_until, mock_requests_get): from lighteval.models.endpoints.tgi_model import ModelClient, TGIModelConfig from lighteval.utils.imports import is_tgi_available @@ -229,11 +230,16 @@ def test_cache_tgi(self, mock_greedy_until, mock_loglikelihood): pytest.skip("Skipping because missing the imports") # Mock TGI requests - mock_greedy_until.return_value = self.model_responses mock_loglikelihood.return_value = self.model_responses + mock_greedy_until.return_value = self.model_responses + + # Mock HTTP info request + mock_requests_get.return_value.json.return_value = {"model_id": "Qwen/Qwen3-0.6B"} with tempfile.TemporaryDirectory() as temp_dir: - config = TGIModelConfig(model_name="Qwen/Qwen3-0.6B", cache_dir=temp_dir) + config = TGIModelConfig( + model_name="Qwen/Qwen3-0.6B", cache_dir=temp_dir, inference_server_address="http://localhost:8080" + ) model = ModelClient(config) self._test_cache(model)