Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lighteval/main_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def main(args):

print(make_results_table(final_dict))

model.cleanup()
if not args.reuse_existing:
model.cleanup()

return final_dict
25 changes: 18 additions & 7 deletions src/lighteval/models/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,23 @@ def __init__(
self.async_client = AsyncInferenceClient(model=config.model, token=env_config.token)
self.client = InferenceClient(model=config.model, token=env_config.token)

self.use_async = False # for debug - async use is faster
self.use_async = True # set to False for debug - async use is faster

self._tokenizer = AutoTokenizer.from_pretrained(self.name)
self._add_special_tokens = config.add_special_tokens if config.add_special_tokens is not None else False

@property
def tokenizer(self):
return self._tokenizer

@property
def add_special_tokens(self):
return self._add_special_tokens

@property
def disable_tqdm(self) -> bool:
False # no accelerator = this is the main process

def cleanup(self):
if self.endpoint is not None:
self.endpoint.delete()
Expand Down Expand Up @@ -250,7 +259,8 @@ def greedy_until(
override_bs: Optional[int] = None,
) -> List[GenerateReturn]:
for request in requests:
request.stop_sequence = request.stop_sequence + [self.tokenizer.eos_token]
request.tokenized_context = self.tok_encode(request.context)
request.stop_sequence = as_list(request.stop_sequence) + [self.tokenizer.eos_token]

dataset = GenerativeTaskDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
batch_size = override_bs if override_bs is not None else BATCH_SIZE
Expand All @@ -268,10 +278,11 @@ def greedy_until(
for batch in tqdm(
dataloader, desc="Greedy generation", position=1, leave=False, disable=self.disable_tqdm
):
# the `returns_logits` flag is only used to filter the results, we always request the full details.
if self.use_async:
responses = asyncio.run(self.__async_process_batch_generate(batch, returns_logits))
responses = asyncio.run(self.__async_process_batch_generate(batch))
else:
responses = self.__process_batch_generate(batch, returns_logits)
responses = self.__process_batch_generate(batch)
for response in responses:
results.append(
GenerateReturn(
Expand All @@ -282,7 +293,7 @@ def greedy_until(
)
)

return results
return dataset.get_original_order(results)

def loglikelihood(
self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None
Expand Down Expand Up @@ -321,7 +332,7 @@ def loglikelihood(
)
)

return results
return dataset.get_original_order(results)

def loglikelihood_rolling(
self, requests: list[LoglikelihoodRollingRequest], override_bs=None
Expand Down Expand Up @@ -361,7 +372,7 @@ def loglikelihood_rolling(
)
)

return results
return dataset.get_original_order(results)

def loglikelihood_single_token(
self,
Expand Down
4 changes: 3 additions & 1 deletion src/lighteval/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ class TGIModelConfig:
@dataclass
class InferenceModelConfig:
model: str
add_special_tokens: bool = True


@dataclass
Expand All @@ -235,6 +236,7 @@ class InferenceEndpointModelConfig:
framework: str = "pytorch"
endpoint_type: str = "protected"
should_reuse_existing: bool = False
add_special_tokens: bool = True


def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]) -> BaseModelConfig: # noqa: C901
Expand Down Expand Up @@ -270,7 +272,7 @@ def create_model_config(args: Namespace, accelerator: Union["Accelerator", None]
# Endpoint
if args.endpoint_model_name:
if args.reuse_existing or args.vendor is not None:
model = args.endpoint_model_name.split("/")[1].lower()
model = args.endpoint_model_name.split("/")[1].replace(".", "-").lower()
return InferenceEndpointModelConfig(
name=f"{model}-lighteval",
repository=args.endpoint_model_name,
Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/tasks/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class GreedyUntilRequest(Request):
request_type (RequestType): The type of the request, set to RequestType.GREEDY_UNTIL.
"""

stop_sequence: str
stop_sequence: Union[str, tuple[str], list[str]]
generation_size: int
request_type = RequestType.GREEDY_UNTIL
tokenized_context: list[int] = None
Expand All @@ -132,7 +132,7 @@ class GreedyUntilWithLogitsRequest(Request):
request_type (RequestType): The type of the request (GREEDY_UNTIL_WITH_LOGITS).
"""

stop_sequence: str
stop_sequence: Union[str, tuple[str], list[str]]
generation_size: int
request_type = RequestType.GREEDY_UNTIL_WITH_LOGITS
tokenized_context: list[int] = None
Expand Down