diff --git a/jsonpb/jsonpb.go b/jsonpb/jsonpb.go index 91dcd26af5..71985e64ff 100644 --- a/jsonpb/jsonpb.go +++ b/jsonpb/jsonpb.go @@ -233,12 +233,14 @@ func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeU } // Handle proto2 extensions. - if ep, ok := v.(extendableProto); ok { + if ep, ok := v.(proto.Message); ok { extensions := proto.RegisteredExtensions(v) - extensionMap := ep.ExtensionMap() // Sort extensions for stable output. - ids := make([]int32, 0, len(extensionMap)) - for id := range extensionMap { + ids := make([]int32, 0, len(extensions)) + for id, desc := range extensions { + if !proto.HasExtension(ep, desc) { + continue + } ids = append(ids, id) } sort.Sort(int32Slice(ids)) @@ -767,13 +769,6 @@ func acceptedJSONFieldNames(prop *proto.Properties) fieldNames { return opts } -// extendableProto is an interface implemented by any protocol buffer that may be extended. -type extendableProto interface { - proto.Message - ExtensionRangeArray() []proto.ExtensionRange - ExtensionMap() map[int32]proto.Extension -} - // Writer wrapper inspired by https://blog.golang.org/errors-are-values type errWriter struct { writer io.Writer diff --git a/proto/clone.go b/proto/clone.go index e98ddec981..e392575b35 100644 --- a/proto/clone.go +++ b/proto/clone.go @@ -84,9 +84,15 @@ func mergeStruct(out, in reflect.Value) { mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i]) } - if emIn, ok := in.Addr().Interface().(extendableProto); ok { - emOut := out.Addr().Interface().(extendableProto) - mergeExtension(emOut.ExtensionMap(), emIn.ExtensionMap()) + if emIn, ok := extendable(in.Addr().Interface()); ok { + emOut, _ := extendable(out.Addr().Interface()) + mIn, muIn := emIn.extensionsRead() + if mIn != nil { + mOut := emOut.extensionsWrite() + muIn.Lock() + mergeExtension(mOut, mIn) + muIn.Unlock() + } } uf := in.FieldByName("XXX_unrecognized") diff --git a/proto/decode.go b/proto/decode.go index f94b9f416e..07288a250b 100644 --- a/proto/decode.go +++ b/proto/decode.go @@ -390,11 +390,12 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group if !ok { // Maybe it's an extension? if prop.extendable { - if e := structPointer_Interface(base, st).(extendableProto); isExtensionField(e, int32(tag)) { + if e, _ := extendable(structPointer_Interface(base, st)); isExtensionField(e, int32(tag)) { if err = o.skip(st, tag, wire); err == nil { - ext := e.ExtensionMap()[int32(tag)] // may be missing + extmap := e.extensionsWrite() + ext := extmap[int32(tag)] // may be missing ext.enc = append(ext.enc, o.buf[oi:o.index]...) - e.ExtensionMap()[int32(tag)] = ext + extmap[int32(tag)] = ext } continue } diff --git a/proto/encode.go b/proto/encode.go index 401c1143c8..1b5578d9db 100644 --- a/proto/encode.go +++ b/proto/encode.go @@ -1073,10 +1073,25 @@ func size_slice_struct_group(p *Properties, base structPointer) (n int) { // Encode an extension map. func (o *Buffer) enc_map(p *Properties, base structPointer) error { - v := *structPointer_ExtMap(base, p.field) - if err := encodeExtensionMap(v); err != nil { + exts := structPointer_ExtMap(base, p.field) + if err := encodeExtensionsMap(*exts); err != nil { return err } + + return o.enc_map_body(*exts) +} + +func (o *Buffer) enc_exts(p *Properties, base structPointer) error { + exts := structPointer_Extensions(base, p.field) + if err := encodeExtensions(exts); err != nil { + return err + } + v, _ := exts.extensionsRead() + + return o.enc_map_body(v) +} + +func (o *Buffer) enc_map_body(v map[int32]Extension) error { // Fast-path for common cases: zero or one extensions. if len(v) <= 1 { for _, e := range v { @@ -1099,8 +1114,13 @@ func (o *Buffer) enc_map(p *Properties, base structPointer) error { } func size_map(p *Properties, base structPointer) int { - v := *structPointer_ExtMap(base, p.field) - return sizeExtensionMap(v) + v := structPointer_ExtMap(base, p.field) + return extensionsMapSize(*v) +} + +func size_exts(p *Properties, base structPointer) int { + v := structPointer_Extensions(base, p.field) + return extensionsSize(v) } // Encode a map field. diff --git a/proto/equal.go b/proto/equal.go index f5db1def3c..8c6fa85371 100644 --- a/proto/equal.go +++ b/proto/equal.go @@ -121,9 +121,9 @@ func equalStruct(v1, v2 reflect.Value) bool { } } - if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() { - em2 := v2.FieldByName("XXX_extensions") - if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) { + if em1 := v1.FieldByName("XXX_InternalExtensions"); em1.IsValid() { + em2 := v2.FieldByName("XXX_InternalExtensions") + if !equalExtensions(v1.Type(), em1.Interface().(XXX_InternalExtensions), em2.Interface().(XXX_InternalExtensions)) { return false } } @@ -223,8 +223,10 @@ func equalAny(v1, v2 reflect.Value, prop *Properties) bool { } // base is the struct type that the extensions are based on. -// em1 and em2 are extension maps. -func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool { +// x1 and x2 are InternalExtensions. +func equalExtensions(base reflect.Type, x1, x2 XXX_InternalExtensions) bool { + em1, _ := x1.extensionsRead() + em2, _ := x2.extensionsRead() if len(em1) != len(em2) { return false } diff --git a/proto/extensions.go b/proto/extensions.go index 0de0b424f8..9f484f53a3 100644 --- a/proto/extensions.go +++ b/proto/extensions.go @@ -52,14 +52,99 @@ type ExtensionRange struct { Start, End int32 // both inclusive } -// extendableProto is an interface implemented by any protocol buffer that may be extended. +// extendableProto is an interface implemented by any protocol buffer generated by the current +// proto compiler that may be extended. type extendableProto interface { + Message + ExtensionRangeArray() []ExtensionRange + extensionsWrite() map[int32]Extension + extensionsRead() (map[int32]Extension, sync.Locker) +} + +// extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous +// version of the proto compiler that may be extended. +type extendableProtoV1 interface { Message ExtensionRangeArray() []ExtensionRange ExtensionMap() map[int32]Extension } +// extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto. +type extensionAdapter struct { + extendableProtoV1 +} + +func (e extensionAdapter) extensionsWrite() map[int32]Extension { + return e.ExtensionMap() +} + +func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) { + return e.ExtensionMap(), notLocker{} +} + +// notLocker is a sync.Locker whose Lock and Unlock methods are nops. +type notLocker struct{} + +func (n notLocker) Lock() {} +func (n notLocker) Unlock() {} + +// extendable returns the extendableProto interface for the given generated proto message. +// If the proto message has the old extension format, it returns a wrapper that implements +// the extendableProto interface. +func extendable(p interface{}) (extendableProto, bool) { + if ep, ok := p.(extendableProto); ok { + return ep, ok + } + if ep, ok := p.(extendableProtoV1); ok { + return extensionAdapter{ep}, ok + } + return nil, false +} + +// XXX_InternalExtensions is an internal representation of proto extensions. +// +// Each generated message struct type embeds an anonymous XXX_InternalExtensions field, +// thus gaining the unexported 'extensions' method, which can be called only from the proto package. +// +// The methods of XXX_InternalExtensions are not concurrency safe in general, +// but calls to logically read-only methods such as has and get may be executed concurrently. +type XXX_InternalExtensions struct { + // The struct must be indirect so that if a user inadvertently copies a + // generated message and its embedded XXX_InternalExtensions, they + // avoid the mayhem of a copied mutex. + // + // The mutex serializes all logically read-only operations to p.extensionMap. + // It is up to the client to ensure that write operations to p.extensionMap are + // mutually exclusive with other accesses. + p *struct { + mu sync.Mutex + extensionMap map[int32]Extension + } +} + +// extensionsWrite returns the extension map, creating it on first use. +func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension { + if e.p == nil { + e.p = new(struct { + mu sync.Mutex + extensionMap map[int32]Extension + }) + e.p.extensionMap = make(map[int32]Extension) + } + return e.p.extensionMap +} + +// extensionsRead returns the extensions map for read-only use. It may be nil. +// The caller must hold the returned mutex's lock when accessing Elements within the map. +func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) { + if e.p == nil { + return nil, nil + } + return e.p.extensionMap, &e.p.mu +} + var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem() +var extendableProtoV1Type = reflect.TypeOf((*extendableProtoV1)(nil)).Elem() // ExtensionDesc represents an extension specification. // Used in generated code from the protocol compiler. @@ -93,11 +178,12 @@ type Extension struct { // SetRawExtension is for testing only. func SetRawExtension(base Message, id int32, b []byte) { - epb, ok := base.(extendableProto) + epb, ok := extendable(base) if !ok { return } - epb.ExtensionMap()[id] = Extension{enc: b} + extmap := epb.extensionsWrite() + extmap[id] = Extension{enc: b} } // isExtensionField returns true iff the given field number is in an extension range. @@ -112,8 +198,12 @@ func isExtensionField(pb extendableProto, field int32) bool { // checkExtensionTypes checks that the given extension is valid for pb. func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error { + var pbi interface{} = pb // Check the extended type. - if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b { + if ea, ok := pbi.(extensionAdapter); ok { + pbi = ea.extendableProtoV1 + } + if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b { return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String()) } // Check the range. @@ -159,8 +249,19 @@ func extensionProperties(ed *ExtensionDesc) *Properties { return prop } -// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m. -func encodeExtensionMap(m map[int32]Extension) error { +// encode encodes any unmarshaled (unencoded) extensions in e. +func encodeExtensions(e *XXX_InternalExtensions) error { + m, mu := e.extensionsRead() + if m == nil { + return nil // fast path + } + mu.Lock() + defer mu.Unlock() + return encodeExtensionsMap(m) +} + +// encode encodes any unmarshaled (unencoded) extensions in e. +func encodeExtensionsMap(m map[int32]Extension) error { for k, e := range m { if e.value == nil || e.desc == nil { // Extension is only in its encoded form. @@ -188,7 +289,17 @@ func encodeExtensionMap(m map[int32]Extension) error { return nil } -func sizeExtensionMap(m map[int32]Extension) (n int) { +func extensionsSize(e *XXX_InternalExtensions) (n int) { + m, mu := e.extensionsRead() + if m == nil { + return 0 + } + mu.Lock() + defer mu.Unlock() + return extensionsMapSize(m) +} + +func extensionsMapSize(m map[int32]Extension) (n int) { for _, e := range m { if e.value == nil || e.desc == nil { // Extension is only in its encoded form. @@ -215,28 +326,35 @@ func sizeExtensionMap(m map[int32]Extension) (n int) { // HasExtension returns whether the given extension is present in pb. func HasExtension(pb Message, extension *ExtensionDesc) bool { // TODO: Check types, field numbers, etc.? - epb, ok := pb.(extendableProto) + epb, ok := extendable(pb) if !ok { return false } - _, ok = epb.ExtensionMap()[extension.Field] + extmap, mu := epb.extensionsRead() + if extmap == nil { + return false + } + mu.Lock() + _, ok = extmap[extension.Field] + mu.Unlock() return ok } // ClearExtension removes the given extension from pb. func ClearExtension(pb Message, extension *ExtensionDesc) { - epb, ok := pb.(extendableProto) + epb, ok := extendable(pb) if !ok { return } // TODO: Check types, field numbers, etc.? - delete(epb.ExtensionMap(), extension.Field) + extmap := epb.extensionsWrite() + delete(extmap, extension.Field) } // GetExtension parses and returns the given extension of pb. // If the extension is not present and has no default value it returns ErrMissingExtension. func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) { - epb, ok := pb.(extendableProto) + epb, ok := extendable(pb) if !ok { return nil, errors.New("proto: not an extendable proto") } @@ -245,7 +363,12 @@ func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) { return nil, err } - emap := epb.ExtensionMap() + emap, mu := epb.extensionsRead() + if emap == nil { + return defaultExtensionValue(extension) + } + mu.Lock() + defer mu.Unlock() e, ok := emap[extension.Field] if !ok { // defaultExtensionValue returns the default value or @@ -349,7 +472,7 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) { // GetExtensions returns a slice of the extensions present in pb that are also listed in es. // The returned slice has the same length as es; missing extensions will appear as nil elements. func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) { - epb, ok := pb.(extendableProto) + epb, ok := extendable(pb) if !ok { return nil, errors.New("proto: not an extendable proto") } @@ -368,7 +491,7 @@ func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, e // SetExtension sets the specified extension of pb to the specified value. func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error { - epb, ok := pb.(extendableProto) + epb, ok := extendable(pb) if !ok { return errors.New("proto: not an extendable proto") } @@ -388,17 +511,18 @@ func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error return fmt.Errorf("proto: SetExtension called with nil value of type %T", value) } - epb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value} + extmap := epb.extensionsWrite() + extmap[extension.Field] = Extension{desc: extension, value: value} return nil } // ClearAllExtensions clears all extensions from pb. func ClearAllExtensions(pb Message) { - epb, ok := pb.(extendableProto) + epb, ok := extendable(pb) if !ok { return } - m := epb.ExtensionMap() + m := epb.extensionsWrite() for k := range m { delete(m, k) } diff --git a/proto/lib.go b/proto/lib.go index 0de8f8dffd..170b8e87d2 100644 --- a/proto/lib.go +++ b/proto/lib.go @@ -889,6 +889,10 @@ func isProto3Zero(v reflect.Value) bool { return false } +// ProtoPackageIsVersion2 is referenced from generated protocol buffer files +// to assert that that code is compatible with this version of the proto package. +const ProtoPackageIsVersion2 = true + // ProtoPackageIsVersion1 is referenced from generated protocol buffer files // to assert that that code is compatible with this version of the proto package. const ProtoPackageIsVersion1 = true diff --git a/proto/message_set.go b/proto/message_set.go index e25e01e637..fd982decd6 100644 --- a/proto/message_set.go +++ b/proto/message_set.go @@ -149,9 +149,21 @@ func skipVarint(buf []byte) []byte { // MarshalMessageSet encodes the extension map represented by m in the message set wire format. // It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option. -func MarshalMessageSet(m map[int32]Extension) ([]byte, error) { - if err := encodeExtensionMap(m); err != nil { - return nil, err +func MarshalMessageSet(exts interface{}) ([]byte, error) { + var m map[int32]Extension + switch exts := exts.(type) { + case *XXX_InternalExtensions: + if err := encodeExtensions(exts); err != nil { + return nil, err + } + m, _ = exts.extensionsRead() + case map[int32]Extension: + if err := encodeExtensionsMap(exts); err != nil { + return nil, err + } + m = exts + default: + return nil, errors.New("proto: not an extension map") } // Sort extension IDs to provide a deterministic encoding. @@ -178,7 +190,17 @@ func MarshalMessageSet(m map[int32]Extension) ([]byte, error) { // UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format. // It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option. -func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error { +func UnmarshalMessageSet(buf []byte, exts interface{}) error { + var m map[int32]Extension + switch exts := exts.(type) { + case *XXX_InternalExtensions: + m = exts.extensionsWrite() + case map[int32]Extension: + m = exts + default: + return errors.New("proto: not an extension map") + } + ms := new(messageSet) if err := Unmarshal(buf, ms); err != nil { return err @@ -209,7 +231,16 @@ func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error { // MarshalMessageSetJSON encodes the extension map represented by m in JSON format. // It is called by generated MarshalJSON methods on protocol buffer messages with the message_set_wire_format option. -func MarshalMessageSetJSON(m map[int32]Extension) ([]byte, error) { +func MarshalMessageSetJSON(exts interface{}) ([]byte, error) { + var m map[int32]Extension + switch exts := exts.(type) { + case *XXX_InternalExtensions: + m, _ = exts.extensionsRead() + case map[int32]Extension: + m = exts + default: + return nil, errors.New("proto: not an extension map") + } var b bytes.Buffer b.WriteByte('{') @@ -252,7 +283,7 @@ func MarshalMessageSetJSON(m map[int32]Extension) ([]byte, error) { // UnmarshalMessageSetJSON decodes the extension map encoded in buf in JSON format. // It is called by generated UnmarshalJSON methods on protocol buffer messages with the message_set_wire_format option. -func UnmarshalMessageSetJSON(buf []byte, m map[int32]Extension) error { +func UnmarshalMessageSetJSON(buf []byte, exts interface{}) error { // Common-case fast path. if len(buf) == 0 || bytes.Equal(buf, []byte("{}")) { return nil diff --git a/proto/message_set_test.go b/proto/message_set_test.go index ab8ac2f0eb..353a3ea769 100644 --- a/proto/message_set_test.go +++ b/proto/message_set_test.go @@ -50,13 +50,13 @@ func TestUnmarshalMessageSetWithDuplicate(t *testing.T) { } t.Logf("Marshaled bytes: %q", b) - m := make(map[int32]Extension) - if err := UnmarshalMessageSet(b, m); err != nil { + var extensions XXX_InternalExtensions + if err := UnmarshalMessageSet(b, &extensions); err != nil { t.Fatalf("UnmarshalMessageSet: %v", err) } - ext, ok := m[12345] + ext, ok := extensions.p.extensionMap[12345] if !ok { - t.Fatalf("Didn't retrieve extension 12345; map is %v", m) + t.Fatalf("Didn't retrieve extension 12345; map is %v", extensions.p.extensionMap) } // Skip wire type/field number and length varints. got := skipVarint(skipVarint(ext.enc)) diff --git a/proto/pointer_reflect.go b/proto/pointer_reflect.go index 989914177d..fb512e2e16 100644 --- a/proto/pointer_reflect.go +++ b/proto/pointer_reflect.go @@ -139,6 +139,11 @@ func structPointer_StringSlice(p structPointer, f field) *[]string { return structPointer_ifield(p, f).(*[]string) } +// Extensions returns the address of an extension map field in the struct. +func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions { + return structPointer_ifield(p, f).(*XXX_InternalExtensions) +} + // ExtMap returns the address of an extension map field in the struct. func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension { return structPointer_ifield(p, f).(*map[int32]Extension) diff --git a/proto/pointer_unsafe.go b/proto/pointer_unsafe.go index ceece772a2..6b5567d47c 100644 --- a/proto/pointer_unsafe.go +++ b/proto/pointer_unsafe.go @@ -126,6 +126,10 @@ func structPointer_StringSlice(p structPointer, f field) *[]string { } // ExtMap returns the address of an extension map field in the struct. +func structPointer_Extensions(p structPointer, f field) *XXX_InternalExtensions { + return (*XXX_InternalExtensions)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension { return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f))) } diff --git a/proto/properties.go b/proto/properties.go index 880eb22d8f..39edea32fd 100644 --- a/proto/properties.go +++ b/proto/properties.go @@ -682,7 +682,8 @@ func getPropertiesLocked(t reflect.Type) *StructProperties { propertiesMap[t] = prop // build properties - prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType) + prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType) || + reflect.PtrTo(t).Implements(extendableProtoV1Type) prop.unrecField = invalidField prop.Prop = make([]*Properties, t.NumField()) prop.order = make([]int, t.NumField()) @@ -693,12 +694,15 @@ func getPropertiesLocked(t reflect.Type) *StructProperties { name := f.Name p.init(f.Type, name, f.Tag.Get("protobuf"), &f, false) - if f.Name == "XXX_extensions" { // special case + if f.Name == "XXX_InternalExtensions" { // special case + p.enc = (*Buffer).enc_exts + p.dec = nil // not needed + p.size = size_exts + } else if f.Name == "XXX_extensions" { // special case p.enc = (*Buffer).enc_map p.dec = nil // not needed p.size = size_map - } - if f.Name == "XXX_unrecognized" { // special case + } else if f.Name == "XXX_unrecognized" { // special case prop.unrecField = toField(&f) } oneof := f.Tag.Get("protobuf_oneof") // special case diff --git a/proto/text.go b/proto/text.go index 37c953570d..bd6f1ae5e4 100644 --- a/proto/text.go +++ b/proto/text.go @@ -455,7 +455,7 @@ func (tm *TextMarshaler) writeStruct(w *textWriter, sv reflect.Value) error { // Extensions (the XXX_extensions field). pv := sv.Addr() - if pv.Type().Implements(extendableProtoType) { + if _, ok := extendable(pv.Interface()); ok { if err := tm.writeExtensions(w, pv); err != nil { return err } @@ -689,17 +689,22 @@ func (s int32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } // pv is assumed to be a pointer to a protocol message struct that is extendable. func (tm *TextMarshaler) writeExtensions(w *textWriter, pv reflect.Value) error { emap := extensionMaps[pv.Type().Elem()] - ep := pv.Interface().(extendableProto) + ep, _ := extendable(pv.Interface()) // Order the extensions by ID. // This isn't strictly necessary, but it will give us // canonical output, which will also make testing easier. - m := ep.ExtensionMap() + m, mu := ep.extensionsRead() + if m == nil { + return nil + } + mu.Lock() ids := make([]int32, 0, len(m)) for id := range m { ids = append(ids, id) } sort.Sort(int32Slice(ids)) + mu.Unlock() for _, extNum := range ids { ext := m[extNum] diff --git a/proto/text_parser.go b/proto/text_parser.go index b5fba5b2ae..8fa9f3af6e 100644 --- a/proto/text_parser.go +++ b/proto/text_parser.go @@ -550,7 +550,7 @@ func (p *textParser) readStruct(sv reflect.Value, terminator string) error { } reqFieldErr = err } - ep := sv.Addr().Interface().(extendableProto) + ep := sv.Addr().Interface().(Message) if !rep { SetExtension(ep, desc, ext.Interface()) } else { diff --git a/protoc-gen-go/generator/generator.go b/protoc-gen-go/generator/generator.go index 2994cc96b4..526919fdeb 100644 --- a/protoc-gen-go/generator/generator.go +++ b/protoc-gen-go/generator/generator.go @@ -62,7 +62,7 @@ import ( // It is incremented whenever an incompatibility between the generated code and // proto package is introduced; the generated code references // a constant, proto.ProtoPackageIsVersionN (where N is generatedCodeVersion). -const generatedCodeVersion = 1 +const generatedCodeVersion = 2 // A Plugin provides functionality to add to the output during Go code generation, // such as to produce RPC stubs. @@ -363,8 +363,6 @@ func (ms *messageSymbol) GenerateAlias(g *Generator, pkg string) { if ms.hasExtensions { g.P("func (*", ms.sym, ") ExtensionRangeArray() []", g.Pkg["proto"], ".ExtensionRange ", "{ return (*", remoteSym, ")(nil).ExtensionRangeArray() }") - g.P("func (m *", ms.sym, ") ExtensionMap() map[int32]", g.Pkg["proto"], ".Extension ", - "{ return (*", remoteSym, ")(m).ExtensionMap() }") if ms.isMessageSet { g.P("func (m *", ms.sym, ") Marshal() ([]byte, error) ", "{ return (*", remoteSym, ")(m).Marshal() }") @@ -1173,7 +1171,9 @@ func (g *Generator) generate(file *FileDescriptor) { // For one file in the package, assert version compatibility. g.P("// This is a compile-time assertion to ensure that this generated file") g.P("// is compatible with the proto package it is being compiled against.") - g.P("const _ = ", g.Pkg["proto"], ".ProtoPackageIsVersion", generatedCodeVersion) + g.P("// A compilation error at this line likely means your copy of the") + g.P("// proto package needs to be updated.") + g.P("const _ = ", g.Pkg["proto"], ".ProtoPackageIsVersion", generatedCodeVersion, " // please upgrade the proto package") g.P() } @@ -1684,7 +1684,8 @@ func (g *Generator) RecordTypeUse(t string) { } // Method names that may be generated. Fields with these names get an -// underscore appended. +// underscore appended. Any change to this set is a potential incompatible +// API change because it changes generated field names. var methodNames = [...]string{ "Reset", "String", @@ -1869,7 +1870,7 @@ func (g *Generator) generateMessage(message *Descriptor) { g.RecordTypeUse(field.GetTypeName()) } if len(message.ExtensionRange) > 0 { - g.P("XXX_extensions\t\tmap[int32]", g.Pkg["proto"], ".Extension `json:\"-\"`") + g.P(g.Pkg["proto"], ".XXX_InternalExtensions `json:\"-\"`") } if !message.proto3() { g.P("XXX_unrecognized\t[]byte `json:\"-\"`") @@ -1919,22 +1920,22 @@ func (g *Generator) generateMessage(message *Descriptor) { g.P() g.P("func (m *", ccTypeName, ") Marshal() ([]byte, error) {") g.In() - g.P("return ", g.Pkg["proto"], ".MarshalMessageSet(m.ExtensionMap())") + g.P("return ", g.Pkg["proto"], ".MarshalMessageSet(&m.XXX_InternalExtensions)") g.Out() g.P("}") g.P("func (m *", ccTypeName, ") Unmarshal(buf []byte) error {") g.In() - g.P("return ", g.Pkg["proto"], ".UnmarshalMessageSet(buf, m.ExtensionMap())") + g.P("return ", g.Pkg["proto"], ".UnmarshalMessageSet(buf, &m.XXX_InternalExtensions)") g.Out() g.P("}") g.P("func (m *", ccTypeName, ") MarshalJSON() ([]byte, error) {") g.In() - g.P("return ", g.Pkg["proto"], ".MarshalMessageSetJSON(m.XXX_extensions)") + g.P("return ", g.Pkg["proto"], ".MarshalMessageSetJSON(&m.XXX_InternalExtensions)") g.Out() g.P("}") g.P("func (m *", ccTypeName, ") UnmarshalJSON(buf []byte) error {") g.In() - g.P("return ", g.Pkg["proto"], ".UnmarshalMessageSetJSON(buf, m.XXX_extensions)") + g.P("return ", g.Pkg["proto"], ".UnmarshalMessageSetJSON(buf, &m.XXX_InternalExtensions)") g.Out() g.P("}") g.P("// ensure ", ccTypeName, " satisfies proto.Marshaler and proto.Unmarshaler") @@ -1956,16 +1957,6 @@ func (g *Generator) generateMessage(message *Descriptor) { g.P("return extRange_", ccTypeName) g.Out() g.P("}") - g.P("func (m *", ccTypeName, ") ExtensionMap() map[int32]", g.Pkg["proto"], ".Extension {") - g.In() - g.P("if m.XXX_extensions == nil {") - g.In() - g.P("m.XXX_extensions = make(map[int32]", g.Pkg["proto"], ".Extension)") - g.Out() - g.P("}") - g.P("return m.XXX_extensions") - g.Out() - g.P("}") } // Default constants