Skip to content

Commit

Permalink
node/bindnode: redesign the shape of unions in Go
Browse files Browse the repository at this point in the history
Before, the shape of IPLD schema Unions in Go was as follows:

	struct {
		Index int // 0..len(typ.Members)-1
		Value interface{}
	}

This worked perfectly fine when inferring Go types from a schema.
The inferred Go types would be anonymous,
so it didn't really matter that the value was behind interface{}.

However, this mechanism did not work for providing a custom Go type.
For example, the equivalent of the added example would be:

	type CustomIntType int64
	type StringOrInt struct {
		Index int
		Value interface{}
	}
	proto := bindnode.Prototype((*StringOrInt)(nil), schemaType)

bindnode failed to use CustomIntType at all,
since StringOrInt did not reference CustomIntType in any way.
Moreover, the interface{} layer also felt prone to type assert panics.

Now, the shape of Unions in Go is:

	struct {
		Type1 *Type1
		Type2 *Type2
		...
	}

The problem described above is no longer present anymore,
as the added runnable example demonstrates.
That alone makes the redesign worthwhile.

One minor downside of the new method is large unions,
since looking up the "index" is now a linear search for the first
non-nil field pointer,
and the size of the Go type increases with each member type.
However, unions with more than a dozen member types should be rare.

A different design direction would have been to keep interface{} values,
and have "Register" APIs to tell bindnode about the member Go types.
However, that feels worse in terms of usability and design complexity.

Fixes #210.
  • Loading branch information
mvdan committed Aug 12, 2021
1 parent 050d445 commit 1e8e5c6
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 29 deletions.
1 change: 1 addition & 0 deletions node/bindnode/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
// the other provided type. For example, we can infer an unnamed Go struct type
// for a schema struct tyep, and we can infer a schema Int type for a Go int64
// type. The inferring logic is still a work in progress and subject to change.
// At this time, inferring IPLD Unions and Enums from Go types is not supported.
//
// When supplying a non-nil ptrType, Prototype only obtains the Go pointer type
// from it, so its underlying value will typically be nil. For example:
Expand Down
58 changes: 58 additions & 0 deletions node/bindnode/example_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package bindnode_test

