diff --git a/modelgauge/simple_test_runner.py b/modelgauge/simple_test_runner.py index ee61069d..7ef19031 100644 --- a/modelgauge/simple_test_runner.py +++ b/modelgauge/simple_test_runner.py @@ -122,37 +122,38 @@ def _process_test_item( ) -> TestItemRecord: interactions: List[PromptInteractionAnnotations] = [] for prompt in item.prompts: - if isinstance(prompt.prompt, TextPrompt): - sut_request = sut.translate_text_prompt(prompt.prompt) - else: - sut_request = sut.translate_chat_prompt(prompt.prompt) try: + if isinstance(prompt.prompt, TextPrompt): + sut_request = sut.translate_text_prompt(prompt.prompt) + else: + sut_request = sut.translate_chat_prompt(prompt.prompt) with sut_cache as cache: sut_response = cache.get_or_call(sut_request, sut.evaluate) + response = sut.translate_response(sut_request, sut_response) except Exception as e: raise Exception( - f"Exception while handling SUT request `{sut_request}` for TestItem `{item}`" + f"Exception while handling SUT {sut.uid} for TestItem `{item}`" ) from e - response = sut.translate_response(sut_request, sut_response) annotated_completions: List[SUTCompletionAnnotations] = [] for completion in response.completions: annotations = {} for annotator_data in annotators: annotator = annotator_data.annotator - annotator_request = annotator.translate_request(prompt, completion) try: + annotator_request = annotator.translate_request(prompt, completion) with annotator_data.cache as cache: annotator_response = cache.get_or_call( annotator_request, annotator.annotate ) + annotation = annotator.translate_response( + annotator_request, annotator_response + ) except Exception as e: raise Exception( f"Exception while handling annotation for {annotator_data.key} on {response}" ) from e - annotation = annotator.translate_response( - annotator_request, annotator_response - ) + annotations[annotator_data.key] = Annotation.from_instance(annotation) annotated_completions.append( SUTCompletionAnnotations(completion=completion, annotations=annotations) diff --git a/tests/test_simple_test_runner.py b/tests/test_simple_test_runner.py index 0d705f16..0cc58bec 100644 --- a/tests/test_simple_test_runner.py +++ b/tests/test_simple_test_runner.py @@ -227,7 +227,7 @@ def _raise_exception(*args, **kwargs): tmpdir, ) err_text = str(err_info.value) - assert "SUT request `text='1' num_completions=1`" in err_text + assert "Exception while handling SUT fake-sut" in err_text assert "TestItem `prompts=[PromptWithContext(" in err_text # Ensure it forwards the original issue assert str(err_info.value.__cause__) == "some-exception"