Skip to content

Commit eac4bbf

Browse files
authored
Addressed issue where dynamic types were not streaming. (#1606)
- Updated `ir_helpers/mod.rs` to introduce `IRHelperExtended` and `IRSemanticStreamingHelper` traits, adding methods for class field management and streaming behavior. - Modified `mod.rs` to export the new traits and updated relevant imports across the codebase. - Refactored `validate_streaming_state` and `process_node` functions in `semantic_streaming.rs` to utilize the new traits for improved type handling. - Adjusted various modules to implement the extended functionality, ensuring better integration with runtime context and type management. This update improves the flexibility and usability of the IR helper functions, facilitating better handling of class fields and streaming requirements. <!-- ELLIPSIS_HIDDEN --> ---- > [!IMPORTANT] > Introduce `IRHelperExtended` and `IRSemanticStreamingHelper` traits to enhance dynamic type handling and streaming behavior in the Intermediate Representation. > > - **Traits**: > - Introduce `IRHelperExtended` and `IRSemanticStreamingHelper` in `ir_helpers/mod.rs` for class field management and streaming behavior. > - Implement these traits in `IntermediateRepr` and `ScopedIr`. > - **Functions**: > - Refactor `validate_streaming_state` and `process_node` in `semantic_streaming.rs` to use new traits. > - Update `parsed_value_to_response` in `llm_client/mod.rs` and `helpers/mod.rs` to utilize new traits. > - **Modules**: > - Modify `mod.rs` to export new traits and update imports across the codebase. > - Adjust various modules to implement extended functionality for better runtime context and type management. > > <sup>This description was created by </sup>[<img alt="Ellipsis" src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup> for e9cb26c. It will automatically update as commits are pushed.</sup> <!-- ELLIPSIS_HIDDEN -->
1 parent 366af1a commit eac4bbf

File tree

19 files changed

+1060
-810
lines changed

19 files changed

+1060
-810
lines changed

engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs

Lines changed: 770 additions & 741 deletions
Large diffs are not rendered by default.

engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::path::PathBuf;
77

88
use crate::ir::IntermediateRepr;
99

10-
use super::{scope_diagnostics::ScopeStack, IRHelper};
10+
use super::{scope_diagnostics::ScopeStack, IRHelper, IRHelperExtended};
1111
use crate::ir::jinja_helpers::evaluate_predicate;
1212

1313
#[derive(Default)]

engine/baml-lib/baml-core/src/ir/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
mod ir_helpers;
1+
pub mod ir_helpers;
22
pub mod jinja_helpers;
33
mod json_schema;
44
pub mod repr;
55
mod walker;
66

77
pub use ir_helpers::{
88
scope_diagnostics, ArgCoercer, ClassFieldWalker, ClassWalker, ClientWalker, EnumValueWalker,
9-
EnumWalker, FunctionWalker, IRHelper, RetryPolicyWalker, TemplateStringWalker, TestCaseWalker,
9+
EnumWalker, FunctionWalker, IRHelper, IRHelperExtended, IRSemanticStreamingHelper,
10+
RetryPolicyWalker, TemplateStringWalker, TestCaseWalker, TypeAliasWalker,
1011
};
1112

1213
pub(super) use repr::IntermediateRepr;

engine/baml-lib/jsonish/src/deserializer/semantic_streaming.rs

Lines changed: 12 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::deserializer::coercer::ParsingError;
55
use crate::{BamlValueWithFlags, Flag};
66
use indexmap::{IndexMap, IndexSet};
77
use internal_baml_core::ir::repr::{IntermediateRepr, Walker};
8-
use internal_baml_core::ir::{Field, IRHelper};
8+
use internal_baml_core::ir::{Field, IRHelper, IRHelperExtended, IRSemanticStreamingHelper};
99

1010
use baml_types::{
1111
BamlMap, BamlValueWithMeta, Completion, CompletionState, FieldType, ResponseCheck,
@@ -31,7 +31,7 @@ pub enum StreamingError {
3131
/// For a given baml value, traverse its nodes, comparing the completion state
3232
/// of each node against the streaming behavior of the node's type.
3333
pub fn validate_streaming_state(
34-
ir: &IntermediateRepr,
34+
ir: &impl IRHelperExtended,
3535
baml_value: &BamlValueWithFlags,
3636
field_type: &FieldType,
3737
allow_partials: bool,
@@ -63,7 +63,7 @@ pub fn validate_streaming_state(
6363
/// allow_partials: Whether this node may contain partial values. (Once we
6464
/// see a false, all child nodes will also get false).
6565
fn process_node(
66-
ir: &IntermediateRepr,
66+
ir: &impl IRHelperExtended,
6767
value: BamlValueWithMeta<(CompletionState, &FieldType)>,
6868
allow_partials: bool,
6969
depth: usize,
@@ -233,15 +233,9 @@ fn process_node(
233233

234234
/// Extract the field names from a field_type that is expected to be a `Class`.
235235
/// If it is not a known class, return no field names.
236-
fn type_field_names(ir: &IntermediateRepr, field_type: &FieldType) -> IndexSet<String> {
236+
fn type_field_names(ir: &impl IRHelperExtended, field_type: &FieldType) -> IndexSet<String> {
237237
match ir.distribute_metadata(field_type).0 {
238-
FieldType::Class(class_name) => match ir.find_class(class_name) {
239-
Err(_) => IndexSet::new(),
240-
Ok(class) => class
241-
.walk_fields()
242-
.map(|field| field.name().to_string())
243-
.collect(),
244-
},
238+
FieldType::Class(class_name) => ir.class_field_names(class_name).unwrap_or_default(),
245239
_ => IndexSet::new(),
246240
}
247241
}
@@ -250,31 +244,15 @@ fn type_field_names(ir: &IntermediateRepr, field_type: &FieldType) -> IndexSet<S
250244
/// fields in the class need to be filled in by a null. A field needs to be
251245
/// filled by a null if it is not present in the map value.
252246
fn fields_needing_null_filler<'a>(
253-
ir: &'a IntermediateRepr,
247+
ir: &'a impl IRSemanticStreamingHelper,
254248
class_name: &'a str,
255249
value_names: HashSet<String>,
256250
allow_partials: bool,
257251
) -> Result<HashSet<String>, anyhow::Error> {
258252
if allow_partials == false {
259253
return Ok(HashSet::new());
260254
}
261-
let res = match ir.find_class(class_name) {
262-
Err(_) => Ok(HashSet::new()),
263-
Ok(class) => {
264-
let missing_fields = class
265-
.walk_fields()
266-
.filter_map(|field: Walker<'_, &Field>| {
267-
if !value_names.contains(field.name()) {
268-
Some(field.name().to_string())
269-
} else {
270-
None
271-
}
272-
})
273-
.collect();
274-
Ok(missing_fields)
275-
}
276-
};
277-
res
255+
ir.find_class_fields_needing_null_filler(class_name, &value_names)
278256
}
279257

280258
/// For a given type, assume that it is a class, and list the fields of that
@@ -285,33 +263,21 @@ fn fields_needing_null_filler<'a>(
285263
/// and return an empty set (because we are ignoring the "@stream.not_null" property,
286264
/// which only applies when `allow_partials==true`).
287265
fn needed_fields(
288-
ir: &IntermediateRepr,
266+
ir: &impl IRHelperExtended,
289267
class_name: &str,
290268
allow_partials: bool,
291269
) -> Result<HashSet<String>, anyhow::Error> {
292270
if allow_partials == false {
293271
return Ok(HashSet::new());
294272
}
295-
let class = ir
296-
.find_class(class_name)
273+
ir.class_streaming_needed_fields(class_name)
297274
.map_err(|_| StreamingError::ExpectedClass)
298-
.context("needed_fields failed to lookup class")?;
299-
let needed_fields = class
300-
.walk_fields()
301-
.filter_map(|field: Walker<'_, &Field>| {
302-
if field.streaming_needed() {
303-
Some(field.name().to_string())
304-
} else {
305-
None
306-
}
307-
})
308-
.collect();
309-
Ok(needed_fields)
275+
.context("needed_fields failed to lookup class")
310276
}
311277

312278
/// Whether a type must be complete before being included as a node
313279
/// in a streamed value.
314-
fn required_done(ir: &IntermediateRepr, field_type: &FieldType) -> bool {
280+
fn required_done(ir: &impl IRHelperExtended, field_type: &FieldType) -> bool {
315281
let (base_type, (_, streaming_behavior)) = ir.distribute_metadata(field_type);
316282
let type_implies_done = match base_type {
317283
FieldType::Primitive(tv) => match tv {
@@ -353,7 +319,7 @@ fn completion_state(flags: &Vec<Flag>) -> CompletionState {
353319
}
354320
}
355321

356-
fn type_streaming_behavior(ir: &IntermediateRepr, r#type: &FieldType) -> StreamingBehavior {
322+
fn type_streaming_behavior(ir: &impl IRHelperExtended, r#type: &FieldType) -> StreamingBehavior {
357323
let (_base_type, (_constraints, streaming_behavior)) = ir.distribute_metadata(r#type);
358324
streaming_behavior
359325
}

engine/baml-lib/jsonish/src/helpers/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ pub mod common;
22
use std::{collections::HashSet, path::PathBuf};
33

44
use anyhow::Result;
5-
use baml_types::{EvaluationContext, JinjaExpression};
65
use baml_types::{BamlValueWithMeta, ResponseCheck, StreamingBehavior};
6+
use baml_types::{EvaluationContext, JinjaExpression};
77
use indexmap::{IndexMap, IndexSet};
8+
use internal_baml_core::ir::IRHelperExtended;
89
use internal_baml_core::{
910
ast::Field,
1011
internal_baml_diagnostics::SourceFile,
@@ -263,7 +264,6 @@ pub fn parsed_value_to_response(
263264
field_type: &FieldType,
264265
allow_partials: bool,
265266
) -> Result<ResponseBamlValue> {
266-
267267
let meta_flags: BamlValueWithMeta<Vec<Flag>> = baml_value.clone().into();
268268
let baml_value_with_meta: BamlValueWithMeta<Vec<(String, JinjaExpression, bool)>> =
269269
baml_value.clone().into();
@@ -292,6 +292,6 @@ pub fn parsed_value_to_response(
292292
let response_value = baml_value_with_streaming
293293
.zip_meta(&value_with_response_checks)?
294294
.zip_meta(&meta_flags)?
295-
.map_meta(|((x, y), z)| (z.clone(), y.clone(), x.clone() ));
295+
.map_meta(|((x, y), z)| (z.clone(), y.clone(), x.clone()));
296296
Ok(ResponseBamlValue(response_value))
297297
}

engine/baml-runtime/src/internal/llm_client/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ pub mod traits;
1212
use anyhow::{Context, Result};
1313

1414
use baml_types::{BamlMap, BamlValueWithMeta, FieldType, JinjaExpression, ResponseCheck};
15-
use internal_baml_core::ir::{repr::IntermediateRepr, ClientWalker};
15+
use internal_baml_core::ir::{repr::IntermediateRepr, ClientWalker, IRHelper, IRHelperExtended};
1616
use internal_baml_jinja::RenderedPrompt;
1717
use internal_llm_client::AllowedRoleMetadata;
1818
pub use jsonish::ResponseBamlValue;
@@ -33,7 +33,7 @@ use wasm_bindgen::JsValue;
3333

3434
/// Validate a parsed value, checking asserts and checks.
3535
pub fn parsed_value_to_response(
36-
ir: &IntermediateRepr,
36+
ir: &impl IRHelperExtended,
3737
baml_value: BamlValueWithFlags,
3838
field_type: &FieldType,
3939
allow_partials: bool,

engine/baml-runtime/src/internal/prompt_renderer/mod.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
11
mod render_output_format;
2+
pub(crate) mod scoped_ir;
23
use internal_llm_client::ClientSpec;
34
use jsonish::{BamlValueWithFlags, ResponseBamlValue};
45
use render_output_format::render_output_format;
56

67
use anyhow::Result;
7-
use baml_types::{BamlValue, FieldType};
8+
use baml_types::{BamlValue, FieldType, StreamingBehavior};
89
use internal_baml_core::{
910
error_unsupported,
10-
ir::{repr::IntermediateRepr, FunctionWalker, IRHelper},
11+
ir::{
12+
repr::IntermediateRepr, FunctionWalker, IRHelper, IRHelperExtended,
13+
IRSemanticStreamingHelper,
14+
},
1115
};
1216
use internal_baml_jinja::{
1317
types::OutputFormatContent, RenderContext, RenderContext_Client, RenderedPrompt,
1418
TemplateStringMacro,
1519
};
20+
use scoped_ir::ScopedIr;
1621

17-
use crate::RuntimeContext;
22+
use crate::{runtime_context::RuntimeClassOverride, RuntimeContext};
1823

1924
use super::llm_client::parsed_value_to_response;
2025

@@ -55,6 +60,7 @@ impl PromptRenderer {
5560
pub fn parse(
5661
&self,
5762
ir: &IntermediateRepr,
63+
ctx: &RuntimeContext,
5864
raw_string: &str,
5965
allow_partials: bool,
6066
) -> Result<ResponseBamlValue> {
@@ -64,7 +70,8 @@ impl PromptRenderer {
6470
raw_string,
6571
allow_partials,
6672
)?;
67-
let res = parsed_value_to_response(ir, parsed, &self.output_type, allow_partials);
73+
let scoped_ir = ScopedIr::new(ir, ctx);
74+
let res = parsed_value_to_response(&scoped_ir, parsed, &self.output_type, allow_partials);
6875
res
6976
}
7077

engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use anyhow::Result;
44
use baml_types::BamlValue;
55
use indexmap::{IndexMap, IndexSet};
66
use internal_baml_core::ir::{
7-
repr::IntermediateRepr, ClassWalker, EnumWalker, FieldType, IRHelper,
7+
repr::IntermediateRepr, ClassWalker, EnumWalker, FieldType, IRHelper, IRHelperExtended,
88
};
99
use internal_baml_jinja::types::{Class, Enum, Name, OutputFormatContent};
1010

0 commit comments

Comments
 (0)