10
10
import time
11
11
from collections import defaultdict
12
12
from concurrent .futures import ProcessPoolExecutor , as_completed
13
- from dataclasses import dataclass
14
13
from multiprocessing import Manager , Process
15
- from typing import Union
14
+ from typing import TypedDict , Union
16
15
17
16
from .agents import Agent , GPTMe
18
17
from .execenv import SimpleExecutionEnv
27
26
logger = logging .getLogger (__name__ )
28
27
29
28
30
- @ dataclass
31
- class ProcessSuccess :
29
+ class ProcessSuccess ( TypedDict ):
30
+ status : str
32
31
files : dict [str , str | bytes ]
33
32
stdout : str
34
33
stderr : str
35
34
duration : float
36
35
37
36
38
- @ dataclass
39
- class ProcessError :
37
+ class ProcessError ( TypedDict ):
38
+ status : str
40
39
message : str
41
40
stdout : str
42
41
stderr : str
@@ -46,6 +45,10 @@ class ProcessError:
46
45
ProcessResult = Union [ProcessSuccess , ProcessError ]
47
46
48
47
48
+ class SyncedDict (TypedDict ):
49
+ result : ProcessResult
50
+
51
+
49
52
def run_evals (
50
53
tests : list [ExecTest ], models : list [str ], timeout : int , parallel : int
51
54
) -> dict [str , list [ExecResult ]]:
@@ -145,15 +148,15 @@ def execute(test: ExecTest, agent: Agent, timeout: int, parallel: bool) -> ExecR
145
148
time_eval = 0.0
146
149
147
150
with Manager () as manager :
148
- result_dict = manager .dict ()
151
+ sync_dict = manager .dict ()
149
152
p = Process (
150
153
target = act_process ,
151
154
args = (
152
155
agent ,
153
- test ["files" ],
154
- test ["prompt" ],
155
- result_dict ,
156
156
test ["name" ],
157
+ test ["prompt" ],
158
+ test ["files" ],
159
+ sync_dict ,
157
160
parallel ,
158
161
),
159
162
)
@@ -173,10 +176,10 @@ def execute(test: ExecTest, agent: Agent, timeout: int, parallel: bool) -> ExecR
173
176
p .terminate ()
174
177
p .join (timeout = 1 )
175
178
176
- if "result" in result_dict :
177
- result = result_dict ["result" ]
179
+ if "result" in sync_dict :
180
+ result = sync_dict ["result" ]
178
181
time_gen = max (result .get ("duration" , 0.0 ), time_gen )
179
- status = result . get ( "status" , "success" )
182
+ status = result [ "status" ]
180
183
files = result .get ("files" , {})
181
184
gen_stdout = result .get ("stdout" , "" )
182
185
gen_stderr = result .get ("stderr" , "" )
@@ -264,10 +267,10 @@ def getvalue(self):
264
267
265
268
def act_process (
266
269
agent : Agent ,
267
- files ,
268
- prompt ,
269
- result_dict : dict ,
270
270
test_name : str ,
271
+ prompt : str ,
272
+ files : dict [str , str | bytes ],
273
+ sync_dict : SyncedDict ,
271
274
parallel : bool ,
272
275
):
273
276
# Configure logging for this subprocess
@@ -290,13 +293,14 @@ def error_handler(e):
290
293
duration = time .time () - start
291
294
if not isinstance (e , KeyboardInterrupt ):
292
295
subprocess_logger .error (f"Error: { e } " )
293
- result_dict [ "result" ] = {
296
+ result_error : ProcessError = {
294
297
"status" : "error" ,
295
298
"message" : str (e ),
296
299
"stdout" : stdout .getvalue (),
297
300
"stderr" : stderr .getvalue (),
298
301
"duration" : duration ,
299
302
}
303
+ sync_dict ["result" ] = result_error
300
304
301
305
# kill child processes
302
306
os .killpg (pgrp , signal .SIGKILL )
@@ -315,13 +319,14 @@ def sigterm_handler(*_):
315
319
return
316
320
317
321
duration = time .time () - start
318
- result_dict [ "result" ] = {
322
+ result_success : ProcessSuccess = {
319
323
"status" : "success" ,
320
324
"files" : files ,
321
325
"stdout" : stdout .getvalue (),
322
326
"stderr" : stderr .getvalue (),
323
327
"duration" : duration ,
324
328
}
329
+ sync_dict ["result" ] = result_success
325
330
subprocess_logger .info ("Success" )
326
331
327
332
# kill child processes
0 commit comments