Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/lmql/runtime/dclib/dclib_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class Continuation:
token: Any
logprob: Any
user_data: Any
distribution_logprobs: Any = None
class criterion:
def __and__(self, other):
return logical_and(self, other)
Expand Down Expand Up @@ -325,8 +326,9 @@ def op_extend(p1, p2):
tokens = continuation.token.reshape(-1)
logprobs = continuation.logprob.reshape(-1)
user_data = continuation.user_data or [None] * len(tokens)
for t,s,u in zip(tokens, logprobs, user_data):
extended_seqs.append(sq.extend(Continuation(t, s, u)))
distribution_logprobs = continuation.distribution_logprobs or [None] * len(tokens)
for t,s,u,d in zip(tokens, logprobs, user_data, distribution_logprobs):
extended_seqs.append(sq.extend(Continuation(t, s, u, d)))
return extended_seqs

return DataArray(apply_componentwise(op_extend, self.sequences, other.sequences, "extend", allow_mismatch_keys=False), dims=self.shape)
Expand Down
15 changes: 12 additions & 3 deletions src/lmql/runtime/dclib/dclib_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def json(self, diff: bool = False):
}

class DecoderSequence:
def __init__(self, input_ids_or_str, logprobs=None, deterministic=None, stop_phrase=None, predecessor=None, user_data=None, sticky_user_data_keys=None, epsilon_node=False, internal=False):
def __init__(self, input_ids_or_str, logprobs=None, deterministic=None, stop_phrase=None, predecessor=None, user_data=None, sticky_user_data_keys=None, epsilon_node=False, internal=False, distribution_logprobs=None):
if logprobs is not None:
if not all([p > get_truncation_threshold() for p in logprobs]):
warnings.warn("logprobs contain values below the current logprob truncation threshold {t}, which may cause unexpected behavior. Consider increasing the truncation threshold via lmql.model(..., truncation_threshold=...).".format(t=get_truncation_threshold()))
Expand Down Expand Up @@ -141,6 +141,10 @@ def __init__(self, input_ids_or_str, logprobs=None, deterministic=None, stop_phr

# indicates to dc.rewrite whether this sequence can be rewritten
self.needs_rewrite = True
if not distribution_logprobs:
self.distribution_logprobs = [None] * len(self.logprobs)
else:
self.distribution_logprobs = distribution_logprobs

def __hash__(self) -> int:
return hash(self.id)
Expand Down Expand Up @@ -371,7 +375,8 @@ def extend(self, continuation, internal=False):
predecessor=self,
user_data=self.extend_user_data(continuation),
sticky_user_data_keys=self.sticky_user_data_keys,
internal=internal
internal=internal,
distribution_logprobs=self.distribution_logprobs + [continuation.distribution_logprobs]
)

