From 7b36aa95a1150c4d2317396d3babd0b6107f7fed Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Tue, 9 Jan 2024 11:41:32 -0500 Subject: [PATCH 1/2] Add a custom proto marshaller to force consistency against other languages for signature verification --- pkg/code/auth/encoding.go | 400 ++++++++++++++++++++++++++++++++ pkg/code/auth/encoding_test.go | 47 ++++ pkg/code/auth/signature.go | 46 +++- pkg/code/auth/signature_test.go | 175 +++++++------- 4 files changed, 576 insertions(+), 92 deletions(-) create mode 100644 pkg/code/auth/encoding.go create mode 100644 pkg/code/auth/encoding_test.go diff --git a/pkg/code/auth/encoding.go b/pkg/code/auth/encoding.go new file mode 100644 index 00000000..61e3c624 --- /dev/null +++ b/pkg/code/auth/encoding.go @@ -0,0 +1,400 @@ +package auth + +import ( + "math" + "sort" + "sync" + "unicode/utf8" + + "github.com/pkg/errors" + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/runtime/protoiface" +) + +// Primarily based off of https://github.com/protocolbuffers/protobuf-go/blob/v1.28.1/proto/encode.go +// but tweaked for consistent marshalling using field number ordering. +// +// todo: Potentially use a custom approach. This is a temporary measure to fix +// discrepancies between Go and other language implementations that result +// in signature mismatch issues due to not using field numbers for ordering. + +const speculativeLength = 1 + +var wireTypes = map[protoreflect.Kind]protowire.Type{ + protoreflect.BoolKind: protowire.VarintType, + protoreflect.EnumKind: protowire.VarintType, + protoreflect.Int32Kind: protowire.VarintType, + protoreflect.Sint32Kind: protowire.VarintType, + protoreflect.Uint32Kind: protowire.VarintType, + protoreflect.Int64Kind: protowire.VarintType, + protoreflect.Sint64Kind: protowire.VarintType, + protoreflect.Uint64Kind: protowire.VarintType, + protoreflect.Sfixed32Kind: protowire.Fixed32Type, + protoreflect.Fixed32Kind: protowire.Fixed32Type, + protoreflect.FloatKind: protowire.Fixed32Type, + protoreflect.Sfixed64Kind: protowire.Fixed64Type, + protoreflect.Fixed64Kind: protowire.Fixed64Type, + protoreflect.DoubleKind: protowire.Fixed64Type, + protoreflect.StringKind: protowire.BytesType, + protoreflect.BytesKind: protowire.BytesType, + protoreflect.MessageKind: protowire.BytesType, + protoreflect.GroupKind: protowire.StartGroupType, +} + +func forceConsistentMarshal(m proto.Message) ([]byte, error) { + out, err := consistentMarshal(nil, m.ProtoReflect()) + if err != nil { + return nil, err + } + return out.Buf, nil +} + +func consistentMarshal(b []byte, m protoreflect.Message) (out protoiface.MarshalOutput, err error) { + out.Buf, err = consistentMarshalMessageSlow(b, m) + if err != nil { + return out, err + } + return out, checkInitialized(m) +} + +func consistentMarshalMessageSlow(b []byte, m protoreflect.Message) ([]byte, error) { + var err error + rangeFields(m, numberFieldOrder, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { + b, err = marshalField(b, fd, v) + return err == nil + }) + if err != nil { + return b, err + } + b = append(b, m.GetUnknown()...) + return b, nil +} + +func marshalSingular(b []byte, fd protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) { + switch fd.Kind() { + case protoreflect.BoolKind: + b = protowire.AppendVarint(b, protowire.EncodeBool(v.Bool())) + case protoreflect.EnumKind: + b = protowire.AppendVarint(b, uint64(v.Enum())) + case protoreflect.Int32Kind: + b = protowire.AppendVarint(b, uint64(int32(v.Int()))) + case protoreflect.Sint32Kind: + b = protowire.AppendVarint(b, protowire.EncodeZigZag(int64(int32(v.Int())))) + case protoreflect.Uint32Kind: + b = protowire.AppendVarint(b, uint64(uint32(v.Uint()))) + case protoreflect.Int64Kind: + b = protowire.AppendVarint(b, uint64(v.Int())) + case protoreflect.Sint64Kind: + b = protowire.AppendVarint(b, protowire.EncodeZigZag(v.Int())) + case protoreflect.Uint64Kind: + b = protowire.AppendVarint(b, v.Uint()) + case protoreflect.Sfixed32Kind: + b = protowire.AppendFixed32(b, uint32(v.Int())) + case protoreflect.Fixed32Kind: + b = protowire.AppendFixed32(b, uint32(v.Uint())) + case protoreflect.FloatKind: + b = protowire.AppendFixed32(b, math.Float32bits(float32(v.Float()))) + case protoreflect.Sfixed64Kind: + b = protowire.AppendFixed64(b, uint64(v.Int())) + case protoreflect.Fixed64Kind: + b = protowire.AppendFixed64(b, v.Uint()) + case protoreflect.DoubleKind: + b = protowire.AppendFixed64(b, math.Float64bits(v.Float())) + case protoreflect.StringKind: + if enforceUTF8(fd) && !utf8.ValidString(v.String()) { + return b, errors.Errorf("field %v contains invalid UTF-8", string(fd.FullName())) + } + b = protowire.AppendString(b, v.String()) + case protoreflect.BytesKind: + b = protowire.AppendBytes(b, v.Bytes()) + case protoreflect.MessageKind: + var pos int + var err error + b, pos = appendSpeculativeLength(b) + b, err = marshalMessage(b, v.Message()) + if err != nil { + return b, err + } + b = finishSpeculativeLength(b, pos) + case protoreflect.GroupKind: + var err error + b, err = marshalMessage(b, v.Message()) + if err != nil { + return b, err + } + b = protowire.AppendVarint(b, protowire.EncodeTag(fd.Number(), protowire.EndGroupType)) + default: + return b, errors.Errorf("invalid kind %v", fd.Kind()) + } + return b, nil +} + +func marshalMessage(b []byte, m protoreflect.Message) ([]byte, error) { + out, err := consistentMarshal(b, m) + return out.Buf, err +} + +func marshalField(b []byte, fd protoreflect.FieldDescriptor, value protoreflect.Value) ([]byte, error) { + switch { + case fd.IsList(): + return marshalList(b, fd, value.List()) + case fd.IsMap(): + return marshalMap(b, fd, value.Map()) + default: + b = protowire.AppendTag(b, fd.Number(), wireTypes[fd.Kind()]) + return marshalSingular(b, fd, value) + } +} + +func marshalList(b []byte, fd protoreflect.FieldDescriptor, list protoreflect.List) ([]byte, error) { + if fd.IsPacked() && list.Len() > 0 { + b = protowire.AppendTag(b, fd.Number(), protowire.BytesType) + b, pos := appendSpeculativeLength(b) + for i, llen := 0, list.Len(); i < llen; i++ { + var err error + b, err = marshalSingular(b, fd, list.Get(i)) + if err != nil { + return b, err + } + } + b = finishSpeculativeLength(b, pos) + return b, nil + } + + kind := fd.Kind() + for i, llen := 0, list.Len(); i < llen; i++ { + var err error + b = protowire.AppendTag(b, fd.Number(), wireTypes[kind]) + b, err = marshalSingular(b, fd, list.Get(i)) + if err != nil { + return b, err + } + } + return b, nil +} + +func marshalMap(b []byte, fd protoreflect.FieldDescriptor, mapv protoreflect.Map) ([]byte, error) { + keyf := fd.MapKey() + valf := fd.MapValue() + var err error + rangeEntries(mapv, genericKeyOrder, func(key protoreflect.MapKey, value protoreflect.Value) bool { + b = protowire.AppendTag(b, fd.Number(), protowire.BytesType) + var pos int + b, pos = appendSpeculativeLength(b) + + b, err = marshalField(b, keyf, key.Value()) + if err != nil { + return false + } + b, err = marshalField(b, valf, value) + if err != nil { + return false + } + b = finishSpeculativeLength(b, pos) + return true + }) + return b, err +} + +// fieldOrder specifies the ordering to visit message fields. +// It is a function that reports whether x is ordered before y. +type fieldOrder func(x, y protoreflect.FieldDescriptor) bool + +var ( + // numberFieldOrder sorts fields by their field number. + numberFieldOrder fieldOrder = func(x, y protoreflect.FieldDescriptor) bool { + return x.Number() < y.Number() + } +) + +type messageField struct { + fd protoreflect.FieldDescriptor + v protoreflect.Value +} + +var messageFieldPool = sync.Pool{ + New: func() interface{} { return new([]messageField) }, +} + +type ( + // fieldRnger is an interface for visiting all fields in a message. + // The protoreflect.Message type implements this interface. + fieldRanger interface{ Range(visitField) } + // visitField is called every time a message field is visited. + visitField = func(protoreflect.FieldDescriptor, protoreflect.Value) bool +) + +func rangeFields(fs fieldRanger, less fieldOrder, fn visitField) { + if less == nil { + fs.Range(fn) + return + } + + // Obtain a pre-allocated scratch buffer. + p := messageFieldPool.Get().(*[]messageField) + fields := (*p)[:0] + defer func() { + if cap(fields) < 1024 { + *p = fields + messageFieldPool.Put(p) + } + }() + + // Collect all fields in the message and sort them. + fs.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { + fields = append(fields, messageField{fd, v}) + return true + }) + sort.Slice(fields, func(i, j int) bool { + return less(fields[i].fd, fields[j].fd) + }) + + // Visit the fields in the specified ordering. + for _, f := range fields { + if !fn(f.fd, f.v) { + return + } + } +} + +// keyOrder specifies the ordering to visit map entries. +// It is a function that reports whether x is ordered before y. +type keyOrder func(x, y protoreflect.MapKey) bool + +var ( + // genericKeyOrder sorts false before true, numeric keys in ascending order, + // and strings in lexicographical ordering according to UTF-8 codepoints. + genericKeyOrder keyOrder = func(x, y protoreflect.MapKey) bool { + switch x.Interface().(type) { + case bool: + return !x.Bool() && y.Bool() + case int32, int64: + return x.Int() < y.Int() + case uint32, uint64: + return x.Uint() < y.Uint() + case string: + return x.String() < y.String() + default: + panic("invalid map key type") + } + } +) + +type ( + // entryRanger is an interface for visiting all fields in a message. + // The protoreflect.Map type implements this interface. + entryRanger interface{ Range(visitEntry) } + // visitEntry is called every time a map entry is visited. + visitEntry = func(protoreflect.MapKey, protoreflect.Value) bool +) + +type mapEntry struct { + k protoreflect.MapKey + v protoreflect.Value +} + +var mapEntryPool = sync.Pool{ + New: func() interface{} { return new([]mapEntry) }, +} + +// rangeEntries iterates over the entries of es according to the specified order. +func rangeEntries(es entryRanger, less keyOrder, fn visitEntry) { + if less == nil { + es.Range(fn) + return + } + + // Obtain a pre-allocated scratch buffer. + p := mapEntryPool.Get().(*[]mapEntry) + entries := (*p)[:0] + defer func() { + if cap(entries) < 1024 { + *p = entries + mapEntryPool.Put(p) + } + }() + + // Collect all entries in the map and sort them. + es.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { + entries = append(entries, mapEntry{k, v}) + return true + }) + sort.Slice(entries, func(i, j int) bool { + return less(entries[i].k, entries[j].k) + }) + + // Visit the entries in the specified ordering. + for _, e := range entries { + if !fn(e.k, e.v) { + return + } + } +} + +func checkInitialized(m protoreflect.Message) error { + return checkInitializedSlow(m) +} + +func checkInitializedSlow(m protoreflect.Message) error { + md := m.Descriptor() + fds := md.Fields() + for i, nums := 0, md.RequiredNumbers(); i < nums.Len(); i++ { + fd := fds.ByNumber(nums.Get(i)) + if !m.Has(fd) { + return errors.Errorf("required field %v not set", string(fd.FullName())) + } + } + var err error + m.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { + switch { + case fd.IsList(): + if fd.Message() == nil { + return true + } + for i, list := 0, v.List(); i < list.Len() && err == nil; i++ { + err = checkInitialized(list.Get(i).Message()) + } + case fd.IsMap(): + if fd.MapValue().Message() == nil { + return true + } + v.Map().Range(func(key protoreflect.MapKey, v protoreflect.Value) bool { + err = checkInitialized(v.Message()) + return err == nil + }) + default: + if fd.Message() == nil { + return true + } + err = checkInitialized(v.Message()) + } + return err == nil + }) + return err +} + +func enforceUTF8(fd protoreflect.FieldDescriptor) bool { + return fd.Syntax() == protoreflect.Proto3 +} + +func appendSpeculativeLength(b []byte) ([]byte, int) { + pos := len(b) + b = append(b, "\x00\x00\x00\x00"[:speculativeLength]...) + return b, pos +} + +func finishSpeculativeLength(b []byte, pos int) []byte { + mlen := len(b) - pos - speculativeLength + msiz := protowire.SizeVarint(uint64(mlen)) + if msiz != speculativeLength { + for i := 0; i < msiz-speculativeLength; i++ { + b = append(b, 0) + } + copy(b[pos+msiz:], b[pos+speculativeLength:]) + b = b[:pos+msiz+mlen] + } + protowire.AppendVarint(b[:pos], uint64(mlen)) + return b +} diff --git a/pkg/code/auth/encoding_test.go b/pkg/code/auth/encoding_test.go new file mode 100644 index 00000000..c2c845f9 --- /dev/null +++ b/pkg/code/auth/encoding_test.go @@ -0,0 +1,47 @@ +package auth + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + messagingpb "github.com/code-payments/code-protobuf-api/generated/go/messaging/v1" +) + +func TestCrossLanguageProtoEncoding(t *testing.T) { + goValue := "CtYBKtMBCiIKIKuIZy+UTqRbPCXCbXMtLl5A1cfBYsPNaFjVyVoj+jRVIhEKD2FwcC5nZXRjb2RlLmNvbSoiCiCQXMeWrnoZmEYYNs2fWQUviipSzVObQX5XfBGCg9KgbDJCCkBgpVkQnlTv9ackQCHPV39NBCHOKh0N5n8gSwQ7Hz8nFldMcdI+TbF+9foOcW/0g+DSnR5kbxbRYEWuRTKo5O8BOiIKIHCkrraPdjY/ImaB3xZiv8D2Qjbpenpkh0Zqk5lUXnr7Gg4KA3VzZBEAAAAAAADgPxIiCiBwpK62j3Y2PyJmgd8WYr/A9kI26Xp6ZIdGapOZVF56+xpCCkDU3CnRyHQ4w0O5D5eIqizAoaBwDft+RjWsGl+Wzo+jCGyE7u+Siw4uZT7U4VcLV6lcsfe9XeB66E7RYlmwAv0I" + otherLanguageValue := "CtYBKtMBCiIKIKuIZy+UTqRbPCXCbXMtLl5A1cfBYsPNaFjVyVoj+jRVGg4KA3VzZBEAAAAAAADgPyIRCg9hcHAuZ2V0Y29kZS5jb20qIgogkFzHlq56GZhGGDbNn1kFL4oqUs1Tm0F+V3wRgoPSoGwyQgpAYKVZEJ5U7/WnJEAhz1d/TQQhziodDeZ/IEsEOx8/JxZXTHHSPk2xfvX6DnFv9IPg0p0eZG8W0WBFrkUyqOTvAToiCiBwpK62j3Y2PyJmgd8WYr/A9kI26Xp6ZIdGapOZVF56+xIiCiBwpK62j3Y2PyJmgd8WYr/A9kI26Xp6ZIdGapOZVF56+xpCCkDU3CnRyHQ4w0O5D5eIqizAoaBwDft+RjWsGl+Wzo+jCGyE7u+Siw4uZT7U4VcLV6lcsfe9XeB66E7RYlmwAv0I" + + var msg1 messagingpb.SendMessageRequest + buffer, err := base64.StdEncoding.DecodeString(goValue) + require.NoError(t, err) + require.NoError(t, proto.Unmarshal(buffer, &msg1)) + + var msg2 messagingpb.SendMessageRequest + buffer, err = base64.StdEncoding.DecodeString(otherLanguageValue) + require.NoError(t, err) + require.NoError(t, proto.Unmarshal(buffer, &msg2)) + + require.True(t, proto.Equal(&msg1, &msg2)) + + for _, encoded := range []string{ + goValue, + otherLanguageValue, + } { + var msg messagingpb.SendMessageRequest + buffer, err := base64.StdEncoding.DecodeString(encoded) + require.NoError(t, err) + require.NoError(t, proto.Unmarshal(buffer, &msg)) + + marshalled, err := proto.Marshal(&msg) + require.NoError(t, err) + assert.Equal(t, goValue, base64.StdEncoding.EncodeToString(marshalled)) + + marshalled, err = forceConsistentMarshal(&msg) + require.NoError(t, err) + assert.Equal(t, otherLanguageValue, base64.StdEncoding.EncodeToString(marshalled)) + } +} diff --git a/pkg/code/auth/signature.go b/pkg/code/auth/signature.go index e2ac20a5..cbf9efb7 100644 --- a/pkg/code/auth/signature.go +++ b/pkg/code/auth/signature.go @@ -3,7 +3,9 @@ package auth import ( "context" "crypto/ed25519" + "encoding/base64" + "github.com/mr-tron/base58" "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -45,13 +47,12 @@ func (v *RPCSignatureVerifier) Authenticate(ctx context.Context, owner *common.A "owner_account": owner.PublicKey().ToBase58(), }) - messageBytes, err := proto.Marshal(message) + isSignatureValid, err := v.isSignatureVerifiedProtoMessage(owner, message, signature) if err != nil { - log.WithError(err).Warn("failure marshalling message") + log.WithError(err).Warn("failure verifying signature") return status.Error(codes.Internal, "") } - isSignatureValid := ed25519.Verify(owner.PublicKey().ToBytes(), messageBytes, signature.Value) if !isSignatureValid { return status.Error(codes.Unauthenticated, "") } @@ -69,13 +70,12 @@ func (v *RPCSignatureVerifier) AuthorizeDataAccess(ctx context.Context, dataCont "owner_account": owner.PublicKey().ToBase58(), }) - messageBytes, err := proto.Marshal(message) + isSignatureValid, err := v.isSignatureVerifiedProtoMessage(owner, message, signature) if err != nil { - log.WithError(err).Warn("failure marshalling message") + log.WithError(err).Warn("failure verifying signature") return status.Error(codes.Internal, "") } - isSignatureValid := ed25519.Verify(owner.PublicKey().ToBytes(), messageBytes, signature.Value) if !isSignatureValid { return status.Error(codes.Unauthenticated, "") } @@ -91,3 +91,37 @@ func (v *RPCSignatureVerifier) AuthorizeDataAccess(ctx context.Context, dataCont } return nil } + +// marshalStrategy is a strategy for marshalling protobuf messages for signature +// verification +type marshalStrategy func(proto.Message) ([]byte, error) + +// defaultMarshalStrategies are the default marshal strategies +var defaultMarshalStrategies = []marshalStrategy{ + forceConsistentMarshal, + proto.Marshal, // todo: deprecate this option +} + +func (v *RPCSignatureVerifier) isSignatureVerifiedProtoMessage(owner *common.Account, message proto.Message, signature *commonpb.Signature) (bool, error) { + for _, marshalStrategy := range defaultMarshalStrategies { + messageBytes, err := marshalStrategy(message) + if err != nil { + return false, err + } + + isSignatureValid := ed25519.Verify(owner.PublicKey().ToBytes(), messageBytes, signature.Value) + if isSignatureValid { + return true, nil + } + } + + encoded, err := proto.Marshal(message) + if err == nil { + v.log.WithFields(logrus.Fields{ + "proto_message": base64.StdEncoding.EncodeToString(encoded), + "signature": base58.Encode(signature.Value), + }).Info("proto message is not signature verified") + } + + return false, nil +} diff --git a/pkg/code/auth/signature_test.go b/pkg/code/auth/signature_test.go index 881787b3..191500e9 100644 --- a/pkg/code/auth/signature_test.go +++ b/pkg/code/auth/signature_test.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" - "google.golang.org/protobuf/proto" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" messagingpb "github.com/code-payments/code-protobuf-api/generated/go/messaging/v1" @@ -35,94 +34,98 @@ func setup(t *testing.T) (env testEnv) { } func TestAuthenticate(t *testing.T) { - env := setup(t) - - ownerAccount := testutil.NewRandomAccount(t) - maliciousAccount := testutil.NewRandomAccount(t) - - msgValue, _ := uuid.New().MarshalBinary() - msg := &messagingpb.MessageId{ - Value: msgValue, - } - - msgBytes, err := proto.Marshal(msg) - require.NoError(t, err) - - signature, err := ownerAccount.Sign(msgBytes) - require.NoError(t, err) - signatureProto := &commonpb.Signature{ - Value: signature, + for _, marshalStrategy := range defaultMarshalStrategies { + env := setup(t) + + ownerAccount := testutil.NewRandomAccount(t) + maliciousAccount := testutil.NewRandomAccount(t) + + msgValue, _ := uuid.New().MarshalBinary() + msg := &messagingpb.MessageId{ + Value: msgValue, + } + + msgBytes, err := marshalStrategy(msg) + require.NoError(t, err) + + signature, err := ownerAccount.Sign(msgBytes) + require.NoError(t, err) + signatureProto := &commonpb.Signature{ + Value: signature, + } + + err = env.verifier.Authenticate(env.ctx, ownerAccount, msg, signatureProto) + require.NoError(t, err) + + signature, err = maliciousAccount.Sign(msgBytes) + require.NoError(t, err) + signatureProto = &commonpb.Signature{ + Value: signature, + } + + err = env.verifier.Authenticate(env.ctx, ownerAccount, msg, signatureProto) + assert.Error(t, err) + testutil.AssertStatusErrorWithCode(t, err, codes.Unauthenticated) } - - err = env.verifier.Authenticate(env.ctx, ownerAccount, msg, signatureProto) - require.NoError(t, err) - - signature, err = maliciousAccount.Sign(msgBytes) - require.NoError(t, err) - signatureProto = &commonpb.Signature{ - Value: signature, - } - - err = env.verifier.Authenticate(env.ctx, ownerAccount, msg, signatureProto) - assert.Error(t, err) - testutil.AssertStatusErrorWithCode(t, err, codes.Unauthenticated) } func TestAuthorizeDataAccess(t *testing.T) { - env := setup(t) - - dataContainerID := user.NewDataContainerID() - phoneNumber := "+11234567890" - - ownerAccount := testutil.NewRandomAccount(t) - - maliciousAccount := testutil.NewRandomAccount(t) - - msgValue, _ := uuid.New().MarshalBinary() - msg := &messagingpb.MessageId{ - Value: msgValue, - } - - msgBytes, err := proto.Marshal(msg) - require.NoError(t, err) - - signature, err := ownerAccount.Sign(msgBytes) - require.NoError(t, err) - signatureProto := &commonpb.Signature{ - Value: signature, - } - - // Data container doesn't exist - err = env.verifier.AuthorizeDataAccess(env.ctx, dataContainerID, ownerAccount, msg, signatureProto) - assert.Error(t, err) - testutil.AssertStatusErrorWithCode(t, err, codes.PermissionDenied) - - require.NoError(t, env.data.PutUserDataContainer(env.ctx, &storage.Record{ - ID: dataContainerID, - OwnerAccount: ownerAccount.PublicKey().ToBase58(), - IdentifyingFeatures: &user.IdentifyingFeatures{ - PhoneNumber: &phoneNumber, - }, - CreatedAt: time.Now(), - })) - - // Successful authorization - err = env.verifier.AuthorizeDataAccess(env.ctx, dataContainerID, ownerAccount, msg, signatureProto) - assert.NoError(t, err) - - signature, err = maliciousAccount.Sign(msgBytes) - require.NoError(t, err) - signatureProto = &commonpb.Signature{ - Value: signature, + for _, marshalStrategy := range defaultMarshalStrategies { + env := setup(t) + + dataContainerID := user.NewDataContainerID() + phoneNumber := "+11234567890" + + ownerAccount := testutil.NewRandomAccount(t) + + maliciousAccount := testutil.NewRandomAccount(t) + + msgValue, _ := uuid.New().MarshalBinary() + msg := &messagingpb.MessageId{ + Value: msgValue, + } + + msgBytes, err := marshalStrategy(msg) + require.NoError(t, err) + + signature, err := ownerAccount.Sign(msgBytes) + require.NoError(t, err) + signatureProto := &commonpb.Signature{ + Value: signature, + } + + // Data container doesn't exist + err = env.verifier.AuthorizeDataAccess(env.ctx, dataContainerID, ownerAccount, msg, signatureProto) + assert.Error(t, err) + testutil.AssertStatusErrorWithCode(t, err, codes.PermissionDenied) + + require.NoError(t, env.data.PutUserDataContainer(env.ctx, &storage.Record{ + ID: dataContainerID, + OwnerAccount: ownerAccount.PublicKey().ToBase58(), + IdentifyingFeatures: &user.IdentifyingFeatures{ + PhoneNumber: &phoneNumber, + }, + CreatedAt: time.Now(), + })) + + // Successful authorization + err = env.verifier.AuthorizeDataAccess(env.ctx, dataContainerID, ownerAccount, msg, signatureProto) + assert.NoError(t, err) + + signature, err = maliciousAccount.Sign(msgBytes) + require.NoError(t, err) + signatureProto = &commonpb.Signature{ + Value: signature, + } + + // Token account doesn't own data container + err = env.verifier.AuthorizeDataAccess(env.ctx, dataContainerID, maliciousAccount, msg, signatureProto) + assert.Error(t, err) + testutil.AssertStatusErrorWithCode(t, err, codes.PermissionDenied) + + // Signature doesn't match public key + err = env.verifier.AuthorizeDataAccess(env.ctx, dataContainerID, ownerAccount, msg, signatureProto) + assert.Error(t, err) + testutil.AssertStatusErrorWithCode(t, err, codes.Unauthenticated) } - - // Token account doesn't own data container - err = env.verifier.AuthorizeDataAccess(env.ctx, dataContainerID, maliciousAccount, msg, signatureProto) - assert.Error(t, err) - testutil.AssertStatusErrorWithCode(t, err, codes.PermissionDenied) - - // Signature doesn't match public key - err = env.verifier.AuthorizeDataAccess(env.ctx, dataContainerID, ownerAccount, msg, signatureProto) - assert.Error(t, err) - testutil.AssertStatusErrorWithCode(t, err, codes.Unauthenticated) } From cfd79716c26fddd53093f93ef518f6f4d2a3470a Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Tue, 9 Jan 2024 12:01:06 -0500 Subject: [PATCH 2/2] Add SDK parity tests for proto encoding --- pkg/code/auth/encoding_test.go | 40 +++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/pkg/code/auth/encoding_test.go b/pkg/code/auth/encoding_test.go index c2c845f9..3d2dc994 100644 --- a/pkg/code/auth/encoding_test.go +++ b/pkg/code/auth/encoding_test.go @@ -11,7 +11,7 @@ import ( messagingpb "github.com/code-payments/code-protobuf-api/generated/go/messaging/v1" ) -func TestCrossLanguageProtoEncoding(t *testing.T) { +func TestProtoEncoding_CrossLanguageSupport(t *testing.T) { goValue := "CtYBKtMBCiIKIKuIZy+UTqRbPCXCbXMtLl5A1cfBYsPNaFjVyVoj+jRVIhEKD2FwcC5nZXRjb2RlLmNvbSoiCiCQXMeWrnoZmEYYNs2fWQUviipSzVObQX5XfBGCg9KgbDJCCkBgpVkQnlTv9ackQCHPV39NBCHOKh0N5n8gSwQ7Hz8nFldMcdI+TbF+9foOcW/0g+DSnR5kbxbRYEWuRTKo5O8BOiIKIHCkrraPdjY/ImaB3xZiv8D2Qjbpenpkh0Zqk5lUXnr7Gg4KA3VzZBEAAAAAAADgPxIiCiBwpK62j3Y2PyJmgd8WYr/A9kI26Xp6ZIdGapOZVF56+xpCCkDU3CnRyHQ4w0O5D5eIqizAoaBwDft+RjWsGl+Wzo+jCGyE7u+Siw4uZT7U4VcLV6lcsfe9XeB66E7RYlmwAv0I" otherLanguageValue := "CtYBKtMBCiIKIKuIZy+UTqRbPCXCbXMtLl5A1cfBYsPNaFjVyVoj+jRVGg4KA3VzZBEAAAAAAADgPyIRCg9hcHAuZ2V0Y29kZS5jb20qIgogkFzHlq56GZhGGDbNn1kFL4oqUs1Tm0F+V3wRgoPSoGwyQgpAYKVZEJ5U7/WnJEAhz1d/TQQhziodDeZ/IEsEOx8/JxZXTHHSPk2xfvX6DnFv9IPg0p0eZG8W0WBFrkUyqOTvAToiCiBwpK62j3Y2PyJmgd8WYr/A9kI26Xp6ZIdGapOZVF56+xIiCiBwpK62j3Y2PyJmgd8WYr/A9kI26Xp6ZIdGapOZVF56+xpCCkDU3CnRyHQ4w0O5D5eIqizAoaBwDft+RjWsGl+Wzo+jCGyE7u+Siw4uZT7U4VcLV6lcsfe9XeB66E7RYlmwAv0I" @@ -45,3 +45,41 @@ func TestCrossLanguageProtoEncoding(t *testing.T) { assert.Equal(t, otherLanguageValue, base64.StdEncoding.EncodeToString(marshalled)) } } + +func TestProtoEncoding_SDKTestParity(t *testing.T) { + expected := []byte{ + 0x2a, 0xd3, 0x01, 0x0a, 0x22, 0x0a, 0x20, 0xab, 0x88, 0x67, + 0x2f, 0x94, 0x4e, 0xa4, 0x5b, 0x3c, 0x25, 0xc2, 0x6d, 0x73, + 0x2d, 0x2e, 0x5e, 0x40, 0xd5, 0xc7, 0xc1, 0x62, 0xc3, 0xcd, + 0x68, 0x58, 0xd5, 0xc9, 0x5a, 0x23, 0xfa, 0x34, 0x55, 0x1a, + 0x0e, 0x0a, 0x03, 0x75, 0x73, 0x64, 0x11, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0xe0, 0x3f, 0x22, 0x11, 0x0a, 0x0f, 0x61, + 0x70, 0x70, 0x2e, 0x67, 0x65, 0x74, 0x63, 0x6f, 0x64, 0x65, + 0x2e, 0x63, 0x6f, 0x6d, 0x2a, 0x22, 0x0a, 0x20, 0x90, 0x5c, + 0xc7, 0x96, 0xae, 0x7a, 0x19, 0x98, 0x46, 0x18, 0x36, 0xcd, + 0x9f, 0x59, 0x05, 0x2f, 0x8a, 0x2a, 0x52, 0xcd, 0x53, 0x9b, + 0x41, 0x7e, 0x57, 0x7c, 0x11, 0x82, 0x83, 0xd2, 0xa0, 0x6c, + 0x32, 0x42, 0x0a, 0x40, 0xec, 0x47, 0x69, 0x3f, 0xc4, 0xd2, + 0x6a, 0x35, 0x49, 0xfb, 0xbf, 0x57, 0xd3, 0x20, 0xa6, 0x1b, + 0x91, 0x40, 0x94, 0x89, 0x69, 0x24, 0x43, 0xbc, 0x42, 0xcb, + 0xe8, 0xe0, 0x2b, 0x92, 0xc3, 0x23, 0x8b, 0xb0, 0x93, 0x17, + 0x32, 0xa6, 0xf5, 0xe5, 0x3a, 0xd5, 0xca, 0xb1, 0x62, 0x34, + 0x83, 0x44, 0x60, 0x75, 0x9f, 0xc6, 0x9c, 0xc9, 0xbf, 0x03, + 0x64, 0xbe, 0xb7, 0x62, 0x65, 0x3e, 0xf8, 0x0e, 0x3a, 0x22, + 0x0a, 0x20, 0x70, 0xa4, 0xae, 0xb6, 0x8f, 0x76, 0x36, 0x3f, + 0x22, 0x66, 0x81, 0xdf, 0x16, 0x62, 0xbf, 0xc0, 0xf6, 0x42, + 0x36, 0xe9, 0x7a, 0x7a, 0x64, 0x87, 0x46, 0x6a, 0x93, 0x99, + 0x54, 0x5e, 0x7a, 0xfb, + } + + var msg messagingpb.Message + require.NoError(t, proto.Unmarshal(expected, &msg)) + + marshalled, err := proto.Marshal(&msg) + require.NoError(t, err) + assert.NotEqual(t, marshalled, expected) + + marshalled, err = forceConsistentMarshal(&msg) + require.NoError(t, err) + assert.Equal(t, marshalled, expected) +}