Skip to content

Commit f198ba3

Browse files
authored
Fix go enum encoding (#1892)
Adds a test and fixes enums as members of structs <!-- ELLIPSIS_HIDDEN --> ---- > [!IMPORTANT] > Fixes Go enum encoding by updating pointer handling in `generate_types.rs`, `decode.go`, and `types.go`, with a new test case in `cffi_test.go`. > > - **Behavior**: > - Fixes enum encoding in Go by updating `render_value_coercion` in `generate_types.rs` to handle pointers correctly. > - Modifies `decodePrimitiveValue` in `decode.go` to return pointers. > - Updates `Decode` methods in `types.go` and `unions.go` to use pointer dereferencing. > - **Tests**: > - Adds a test case in `cffi_test.go` for `Person` struct to verify enum encoding fix. > - **Misc**: > - Removes unnecessary type casting in `types.go` and `unions.go`. > > <sup>This description was created by </sup>[<img alt="Ellipsis" src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup> for ca36789. You can [customize](https://app.ellipsis.dev/BoundaryML/settings/summaries) this summary. It will automatically update as commits are pushed.</sup> <!-- ELLIPSIS_HIDDEN -->
1 parent 42b3fbe commit f198ba3

5 files changed

Lines changed: 254 additions & 310 deletions

File tree

engine/language_client_codegen/src/go/generate_types.rs

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,25 +55,15 @@ pub(crate) fn cast_value(container_variable_name: &str, field_type: &GoType) ->
5555

5656
fn render_value_coercion(container_variable_name: &str, field_type: &GoType) -> String {
5757
if field_type.is_pointer {
58-
let inner_type = field_type.underlying_type.as_ref().unwrap();
5958
return format!(
6059
"func () {} {{
6160
val := baml.Decode({})
6261
if val == nil {{
6362
return nil
6463
}}
65-
castVal := val.({})
66-
return &castVal
64+
return val.({})
6765
}}()",
68-
field_type.name, container_variable_name, inner_type.name,
69-
);
70-
} else if field_type.is_class {
71-
return format!(
72-
"*baml.Decode({}).(*{})",
73-
container_variable_name,
74-
filters::type_name_without_pointer(&field_type.name)
75-
.ok()
76-
.unwrap()
66+
field_type.name, container_variable_name, field_type.name,
7767
);
7868
} else if field_type.is_slice {
7969
let inner_type = field_type.underlying_type.as_ref().unwrap();
@@ -84,14 +74,9 @@ fn render_value_coercion(container_variable_name: &str, field_type: &GoType) ->
8474
inner_type.name,
8575
render_value_coercion("__holder", inner_type),
8676
);
87-
} else if field_type.is_union {
88-
return format!(
89-
"*baml.Decode({container_variable_name}).(*{})",
90-
field_type.name
91-
);
9277
} else {
9378
return format!(
94-
"baml.Decode({container_variable_name}).({})",
79+
"*baml.Decode({container_variable_name}).(*{})",
9580
field_type.name
9681
);
9782
}

engine/language_client_go/pkg/decode.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,15 @@ func (d *DynamicEnum) Decode(holder cffi.CFFIValueEnum) {
6262
d.Value = string(holder.Value())
6363
}
6464

65-
func decodePrimitiveValue[U any, T cffiValue[U]](valueHolder *cffi.CFFIValueHolder, t T) U {
65+
func decodePrimitiveValue[U any, T cffiValue[U]](valueHolder *cffi.CFFIValueHolder, t T) *U {
6666
var tbl flatbuffers.Table
6767
if !valueHolder.Value(&tbl) {
6868
panic("error decoding value")
6969
}
7070

7171
t.Init(tbl.Bytes, tbl.Pos)
72-
return t.Value()
72+
val := t.Value()
73+
return &val
7374
}
7475

7576
func decodeListValue(valueHolder *cffi.CFFIValueHolder) any {
@@ -396,8 +397,9 @@ func Decode(holder *cffi.CFFIValueHolder) any {
396397
case cffi.CFFIValueUnionNONE:
397398
return nil
398399
case cffi.CFFIValueUnionCFFIValueString:
399-
valueString := decodePrimitiveValue(holder, &cffi.CFFIValueString{})
400-
return string(valueString)
400+
valueBytes := decodePrimitiveValue(holder, &cffi.CFFIValueString{})
401+
valueString := string(*valueBytes)
402+
return &valueString
401403
case cffi.CFFIValueUnionCFFIValueInt:
402404
return decodePrimitiveValue(holder, &cffi.CFFIValueInt{})
403405
case cffi.CFFIValueUnionCFFIValueFloat:

0 commit comments

Comments
 (0)