From b81d29f46b2aaeeca7d0854e2279e5868dcb2df0 Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Sat, 11 Aug 2018 04:10:00 -0700 Subject: [PATCH] proto: fix handling of required fields after multiple violations The previous logic only treated a required not set violation as non-fatal for the first instance. Afterwards, the logic incorrectly switched back to being fatal. --- proto/all_test.go | 22 ++++++++++++++++++++++ proto/table_marshal.go | 8 +++++--- proto/table_unmarshal.go | 6 ++++-- 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/proto/all_test.go b/proto/all_test.go index a68d91d2aa..1bea4b6e8e 100644 --- a/proto/all_test.go +++ b/proto/all_test.go @@ -2324,6 +2324,28 @@ func TestInvalidUTF8(t *testing.T) { } } +func TestRequired(t *testing.T) { + // The F_BoolRequired field appears after all of the required fields. + // It should still be handled even after multiple required field violations. + m := &GoTest{F_BoolRequired: Bool(true)} + got, err := Marshal(m) + if _, ok := err.(*RequiredNotSetError); !ok { + t.Errorf("Marshal() = %v, want RequiredNotSetError error", err) + } + if want := []byte{0x50, 0x01}; !bytes.Equal(got, want) { + t.Errorf("Marshal() = %x, want %x", got, want) + } + + m = new(GoTest) + err = Unmarshal(got, m) + if _, ok := err.(*RequiredNotSetError); !ok { + t.Errorf("Marshal() = %v, want RequiredNotSetError error", err) + } + if !m.GetF_BoolRequired() { + t.Error("m.F_BoolRequired = false, want true") + } +} + // Benchmarks func testMsg() *GoTest { diff --git a/proto/table_marshal.go b/proto/table_marshal.go index eafe04d14b..b16794496f 100644 --- a/proto/table_marshal.go +++ b/proto/table_marshal.go @@ -252,11 +252,13 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte } } for _, f := range u.fields { - if f.required && errLater == nil { + if f.required { if ptr.offset(f.field).getPointer().isNil() { // Required field is not set. // We record the error but keep going, to give a complete marshaling. - errLater = &RequiredNotSetError{f.name} + if errLater == nil { + errLater = &RequiredNotSetError{f.name} + } continue } } @@ -2592,7 +2594,7 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de p := toAddrPointer(&v, ei.isptr) b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic) b = append(b, 1<<3|WireEndGroup) - if nerr.Merge(err) { + if !nerr.Merge(err) { return b, err } } diff --git a/proto/table_unmarshal.go b/proto/table_unmarshal.go index de868ae927..ebf1caa56a 100644 --- a/proto/table_unmarshal.go +++ b/proto/table_unmarshal.go @@ -175,10 +175,12 @@ func (u *unmarshalInfo) unmarshal(m pointer, b []byte) error { reqMask |= f.reqMask continue } - if r, ok := err.(*RequiredNotSetError); ok && errLater == nil { + if r, ok := err.(*RequiredNotSetError); ok { // Remember this error, but keep parsing. We need to produce // a full parse even if a required field is missing. - errLater = r + if errLater == nil { + errLater = r + } reqMask |= f.reqMask continue }