Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

node/bindnode: allow nilable types for IPLD optional/nullable #401

Merged
merged 1 commit into from
Apr 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions node/bindnode/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func ExampleWrap_withSchema() {
type Person struct {
Name String
Age optional Int
Friends [String]
Friends optional [String]
}
`))
if err != nil {
Expand All @@ -27,8 +27,8 @@ func ExampleWrap_withSchema() {

type Person struct {
Name string
Age *int64 // optional
Friends []string
Age *int64 // optional
Friends []string // optional; no need for a pointer as slices are nilable
}
person := &Person{
Name: "Michael",
Expand Down
57 changes: 42 additions & 15 deletions node/bindnode/infer.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,11 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp
}
goType = goType.Elem()
if schemaType.ValueIsNullable() {
if goType.Kind() != reflect.Ptr {
doPanic("nullable types must be pointers")
if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable {
doPanic("nullable types must be nilable")
} else if ptr {
goType = goType.Elem()
}
goType = goType.Elem()
}
verifyCompatibility(seen, goType, schemaType.ValueType())
case *schema.TypeMap:
Expand Down Expand Up @@ -141,10 +142,11 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp

elemType := fieldValues.Type.Elem()
if schemaType.ValueIsNullable() {
if elemType.Kind() != reflect.Ptr {
doPanic("nullable types must be pointers")
if ptr, nilable := ptrOrNilable(elemType.Kind()); !nilable {
doPanic("nullable types must be nilable")
} else if ptr {
elemType = elemType.Elem()
}
elemType = elemType.Elem()
}
verifyCompatibility(seen, elemType, schemaType.ValueType())
case *schema.TypeStruct:
Expand All @@ -159,18 +161,31 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp
for i, schemaField := range schemaFields {
schemaType := schemaField.Type()
goType := goType.Field(i).Type
// TODO: allow "is nilable" to some degree?
if schemaField.IsNullable() {
switch {
case schemaField.IsOptional() && schemaField.IsNullable():
// TODO: https://github.com/ipld/go-ipld-prime/issues/340 will
// help here, to avoid the double pointer. We can't use nilable
// but non-pointer types because that's just one "nil" state.
if goType.Kind() != reflect.Ptr {
doPanic("nullable types must be pointers")
doPanic("optional and nullable fields must use double pointers (**)")
}
goType = goType.Elem()
}
if schemaField.IsOptional() {
if goType.Kind() != reflect.Ptr {
doPanic("optional types must be pointers")
doPanic("optional and nullable fields must use double pointers (**)")
}
goType = goType.Elem()
case schemaField.IsOptional():
if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable {
doPanic("optional fields must be nilable")
} else if ptr {
goType = goType.Elem()
}
case schemaField.IsNullable():
if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable {
doPanic("nullable fields must be nilable")
} else if ptr {
goType = goType.Elem()
}
}
verifyCompatibility(seen, goType, schemaType)
}
Expand All @@ -186,10 +201,11 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp

for i, schemaType := range schemaMembers {
goType := goType.Field(i).Type
if goType.Kind() != reflect.Ptr {
doPanic("union members must be pointers")
if ptr, nilable := ptrOrNilable(goType.Kind()); !nilable {
doPanic("union members must be nilable")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the ideal error here - it's talking schema language about a Go struct, maybe that's the best we have? but istm that this would be better stated as "must be pointers or interfaces".

Not blocking, just noting that it's weird encountering this error in the test when it's refering to the Go type.

} else if ptr {
goType = goType.Elem()
}
goType = goType.Elem()
verifyCompatibility(seen, goType, schemaType)
}
case *schema.TypeLink:
Expand All @@ -206,6 +222,17 @@ func verifyCompatibility(seen map[seenEntry]bool, goType reflect.Type, schemaTyp
}
}

func ptrOrNilable(kind reflect.Kind) (ptr, nilable bool) {
switch kind {
case reflect.Ptr:
return true, true
case reflect.Interface, reflect.Map, reflect.Slice:
return false, true
default:
return false, false
}
}

// If we recurse past a large number of levels, we're mostly stuck in a loop.
// Prevent burning CPU or causing OOM crashes.
// If a user really wrote an IPLD schema or Go type with such deep nesting,
Expand Down
94 changes: 70 additions & 24 deletions node/bindnode/infer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,37 +334,64 @@ func TestPrototypePointerCombinations(t *testing.T) {
})(nil), `{"x":3,"y":4}`},
}

// For each IPLD kind, we test a matrix of combinations for IPLD's optional
// and nullable fields alongside pointer usage on the Go field side.
modifiers := []struct {
schemaField string // "", "optional", "nullable", "optional nullable"
goPointers int // 0 (T), 1 (*T), 2 (**T)
}{
{"", 0}, // regular IPLD field with Go's T
{"", 1}, // regular IPLD field with Go's *T
{"optional", 0}, // optional IPLD field with Go's T (skipped unless T is nilable)
{"optional", 1}, // optional IPLD field with Go's *T
{"nullable", 0}, // nullable IPLD field with Go's T (skipped unless T is nilable)
{"nullable", 1}, // nullable IPLD field with Go's *T
{"optional nullable", 2}, // optional and nullable IPLD field with Go's **T
}
for _, kindTest := range kindTests {
for _, modifier := range []string{"", "optional", "nullable"} {
for _, modifier := range modifiers {
// don't reuse range vars
kindTest := kindTest
modifier := modifier
t.Run(fmt.Sprintf("%s/%s", kindTest.name, modifier), func(t *testing.T) {
goFieldType := reflect.TypeOf(kindTest.fieldPtrType)
switch modifier.goPointers {
case 0:
goFieldType = goFieldType.Elem() // dereference fieldPtrType
case 1:
// fieldPtrType already uses one pointer
case 2:
goFieldType = reflect.PtrTo(goFieldType) // dereference fieldPtrType
}
if modifier.schemaField != "" && !nilable(goFieldType.Kind()) {
continue
}
t.Run(fmt.Sprintf("%s/%s-%dptr", kindTest.name, modifier.schemaField, modifier.goPointers), func(t *testing.T) {
t.Parallel()

var buf bytes.Buffer
err := template.Must(template.New("").Parse(`
type Root struct {
field {{.Modifier}} {{.Type}}
}`)).Execute(&buf, struct {
Type, Modifier string
}{kindTest.schemaType, modifier})
type Root struct {
field {{.Modifier}} {{.Type}}
}`)).Execute(&buf,
struct {
Type, Modifier string
}{kindTest.schemaType, modifier.schemaField})
qt.Assert(t, err, qt.IsNil)
schemaSrc := buf.String()
t.Logf("IPLD schema: %T", schemaSrc)
t.Logf("IPLD schema: %s", schemaSrc)

