Skip to content

Commit 27f9a8a

Browse files
incremental streaming of decoder graph in playground
1 parent c94b793 commit 27f9a8a

File tree

6 files changed

+95
-20
lines changed

6 files changed

+95
-20
lines changed

src/lmql/runtime/dclib/dclib_seq.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111

1212
detokenize_seqs = True
1313

14+
@dataclass
15+
class DecoderGraphSnapshot:
16+
node_hashes = None
17+
step_not_updated_count = None
18+
1419
class DecoderGraph:
1520
def __init__(self):
1621
self.nodes = {}
@@ -19,6 +24,8 @@ def __init__(self):
1924

2025
self.ctr = 0
2126

27+
self.json_snapshot: DecoderGraphSnapshot = DecoderGraphSnapshot()
28+
2229
def add_node(self, node):
2330
uid = f"n{self.ctr}"
2431
self.ctr += 1
@@ -34,13 +41,39 @@ def set_pool(self, node, pool_name):
3441
def add_edge(self, from_node, to_node):
3542
self.edges.append((self.node_ids[from_node], self.node_ids[to_node]))
3643

37-
async def json(self):
44+
async def json(self, diff: bool = False):
3845
nodes = []
46+
47+
if self.json_snapshot.node_hashes is None:
48+
self.json_snapshot.node_hashes = {}
49+
if self.json_snapshot.step_not_updated_count is None:
50+
self.json_snapshot.step_not_updated_count = {}
51+
3952
for k, v in self.nodes.items():
53+
hash = None
54+
if diff and k in self.json_snapshot.node_hashes:
55+
# this will ignore changes to nodes that have not been updated for 3 steps (may miss changes to nodes that
56+
# are not updated for a long time, careful for now)
57+
if k in self.json_snapshot.step_not_updated_count and self.json_snapshot.step_not_updated_count[k] > 2:
58+
continue
59+
hash = await v.json_hash()
60+
if self.json_snapshot.node_hashes[k] == str(hash):
61+
self.json_snapshot.step_not_updated_count[k] = self.json_snapshot.step_not_updated_count.get(k, 0) + 1
62+
continue
63+
64+
# else:
65+
# print("node changed", k, self.json_snapshot.node_hashes[k], hash)
66+
if hash is None:
67+
hash = await v.json_hash()
68+
4069
nodes.append({
4170
"id": k,
4271
**await v.json()
4372
})
73+
74+
if diff:
75+
self.json_snapshot.node_hashes[k] = str(hash)
76+
self.json_snapshot.step_not_updated_count[k] = 0
4477

