diff --git a/src/lmql/runtime/dclib/dclib_array.py b/src/lmql/runtime/dclib/dclib_array.py index d3e85b69..e7817ba1 100644 --- a/src/lmql/runtime/dclib/dclib_array.py +++ b/src/lmql/runtime/dclib/dclib_array.py @@ -12,6 +12,7 @@ class Continuation: token: Any logprob: Any user_data: Any + distribution_logprobs: Any = None class criterion: def __and__(self, other): return logical_and(self, other) @@ -325,8 +326,9 @@ def op_extend(p1, p2): tokens = continuation.token.reshape(-1) logprobs = continuation.logprob.reshape(-1) user_data = continuation.user_data or [None] * len(tokens) - for t,s,u in zip(tokens, logprobs, user_data): - extended_seqs.append(sq.extend(Continuation(t, s, u))) + distribution_logprobs = continuation.distribution_logprobs or [None] * len(tokens) + for t,s,u,d in zip(tokens, logprobs, user_data, distribution_logprobs): + extended_seqs.append(sq.extend(Continuation(t, s, u, d))) return extended_seqs return DataArray(apply_componentwise(op_extend, self.sequences, other.sequences, "extend", allow_mismatch_keys=False), dims=self.shape) diff --git a/src/lmql/runtime/dclib/dclib_seq.py b/src/lmql/runtime/dclib/dclib_seq.py index f5c1e13b..5c439881 100644 --- a/src/lmql/runtime/dclib/dclib_seq.py +++ b/src/lmql/runtime/dclib/dclib_seq.py @@ -82,7 +82,7 @@ async def json(self, diff: bool = False): } class DecoderSequence: - def __init__(self, input_ids_or_str, logprobs=None, deterministic=None, stop_phrase=None, predecessor=None, user_data=None, sticky_user_data_keys=None, epsilon_node=False, internal=False): + def __init__(self, input_ids_or_str, logprobs=None, deterministic=None, stop_phrase=None, predecessor=None, user_data=None, sticky_user_data_keys=None, epsilon_node=False, internal=False, distribution_logprobs=None): if logprobs is not None: if not all([p > get_truncation_threshold() for p in logprobs]): warnings.warn("logprobs contain values below the current logprob truncation threshold {t}, which may cause unexpected behavior. Consider increasing the truncation threshold via lmql.model(..., truncation_threshold=...).".format(t=get_truncation_threshold())) @@ -141,6 +141,10 @@ def __init__(self, input_ids_or_str, logprobs=None, deterministic=None, stop_phr # indicates to dc.rewrite whether this sequence can be rewritten self.needs_rewrite = True + if not distribution_logprobs: + self.distribution_logprobs = [None] * len(self.logprobs) + else: + self.distribution_logprobs = distribution_logprobs def __hash__(self) -> int: return hash(self.id) @@ -371,7 +375,8 @@ def extend(self, continuation, internal=False): predecessor=self, user_data=self.extend_user_data(continuation), sticky_user_data_keys=self.sticky_user_data_keys, - internal=internal + internal=internal, + distribution_logprobs=self.distribution_logprobs + [continuation.distribution_logprobs] ) def detect_stop_phrase(self, continuation): @@ -436,6 +441,10 @@ def make_successors(self, next_tokens, next_token_scores, logits, user_data=None tokens = [t for t, s in zip(next_tokens, next_token_scores) if s > get_truncation_threshold()] scores = [s for s in next_token_scores if s > get_truncation_threshold()] + distribution_logprobs = [{k: v for k, v in logits.probs.items() if type(k) == str}] + if len(distribution_logprobs[0]) < 1: + distribution_logprobs = None + if len(tokens) == 0: print("WARNING: all continuation token fall below the current logprob truncation threshold {t}. This is likely due to a too low truncation threshold. Please increase the truncation threshold via lmql.model(..., truncation_threshold=...).".format(t=get_truncation_threshold())) tokens = [t for t, s in zip(next_tokens, next_token_scores)][:1] @@ -443,7 +452,7 @@ def make_successors(self, next_tokens, next_token_scores, logits, user_data=None next_tokens = np.stack(tokens, axis=0) next_token_scores = np.stack(scores, axis=0) - return Continuation(next_tokens, next_token_scores, user_data) + return Continuation(next_tokens, next_token_scores, user_data, distribution_logprobs) # global counter for all sequences created in this process for identification purposes DecoderSequence.seq_ctr = 0 DecoderSequence.graph = None diff --git a/src/lmql/runtime/interpreter.py b/src/lmql/runtime/interpreter.py index 230153d4..968440ce 100644 --- a/src/lmql/runtime/interpreter.py +++ b/src/lmql/runtime/interpreter.py @@ -102,6 +102,8 @@ class PromptState(NamedTuple): where: Optional[Any] tail: Optional[str] + distribution_logprobs: Optional[Dict[str, float]] = {} + def __str__(self): return f"" @@ -543,15 +545,20 @@ async def where_step_for_sequence(self, s: dc.DecoderSequence, needs_masking, se # update hint for max_tokens to generate for current var max_tokens_hint = ops.most_restrictive_hint([sub_max_token_hints, max_tokens_hint]) + if len(s.distribution_logprobs) > variable_offset: + scores = s.distribution_logprobs[variable_offset] + else: + scores = None + # current context program_state: ProgramState = state.program_state.copy() - program_state.set(variable, text, scores=(), diff=diff_text, montonicity="inc", tokens=text_tokens) + program_state.set(variable, text, scores=scores, diff=diff_text, montonicity="inc", tokens=text_tokens) program_state.subinterpreter_results = subvalid program_state.prompt = state.prompt # follow context follow_program_state: ProgramState = state.program_state.copy() - follow_program_state.set(variable, text + str(ops.NextToken), scores=(), diff=diff_text, montonicity="inc", tokens=text_tokens) + follow_program_state.set(variable, text + str(ops.NextToken), scores=scores, diff=diff_text, montonicity="inc", tokens=text_tokens) follow_program_state.subinterpreter_results = subfollow follow_program_state.prompt = state.prompt @@ -611,6 +618,7 @@ async def where_step_for_sequence(self, s: dc.DecoderSequence, needs_masking, se program_state=program_state, stopping_phrases=stopping_phrases, where=await self.where_graph_with_trace(where, trace, follow_trace), + distribution_logprobs=scores, ) # extract hint of maximum number of tokens to generate for 'variable' from @@ -758,7 +766,7 @@ async def rewrite_for_sequence(self, seq: dc.DecoderSequence, needs_rewrite, ass variable_value = text # set raw variable value - program_state.set(variable, variable_value, scores=(), diff=text_diff, montonicity="fin", tokens=text_tokens) + program_state.set(variable, variable_value, scores=state.distribution_logprobs, diff=text_diff, montonicity="fin", tokens=text_tokens) where = state.full_where_condition(self) @@ -1019,6 +1027,9 @@ async def debug_out(decoder_step): if _DCLibDebugPrinter.printer.records_graph: dc.set_record_graph() self.decoder_graph = dc.DecoderSequence.graph + if self.model.adapter.decoder_args.get("decoder_graph", False): + dc.set_record_graph() + self.decoder_graph = dc.DecoderSequence.graph # get decoder function mode = decoder_args["decoder"].lower() diff --git a/src/lmql/runtime/openai_integration.py b/src/lmql/runtime/openai_integration.py index 69e2d1b1..555e3f1f 100644 --- a/src/lmql/runtime/openai_integration.py +++ b/src/lmql/runtime/openai_integration.py @@ -557,7 +557,7 @@ async def op_sample(seqs): sampling_modes = [f"sample-{temperature}-sample-id-{random.randint(0, 2**32-1)}" for _ in range(len(seqs))] edge_type_populated_user_data = [{"dc-edge-type": sm} for sm in sampling_modes] - completions: List[CompletionResult] = await self.completion_buffer(seqs, logprobs=num_samples, sampling_modes=sampling_modes, **kwargs) + completions: List[CompletionResult] = await self.completion_buffer(seqs, logprobs=5, sampling_modes=sampling_modes, **kwargs) next_token_ids = [] next_token_scores = [] diff --git a/test.py b/test.py new file mode 100644 index 00000000..707a43cc --- /dev/null +++ b/test.py @@ -0,0 +1,48 @@ +import sys +sys.path.append('/Users/felix/Programming/lmql/src') +import lmql +import asyncio + +#add replicate api key to env +import os +os.environ['REPLICATE_API_TOKEN'] = 'r8_aOlrg82Wfg30Rx4L4mv9wI2npPfBQGO0Pvci4' + +# def test_decorator(variable_value, prompt_value, context): +# return variable_value, prompt_value + +async def main(): + + test = lmql.model( + "openai/gpt-3.5-turbo-instruct" + # "meta-llama/Llama-2-13b-chat-hf", + # endpoint="replicate:deployment/ml-delphai/llama2-13b-chat-lmtp", + # endpoint="replicate:charles-dyfis-net/llama-2-7b-chat-hf--lmtp-8bit", + # tokenizer="AyyYOO/Luna-AI-Llama2-Uncensored-FP16-sharded", + ) + pass + + answer = await lmql.run( + """ + import math + def get_probs(variable_value, prompt_value, context): + breakpoint() + logprob_scores = list(context.variable_scores.items())[-1][1] + scores = dict() + for key, value in logprob_scores.items(): + if value > -5: + scores[key] = math.exp(value) + return scores + argmax(verbose=True) + \"How much you like monkeys between 0 and 2?[@get_probs MONKEY]\" where MONKEY in set ([\"0\", \"1\", \"2\"]) + \"How much you like birds between 0 and 2?[@get_probs BIRD]\" where BIRD in set ([\"0\", \"1\", \"2\"]) + return (MONKEY, BIRD) + """, + max_len=4000, + model=test, + # decoder_graph=True, + ) + + print(answer) + +if __name__ == "__main__": + asyncio.run(main())