@@ -166,7 +166,7 @@ async def query(self, prompt, mask_logits_processor, head_input_id_rewriter, act
166
166
167
167
async def debug_out (decoder_step ):
168
168
if _DCLibDebugPrinter .printer is not None and dc .DecoderSequence .graph is not None :
169
- data = await dc .DecoderSequence .graph .json ()
169
+ data = await dc .DecoderSequence .graph .json (diff = True )
170
170
data = replace_inf_nan_with_str (data )
171
171
_DCLibDebugPrinter .printer .add_decoder_state (data )
172
172
dcmodel .report_stats (_DCLibDebugPrinter .printer , decoder_step )
@@ -184,17 +184,18 @@ async def debug_out(decoder_step):
184
184
if step_budget is not None and decoder_step >= step_budget :
185
185
print ("warning: step budget exceeded" )
186
186
break
187
-
188
- if "performance_stats" in decoder_args :
189
- Stats .print_all ()
190
187
191
188
if interrupt .check ():
192
189
interrupt .clear ()
193
190
raise InterruptedError ("lmql.runtime.interrupt" )
194
191
195
- average_step_time = time .time () - start if average_step_time is None else (average_step_time * 0.9 + (time .time () - start ) * 0.1 )
196
- # if decoder_step % 10 == 0:
197
- # print("step", decoder_step, "time", average_step_time)
192
+ average_step_time = (time .time () - start ) if average_step_time is None else (average_step_time * 0.9 + (time .time () - start ) * 0.1 )
193
+
194
+ if "performance_stats" in decoder_args :
195
+ if decoder_step % 10 == 0 :
196
+ Stats .print_all ()
197
+ print ("step" , decoder_step , "time" , average_step_time )
198
+
198
199
start = time .time ()
199
200
200
201
except dc .FinishException as fe :
@@ -219,7 +220,7 @@ async def debug_out(decoder_step):
219
220
return [self .make_decoder_head (i ,n ,s ) for i ,s in enumerate (result_sequences )]
220
221
221
222
def validate_args (self , decoder_args , decoder_fct ):
222
- INTERNAL_ARGS = ["decoder" , "dcmodel" , "dclib_additional_logits_processor" , "input_id_rewriter" , "output_writer" , "chatty_openai" , "distribution_batch_size" , "openai_chunksize" , "step_budget" , "stats" ]
223
+ INTERNAL_ARGS = ["decoder" , "dcmodel" , "dclib_additional_logits_processor" , "input_id_rewriter" , "output_writer" , "chatty_openai" , "distribution_batch_size" , "openai_chunksize" , "step_budget" , "stats" , "performance_stats" ]
223
224
224
225
# get all arg names and kwarg names of decoder function
225
226
decoder_arg_names = inspect .getfullargspec (decoder_fct ).args
0 commit comments