def detect_stop_phrase(self, continuation):
Expand Down Expand Up @@ -436,14 +441,18 @@ def make_successors(self, next_tokens, next_token_scores, logits, user_data=None
tokens = [t for t, s in zip(next_tokens, next_token_scores) if s > get_truncation_threshold()]
scores = [s for s in next_token_scores if s > get_truncation_threshold()]

distribution_logprobs = [{k: v for k, v in logits.probs.items() if type(k) == str}]
if len(distribution_logprobs[0]) < 1:
distribution_logprobs = None

if len(tokens) == 0:
print("WARNING: all continuation token fall below the current logprob truncation threshold {t}. This is likely due to a too low truncation threshold. Please increase the truncation threshold via lmql.model(..., truncation_threshold=...).".format(t=get_truncation_threshold()))
tokens = [t for t, s in zip(next_tokens, next_token_scores)][:1]
scores = [s for s in next_token_scores][:1]
next_tokens = np.stack(tokens, axis=0)
next_token_scores = np.stack(scores, axis=0)

return Continuation(next_tokens, next_token_scores, user_data)
return Continuation(next_tokens, next_token_scores, user_data, distribution_logprobs)
# global counter for all sequences created in this process for identification purposes
DecoderSequence.seq_ctr = 0
DecoderSequence.graph = None
Expand Down
17 changes: 14 additions & 3 deletions src/lmql/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class PromptState(NamedTuple):
where: Optional[Any]
tail: Optional[str]

distribution_logprobs: Optional[Dict[str, float]] = {}

def __str__(self):
return f"<PromptState '{self.variable}' '{[self.prompt]}'>"

Expand Down Expand Up @@ -543,15 +545,20 @@ async def where_step_for_sequence(self, s: dc.DecoderSequence, needs_masking, se
# update hint for max_tokens to generate for current var
max_tokens_hint = ops.most_restrictive_hint([sub_max_token_hints, max_tokens_hint])

if len(s.distribution_logprobs) > variable_offset:
scores = s.distribution_logprobs[variable_offset]
else:
scores = None

# current context
program_state: ProgramState = state.program_state.copy()
program_state.set(variable, text, scores=(), diff=diff_text, montonicity="inc", tokens=text_tokens)
program_state.set(variable, text, scores=scores, diff=diff_text, montonicity="inc", tokens=text_tokens)
program_state.subinterpreter_results = subvalid
program_state.prompt = state.prompt

# follow context
follow_program_state: ProgramState = state.program_state.copy()
follow_program_state.set(variable, text + str(ops.NextToken), scores=(), diff=diff_text, montonicity="inc", tokens=text_tokens)
follow_program_state.set(variable, text + str(ops.NextToken), scores=scores, diff=diff_text, montonicity="inc", tokens=text_tokens)
follow_program_state.subinterpreter_results = subfollow
follow_program_state.prompt = state.prompt

Expand Down Expand Up @@ -611,6 +618,7 @@ async def where_step_for_sequence(self, s: dc.DecoderSequence, needs_masking, se
program_state=program_state,
stopping_phrases=stopping_phrases,
where=await self.where_graph_with_trace(where, trace, follow_trace),
distribution_logprobs=scores,
)

# extract hint of maximum number of tokens to generate for 'variable' from
Expand Down Expand Up @@ -758,7 +766,7 @@ async def rewrite_for_sequence(self, seq: dc.DecoderSequence, needs_rewrite, ass

variable_value = text
# set raw variable value
program_state.set(variable, variable_value, scores=(), diff=text_diff, montonicity="fin", tokens=text_tokens)
program_state.set(variable, variable_value, scores=state.distribution_logprobs, diff=text_diff, montonicity="fin", tokens=text_tokens)

where = state.full_where_condition(self)

Expand Down Expand Up @@ -1019,6 +1027,9 @@ async def debug_out(decoder_step):
if _DCLibDebugPrinter.printer.records_graph:
dc.set_record_graph()
self.decoder_graph = dc.DecoderSequence.graph
if self.model.adapter.decoder_args.get("decoder_graph", False):
dc.set_record_graph()
self.decoder_graph = dc.DecoderSequence.graph

# get decoder function
mode = decoder_args["decoder"].lower()
Expand Down
2 changes: 1 addition & 1 deletion src/lmql/runtime/openai_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ async def op_sample(seqs):
sampling_modes = [f"sample-{temperature}-sample-id-{random.randint(0, 2**32-1)}" for _ in range(len(seqs))]
edge_type_populated_user_data = [{"dc-edge-type": sm} for sm in sampling_modes]

completions: List[CompletionResult] = await self.completion_buffer(seqs, logprobs=num_samples, sampling_modes=sampling_modes, **kwargs)
completions: List[CompletionResult] = await self.completion_buffer(seqs, logprobs=5, sampling_modes=sampling_modes, **kwargs)

next_token_ids = []
next_token_scores = []
Expand Down
48 changes: 48 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import sys
sys.path.append('/Users/felix/Programming/lmql/src')
import lmql
import asyncio

#add replicate api key to env
import os
os.environ['REPLICATE_API_TOKEN'] = 'r8_aOlrg82Wfg30Rx4L4mv9wI2npPfBQGO0Pvci4'

# def test_decorator(variable_value, prompt_value, context):
# return variable_value, prompt_value

async def main():

test = lmql.model(
"openai/gpt-3.5-turbo-instruct"
# "meta-llama/Llama-2-13b-chat-hf",
# endpoint="replicate:deployment/ml-delphai/llama2-13b-chat-lmtp",
# endpoint="replicate:charles-dyfis-net/llama-2-7b-chat-hf--lmtp-8bit",
# tokenizer="AyyYOO/Luna-AI-Llama2-Uncensored-FP16-sharded",
)
pass

answer = await lmql.run(
"""
import math
def get_probs(variable_value, prompt_value, context):
breakpoint()
logprob_scores = list(context.variable_scores.items())[-1][1]
scores = dict()
for key, value in logprob_scores.items():
if value > -5:
scores[key] = math.exp(value)
return scores
argmax(verbose=True)
\"How much you like monkeys between 0 and 2?[@get_probs MONKEY]\" where MONKEY in set ([\"0\", \"1\", \"2\"])
\"How much you like birds between 0 and 2?[@get_probs BIRD]\" where BIRD in set ([\"0\", \"1\", \"2\"])
return (MONKEY, BIRD)
""",
max_len=4000,
model=test,
# decoder_graph=True,
)

print(answer)

if __name__ == "__main__":
asyncio.run(main())