diff --git a/run_evals_accelerate.py b/run_evals_accelerate.py index 692daaf76..4bf8348b0 100644 --- a/run_evals_accelerate.py +++ b/run_evals_accelerate.py @@ -71,6 +71,7 @@ def get_parser(): parser.add_argument("--inference_server_auth", type=str, default=None) # Model type 3) Inference endpoints task_type_group.add_argument("--endpoint_model_name", type=str) + parser.add_argument("--revision", type=str) parser.add_argument("--accelerator", type=str, default=None) parser.add_argument("--vendor", type=str, default=None) parser.add_argument("--region", type=str, default=None) diff --git a/src/lighteval/models/endpoint_model.py b/src/lighteval/models/endpoint_model.py index e03e0f7ec..d8904309c 100644 --- a/src/lighteval/models/endpoint_model.py +++ b/src/lighteval/models/endpoint_model.py @@ -70,6 +70,7 @@ def __init__( self.endpoint: InferenceEndpoint = create_inference_endpoint( name=config.name, repository=config.repository, + revision=config.revision, framework=config.framework, task="text-generation", accelerator=config.accelerator, diff --git a/src/lighteval/models/model_config.py b/src/lighteval/models/model_config.py index 20ccd4d66..19c8bdaf8 100644 --- a/src/lighteval/models/model_config.py +++ b/src/lighteval/models/model_config.py @@ -238,6 +238,7 @@ class InferenceEndpointModelConfig: endpoint_type: str = "protected" should_reuse_existing: bool = False add_special_tokens: bool = True + revision: str = "main" def get_dtype_args(self) -> Dict[str, str]: model_dtype = self.model_dtype.lower() @@ -296,6 +297,7 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None] instance_type=args.instance_type, should_reuse_existing=args.reuse_existing, model_dtype=args.model_dtype, + revision=args.revision or "main", ) return InferenceModelConfig(model=args.endpoint_model_name)