4578
return {
4679
"nodes": nodes,
@@ -180,6 +213,12 @@ async def text_provider():
180213

181214
return self.tokenizer_cache[name]
182215

216+
async def json_hash(self):
217+
o = await self.json()
218+
if type(o["user_data"]) is dict and "openai-continuations" in o["user_data"]:
219+
o["user_data"].pop("openai-continuations")
220+
return hash(str(o))
221+
183222
async def json(self):
184223
seqtext = await self.detokenized("seqtext")
185224
text = await self.detokenized("text")

src/lmql/runtime/dclib/lmql_adapter.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ async def query(self, prompt, mask_logits_processor, head_input_id_rewriter, act
166166

167167
async def debug_out(decoder_step):
168168
if _DCLibDebugPrinter.printer is not None and dc.DecoderSequence.graph is not None:
169-
data = await dc.DecoderSequence.graph.json()
169+
data = await dc.DecoderSequence.graph.json(diff=True)
170170
data = replace_inf_nan_with_str(data)
171171
_DCLibDebugPrinter.printer.add_decoder_state(data)
172172
dcmodel.report_stats(_DCLibDebugPrinter.printer, decoder_step)
@@ -184,17 +184,18 @@ async def debug_out(decoder_step):
184184
if step_budget is not None and decoder_step >= step_budget:
185185
print("warning: step budget exceeded")
186186
break
187-
188-
if "performance_stats" in decoder_args:
189-
Stats.print_all()
190187

191188
if interrupt.check():
192189
interrupt.clear()
193190
raise InterruptedError("lmql.runtime.interrupt")
194191

195-
average_step_time = time.time() - start if average_step_time is None else (average_step_time * 0.9 + (time.time() - start) * 0.1)
196-
# if decoder_step % 10 == 0:
197-
# print("step", decoder_step, "time", average_step_time)
192+
average_step_time = (time.time() - start) if average_step_time is None else (average_step_time * 0.9 + (time.time() - start) * 0.1)
193+
194+
if "performance_stats" in decoder_args:
195+
if decoder_step % 10 == 0:
196+
Stats.print_all()
197+
print("step", decoder_step, "time", average_step_time)
198+
198199
start = time.time()
199200

200201
except dc.FinishException as fe:
@@ -219,7 +220,7 @@ async def debug_out(decoder_step):
219220
return [self.make_decoder_head(i,n,s) for i,s in enumerate(result_sequences)]
220221

221222
def validate_args(self, decoder_args, decoder_fct):
222-
INTERNAL_ARGS = ["decoder", "dcmodel", "dclib_additional_logits_processor", "input_id_rewriter", "output_writer", "chatty_openai", "distribution_batch_size", "openai_chunksize", "step_budget", "stats"]
223+
INTERNAL_ARGS = ["decoder", "dcmodel", "dclib_additional_logits_processor", "input_id_rewriter", "output_writer", "chatty_openai", "distribution_batch_size", "openai_chunksize", "step_budget", "stats", "performance_stats"]
223224

224225
# get all arg names and kwarg names of decoder function
225226
decoder_arg_names = inspect.getfullargspec(decoder_fct).args

src/lmql/runtime/tokenizer.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,16 @@ def decode(self, input_ids):
103103
if n-1 in self.detokenizer_cache.keys():
104104
key = str(input_ids[:-1])
105105
if key in self.detokenizer_cache[n-1].keys():
106+
global reverse_special_token_mappings
106107
# print("secondary cache hit")
107-
return self.detokenizer_cache[n-1][key] + self.tokenizer_impl.decode([input_ids[-1]])
108+
if input_ids[-1] >= self.tokenizer_impl.vocab_size:
109+
extended = self.detokenizer_cache[n-1][key] + "<" + reverse_special_token_mappings[input_ids[-1]] + "/>"
110+
else:
111+
extended = self.detokenizer_cache[n-1][key] + self.tokenizer_impl.decode([input_ids[-1]])
112+
if not n in self.detokenizer_cache.keys():
113+
self.detokenizer_cache[n] = {}
114+
self.detokenizer_cache[n][str(input_ids)] = extended
115+
return extended
108116

109117
s = ""
110118
for chunk in self.chunk_out_by_special_ids(input_ids):

src/lmql/ui/playground/src/App.jsx

+4-4
Original file line numberDiff line numberDiff line change
@@ -763,9 +763,9 @@ const ModelResultText = styled.div`
763763
764764
div .tag-assistant {
765765
display: inline-block;
766-
width: 65%;
767766
border: 1pt solid #5c5c5c;
768767
margin-top: 5pt;
768+
margin-right: 4%;
769769
770770
border-radius: 8pt;
771771
overflow: hidden;
@@ -774,7 +774,7 @@ const ModelResultText = styled.div`
774774
775775
div .tag-user {
776776
display: block;
777-
margin-left: 32%;
777+
margin-left: 4%;
778778
position: relative;
779779
border: 1pt solid #5c5c5c;
780780
border-radius: 8pt;
@@ -918,8 +918,8 @@ class Truncated extends React.Component {
918918

919919
componentDidMount() {
920920
this.stepper = setInterval(() => {
921-
this.setState(s => Object.assign(s, { typingOffset: s.typingOffset + 4 }))
922-
}, 10)
921+
this.setState(s => Object.assign(s, { typingOffset: s.typingOffset + 16 }))
922+
}, 5)
923923
}
924924

925925
componentWillUnmount() {

src/lmql/ui/playground/src/DecoderGraph.js

+22-6
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,11 @@ export function DecoderGraph(props) {
567567
target: to
568568
}
569569
};
570-
cy.add({group: "edges", ...e})
570+
try {
571+
cy.add({group: "edges", ...e})
572+
} catch (e) {
573+
console.error("error adding edge", e)
574+
}
571575
}))
572576

573577
let mostLikely = layoutDecoderGraph(cy)
@@ -588,11 +592,23 @@ export function DecoderGraph(props) {
588592
const renderer = {
589593
add_result: (output) => {
590594
if (output.type == "decoder-graph-state") {
591-
// persist decoder graph in local state
592-
const raw = JSON.stringify(output.data)
593-
setRawGraphData(raw)
594-
persistedState.setItem("decoder-graph", raw, onPersistedGraphChange)
595-
setCyData(output.data)
595+
setCyData(cyData => {
596+
// persist decoder graph in local state
597+
let updated = Object.assign({nodes: [], edges: []}, cyData || {})
598+
let nodes = {}
599+
updated.nodes.forEach(n => nodes[n.id] = n)
600+
output.data.nodes.forEach(n => nodes[n.id] = n)
601+
updated.nodes = Array.from(Object.values(nodes))
602+
603+
updated.edges = Array.from([...updated.edges, ...output.data.edges])
604+
605+
// console.log("updating", output.data.nodes.length)
606+
607+
const raw = JSON.stringify(updated)
608+
setRawGraphData(raw)
609+
persistedState.queueSetItem("decoder-graph", raw, onPersistedGraphChange)
610+
return updated
611+
})
596612
} else {
597613
// nop in this component
598614
}

src/lmql/ui/playground/src/State.js

+11
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class PersistedState {
2727
this.items = {}
2828
this.listeners = {}
2929
this.restore()
30+
31+
this.saveQueue = {}
3032
}
3133

3234
persist(k) {
@@ -116,6 +118,15 @@ class PersistedState {
116118
})
117119
}
118120
}
121+
122+
queueSetItem(key, value, exclude_listener=null) {
123+
if (this.saveQueue[key]) {
124+
clearTimeout(this.saveQueue[key]);
125+
}
126+
this.saveQueue[key] = setTimeout(() => {
127+
this.setItem(key, value, exclude_listener);
128+
})
129+
}
119130
}
120131

121132
export const persistedState = new PersistedState();

0 commit comments

Comments
 (0)