import (
"fmt"
"os"

ipld "github.com/ipld/go-ipld-prime"
Expand Down Expand Up @@ -99,3 +100,60 @@ func ExamplePrototype_onlySchema() {
// Output:
// {"Friends":["Sarah","Alex"],"Name":"Michael"}
}

func ExamplePrototype_union() {
ts := schema.TypeSystem{}
ts.Init()
ts.Accumulate(schema.SpawnString("String"))
ts.Accumulate(schema.SpawnInt("Int"))
ts.Accumulate(schema.SpawnUnion("StringOrInt",
[]schema.TypeName{
"String",
"Int",
},
schema.SpawnUnionRepresentationKeyed(map[string]schema.TypeName{
"hasString": "String",
"hasInt": "Int",
}),
))

schemaType := ts.TypeByName("StringOrInt")

type CustomIntType int64
type StringOrInt struct {
String *string
Int *CustomIntType // We can use custom types, too.
}

proto := bindnode.Prototype((*StringOrInt)(nil), schemaType)

node, err := qp.BuildMap(proto.Representation(), -1, func(ma ipld.MapAssembler) {
qp.MapEntry(ma, "hasInt", qp.Int(123))
})
if err != nil {
panic(err)
}

fmt.Print("Type level DAG-JSON: ")
dagjson.Encode(node, os.Stdout)
fmt.Println()

fmt.Print("Representation level DAG-JSON: ")
nodeRepr := node.(schema.TypedNode).Representation()
dagjson.Encode(nodeRepr, os.Stdout)
fmt.Println()

// Inspect what the underlying Go value contains.
union := bindnode.Unwrap(node).(*StringOrInt)
switch {
case union.String != nil:
fmt.Printf("Go StringOrInt.String: %v\n", *union.String)
case union.Int != nil:
fmt.Printf("Go StringOrInt.Int: %v\n", *union.Int)
}

// Output:
// Type level DAG-JSON: {"Int":123}
// Representation level DAG-JSON: {"hasInt":123}
// Go StringOrInt.Int: 123
}
19 changes: 14 additions & 5 deletions node/bindnode/infer.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,21 @@ func inferGoType(typ schema.Type) reflect.Type {
}
return reflect.SliceOf(etyp)
case *schema.TypeUnion:
// We need an extra field to record what member we stored.
type goUnion struct {
Index int // 0..len(typ.Members)-1
Value interface{}
// type goUnion struct {
// Type1 *Type1
// Type2 *Type2
// ...
// }
members := typ.Members()
fieldsGo := make([]reflect.StructField, len(members))
for i, ftyp := range members {
ftypGo := inferGoType(ftyp)
fieldsGo[i] = reflect.StructField{
Name: fieldNameFromSchema(string(ftyp.Name())),
Type: reflect.PtrTo(ftypGo),
}
}
return reflect.TypeOf(goUnion{})
return reflect.StructOf(fieldsGo)
}
panic(fmt.Sprintf("%T\n", typ))
}
Expand Down
39 changes: 30 additions & 9 deletions node/bindnode/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ func (w *_node) LookupByString(key string) (ipld.Node, error) {
if mtyp == nil { // not found
return nil, ipld.ErrNotExists{Segment: ipld.PathSegmentOfString(key)}
}
haveIdx := int(w.val.FieldByName("Index").Int())
// TODO: we could look up the right Go field straight away via idx.
haveIdx, mval := unionMember(w.val)
if haveIdx != idx { // mismatching type
return nil, ipld.ErrNotExists{Segment: ipld.PathSegmentOfString(key)}
}
mval := w.val.FieldByName("Value").Elem()
node := &_node{
schemaType: mtyp,
val: mval,
Expand All @@ -180,6 +180,28 @@ func (w *_node) LookupByString(key string) (ipld.Node, error) {
}
}

var invalidValue reflect.Value

func unionMember(val reflect.Value) (int, reflect.Value) {
// The first non-nil field is a match.
for i := 0; i < val.NumField(); i++ {
elemVal := val.Field(i)
if elemVal.IsNil() {
continue
}
return i, elemVal.Elem()
}
return -1, invalidValue
}

func unionSetMember(val reflect.Value, memberIdx int, memberPtr reflect.Value) {
// Reset the entire union struct to zero, to clear any non-nil pointers.
val.Set(reflect.Zero(val.Type()))

// Set the index pointer to the given value.
val.Field(memberIdx).Set(memberPtr)
}

func (w *_node) LookupByIndex(idx int64) (ipld.Node, error) {
switch typ := w.schemaType.(type) {
case *schema.TypeList:
Expand Down Expand Up @@ -917,17 +939,17 @@ func (w *_unionAssembler) AssembleValue() ipld.NodeAssembler {
// Key: basicnode.NewString(name),
// }
}
goType := inferGoType(mtyp) // TODO: do this upfront
val := reflect.New(goType).Elem()

goType := w.val.Field(idx).Type().Elem()
valPtr := reflect.New(goType)
finish := func() error {
// fmt.Println(kval.Interface(), val.Interface())
w.val.FieldByName("Index").SetInt(int64(idx))
w.val.FieldByName("Value").Set(val)
unionSetMember(w.val, idx, valPtr)
return nil
}
return &_assembler{
schemaType: mtyp,
val: val,
val: valPtr.Elem(),
finish: finish,
}
}
Expand Down Expand Up @@ -1076,9 +1098,8 @@ func (w *_unionIterator) Next() (key, value ipld.Node, _ error) {
}
w.done = true

haveIdx := int(w.val.FieldByName("Index").Int())
haveIdx, mval := unionMember(w.val)
mtyp := w.members[haveIdx]
mval := w.val.FieldByName("Value").Elem()

node := &_node{
schemaType: mtyp,
Expand Down
31 changes: 16 additions & 15 deletions node/bindnode/repr.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (w *_nodeRepr) Kind() ipld.Kind {
case schema.UnionRepresentation_Keyed:
return ipld.Kind_Map
case schema.UnionRepresentation_Kinded:
haveIdx := int(w.val.FieldByName("Index").Int())
haveIdx, _ := unionMember(w.val)
mtyp := w.schemaType.(*schema.TypeUnion).Members()[haveIdx]
return mtyp.TypeKind().ActsLike()
case schema.UnionRepresentation_Stringprefix:
Expand Down Expand Up @@ -108,12 +108,12 @@ func inboundMappedType(typ *schema.TypeUnion, stg schema.UnionRepresentation_Key
func (w *_nodeRepr) asKinded(stg schema.UnionRepresentation_Kinded, kind ipld.Kind) *_nodeRepr {
name := stg.GetMember(kind)
members := w.schemaType.(*schema.TypeUnion).Members()
for _, member := range members {
for i, member := range members {
if member.Name() != name {
continue
}
w2 := *w
w2.val = w.val.FieldByName("Value").Elem()
w2.val = w.val.Field(i).Elem()
w2.schemaType = member
return &w2
}
Expand Down Expand Up @@ -340,11 +340,11 @@ func (w *_nodeRepr) AsString() (string, error) {
}
return b.String(), nil
case schema.UnionRepresentation_Stringprefix:
haveIdx := int(w.val.FieldByName("Index").Int())
haveIdx, mval := unionMember(w.val)
mtyp := w.schemaType.(*schema.TypeUnion).Members()[haveIdx]

w2 := *w
w2.val = w.val.FieldByName("Value").Elem()
w2.val = mval
w2.schemaType = mtyp
s, err := w2.AsString()
if err != nil {
Expand Down Expand Up @@ -432,8 +432,9 @@ func (w *_assemblerRepr) asKinded(stg schema.UnionRepresentation_Kinded, kind ip
continue
}
w2 := *w
goType := inferGoType(member) // TODO: do this upfront
w2.val = reflect.New(goType).Elem()
goType := w.val.Field(idx).Type().Elem()
valPtr := reflect.New(goType)
w2.val = valPtr.Elem()
w2.schemaType = member

// Layer a new finish func on top, to set Index/Value.
Expand All @@ -443,8 +444,7 @@ func (w *_assemblerRepr) asKinded(stg schema.UnionRepresentation_Kinded, kind ip
return err
}
}
w.val.FieldByName("Index").SetInt(int64(idx))
w.val.FieldByName("Value").Set(w2.val)
unionSetMember(w.val, idx, valPtr)
return nil
}
return &w2
Expand Down Expand Up @@ -557,8 +557,9 @@ func (w *_assemblerRepr) AssignString(s string) error {
if member.Name() != name {
continue
}
w.val.FieldByName("Index").SetInt(int64(idx))
w.val.FieldByName("Value").Set(reflect.ValueOf(s))
valPtr := reflect.New(goTypeString)
valPtr.Elem().SetString(s)
unionSetMember(w.val, idx, valPtr)
return nil
}
panic("TODO: GetMember result is missing?")
Expand All @@ -575,17 +576,17 @@ func (w *_assemblerRepr) AssignString(s string) error {
}

w2 := *w
goType := inferGoType(member) // TODO: do this upfront
w2.val = reflect.New(goType).Elem()
goType := w.val.Field(idx).Type().Elem()
valPtr := reflect.New(goType)
w2.val = valPtr.Elem()
w2.schemaType = member
w2.finish = func() error {
if w.finish != nil {
if err := w.finish(); err != nil {
return err
}
}
w.val.FieldByName("Index").SetInt(int64(idx))
w.val.FieldByName("Value").Set(w2.val)
unionSetMember(w.val, idx, valPtr)
return nil
}

Expand Down

0 comments on commit 1e8e5c6

Please sign in to comment.