// *struct { Field {{.fieldPtrType}} }
ptrType := reflect.Zero(reflect.PtrTo(reflect.StructOf([]reflect.StructField{
{Name: "Field", Type: reflect.TypeOf(kindTest.fieldPtrType)},
// *struct { Field {{.goFieldType}} }
goType := reflect.Zero(reflect.PtrTo(reflect.StructOf([]reflect.StructField{
{Name: "Field", Type: goFieldType},
}))).Interface()
t.Logf("Go type: %T", ptrType)
t.Logf("Go type: %T", goType)

ts, err := ipld.LoadSchemaBytes([]byte(schemaSrc))
qt.Assert(t, err, qt.IsNil)
schemaType := ts.TypeByName("Root")
qt.Assert(t, schemaType, qt.Not(qt.IsNil))

proto := bindnode.Prototype(ptrType, schemaType)
proto := bindnode.Prototype(goType, schemaType)
wantEncodedBytes, err := json.Marshal(map[string]interface{}{"field": json.RawMessage(kindTest.fieldDagJSON)})
qt.Assert(t, err, qt.IsNil)
wantEncoded := string(wantEncodedBytes)
Expand All @@ -377,33 +404,48 @@ func TestPrototypePointerCombinations(t *testing.T) {
// Assigning with the missing field should only work with optional.
nb := proto.NewBuilder()
err = dagjson.Decode(nb, strings.NewReader(`{}`))
if modifier == "optional" {
switch modifier.schemaField {
case "optional", "optional nullable":
qt.Assert(t, err, qt.IsNil)
node := nb.Build()
// The resulting node should be non-nil with a nil field.
nodeVal := reflect.ValueOf(bindnode.Unwrap(node))
qt.Assert(t, nodeVal.Elem().FieldByName("Field").IsNil(), qt.IsTrue)
} else {
default:
qt.Assert(t, err, qt.Not(qt.IsNil))
}

// Assigning with a null field should only work with nullable.
nb = proto.NewBuilder()
err = dagjson.Decode(nb, strings.NewReader(`{"field":null}`))
if modifier == "nullable" {
switch modifier.schemaField {
case "nullable", "optional nullable":
qt.Assert(t, err, qt.IsNil)
node := nb.Build()
// The resulting node should be non-nil with a nil field.
nodeVal := reflect.ValueOf(bindnode.Unwrap(node))
qt.Assert(t, nodeVal.Elem().FieldByName("Field").IsNil(), qt.IsTrue)
} else {
if modifier.schemaField == "nullable" {
qt.Assert(t, nodeVal.Elem().FieldByName("Field").IsNil(), qt.IsTrue)
} else {
qt.Assert(t, nodeVal.Elem().FieldByName("Field").Elem().IsNil(), qt.IsTrue)
}
default:
qt.Assert(t, err, qt.Not(qt.IsNil))
}
})
}
}
}

func nilable(kind reflect.Kind) bool {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need this, or can you look at the second return of ptrOrNilable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The funcs are unexported so they cannot be shared between the test and non-test packages - that's how I arrived at the smaller copy. Any other alternative is more lines of code AFAICT.

switch kind {
case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return true
default:
return false
}
}

func assembleAsKind(proto datamodel.NodePrototype, schemaType schema.Type, asKind datamodel.Kind) (ipld.Node, error) {
nb := proto.NewBuilder()
switch asKind {
Expand Down Expand Up @@ -895,13 +937,13 @@ var verifyTests = []struct {
Keys []string
Values map[string]*datamodel.Node
})(nil),
(*struct {
Keys []string
Values map[string]datamodel.Node
})(nil),
},
badTypes: []verifyBadType{
{(*string)(nil), `.*type Root .* type string: kind mismatch;.*`},
{(*struct {
Keys []string
Values map[string]datamodel.Node
})(nil), `.*type Root .*: nullable types must be pointers`},
},
},
{
Expand All @@ -918,6 +960,10 @@ var verifyTests = []struct {
List *[]string
String *string
})(nil),
(*struct {
List []string
String *string
})(nil),
(*struct {
List *[]namedString
String *namedString
Expand All @@ -927,9 +973,9 @@ var verifyTests = []struct {
{(*string)(nil), `.*type Root .* type string: kind mismatch;.*`},
{(*struct{ List *[]string })(nil), `.*type Root .*: 1 vs 2 members`},
{(*struct {
List *[]string
List []string
String string
})(nil), `.*type Root .*: union members must be pointers`},
})(nil), `.*type Root .*: union members must be nilable`},
{(*struct {
List *[]string
String *int
Expand Down
26 changes: 20 additions & 6 deletions node/bindnode/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,17 @@ func (w *_node) LookupByString(key string) (datamodel.Node, error) {
if fval.IsNil() {
return datamodel.Absent, nil
}
fval = fval.Elem()
if fval.Kind() == reflect.Ptr {
fval = fval.Elem()
}
}
if field.IsNullable() {
if fval.IsNil() {
return datamodel.Null, nil
}
fval = fval.Elem()
if fval.Kind() == reflect.Ptr {
fval = fval.Elem()
}
}
if _, ok := field.Type().(*schema.TypeAny); ok {
return nonPtrVal(fval).Interface().(datamodel.Node), nil
Expand Down Expand Up @@ -822,8 +826,14 @@ func (w *_structAssembler) AssembleValue() datamodel.NodeAssembler {
w.doneFields[ftyp.Index[0]] = true
fval := w.val.FieldByIndex(ftyp.Index)
if field.IsOptional() {
fval.Set(reflect.New(fval.Type().Elem()))
fval = fval.Elem()
if fval.Kind() == reflect.Ptr {
// ptrVal = new(T); val = *ptrVal
fval.Set(reflect.New(fval.Type().Elem()))
fval = fval.Elem()
} else {
// val = *new(T)
fval.Set(reflect.New(fval.Type()).Elem())
}
}
// TODO: reuse same assembler for perf?
return &_assembler{
Expand Down Expand Up @@ -1087,13 +1097,17 @@ func (w *_structIterator) Next() (key, value datamodel.Node, _ error) {
if val.IsNil() {
return key, datamodel.Absent, nil
}
val = val.Elem()
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
}
if field.IsNullable() {
if val.IsNil() {
return key, datamodel.Null, nil
}
val = val.Elem()
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
}
if _, ok := field.Type().(*schema.TypeAny); ok {
return key, nonPtrVal(val).Interface().(datamodel.Node), nil
Expand Down