Skip to content

Commit 6d00be7

Browse files
committed
fix: improved typing in gptme.evals.run
1 parent 8c3cb77 commit 6d00be7

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

gptme/eval/run.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
import time
1111
from collections import defaultdict
1212
from concurrent.futures import ProcessPoolExecutor, as_completed
13-
from dataclasses import dataclass
1413
from multiprocessing import Manager, Process
15-
from typing import Union
14+
from typing import TypedDict, Union
1615

1716
from .agents import Agent, GPTMe
1817
from .execenv import SimpleExecutionEnv
@@ -27,16 +26,16 @@
2726
logger = logging.getLogger(__name__)
2827

2928

30-
@dataclass
31-
class ProcessSuccess:
29+
class ProcessSuccess(TypedDict):
30+
status: str
3231
files: dict[str, str | bytes]
3332
stdout: str
3433
stderr: str
3534
duration: float
3635

3736

38-
@dataclass
39-
class ProcessError:
37+
class ProcessError(TypedDict):
38+
status: str
4039
message: str
4140
stdout: str
4241
stderr: str
@@ -46,6 +45,10 @@ class ProcessError:
4645
ProcessResult = Union[ProcessSuccess, ProcessError]
4746

4847

48+
class SyncedDict(TypedDict):
49+
result: ProcessResult
50+
51+
4952
def run_evals(
5053
tests: list[ExecTest], models: list[str], timeout: int, parallel: int
5154
) -> dict[str, list[ExecResult]]:
@@ -145,15 +148,15 @@ def execute(test: ExecTest, agent: Agent, timeout: int, parallel: bool) -> ExecR
145148
time_eval = 0.0
146149

147150
with Manager() as manager:
148-
result_dict = manager.dict()
151+
sync_dict = manager.dict()
149152
p = Process(
150153
target=act_process,
151154
args=(
152155
agent,
153-
test["files"],
154-
test["prompt"],
155-
result_dict,
156156
test["name"],
157+
test["prompt"],
158+
test["files"],
159+
sync_dict,
157160
parallel,
158161
),
159162
)
@@ -173,10 +176,10 @@ def execute(test: ExecTest, agent: Agent, timeout: int, parallel: bool) -> ExecR
173176
p.terminate()
174177
p.join(timeout=1)
175178

176-
if "result" in result_dict:
177-
result = result_dict["result"]
179+
if "result" in sync_dict:
180+
result = sync_dict["result"]
178181
time_gen = max(result.get("duration", 0.0), time_gen)
179-
status = result.get("status", "success")
182+
status = result["status"]
180183
files = result.get("files", {})
181184
gen_stdout = result.get("stdout", "")
182185
gen_stderr = result.get("stderr", "")
@@ -264,10 +267,10 @@ def getvalue(self):
264267

265268
def act_process(
266269
agent: Agent,
267-
files,
268-
prompt,
269-
result_dict: dict,
270270
test_name: str,
271+
prompt: str,
272+
files: dict[str, str | bytes],
273+
sync_dict: SyncedDict,
271274
parallel: bool,
272275
):
273276
# Configure logging for this subprocess
@@ -290,13 +293,14 @@ def error_handler(e):
290293
duration = time.time() - start
291294
if not isinstance(e, KeyboardInterrupt):
292295
subprocess_logger.error(f"Error: {e}")
293-
result_dict["result"] = {
296+
result_error: ProcessError = {
294297
"status": "error",
295298
"message": str(e),
296299
"stdout": stdout.getvalue(),
297300
"stderr": stderr.getvalue(),
298301
"duration": duration,
299302
}
303+
sync_dict["result"] = result_error
300304

301305
# kill child processes
302306
os.killpg(pgrp, signal.SIGKILL)
@@ -315,13 +319,14 @@ def sigterm_handler(*_):
315319
return
316320

317321
duration = time.time() - start
318-
result_dict["result"] = {
322+
result_success: ProcessSuccess = {
319323
"status": "success",
320324
"files": files,
321325
"stdout": stdout.getvalue(),
322326
"stderr": stderr.getvalue(),
323327
"duration": duration,
324328
}
329+
sync_dict["result"] = result_success
325330
subprocess_logger.info("Success")
326331

327332
# kill child processes

0 commit comments

Comments
 (0)