diff --git a/protoc-gen-go/generator/bson_extensions.go b/protoc-gen-go/generator/bson_extensions.go new file mode 100644 index 0000000000..3ffd97e2d0 --- /dev/null +++ b/protoc-gen-go/generator/bson_extensions.go @@ -0,0 +1,74 @@ +package generator + +import ( + "fmt" + "os" + "regexp" + "strings" +) + +const ( + bsonTagPattern = "@bson_tag: (.*)" + bsonCompatiblePattern = "@bson_compatible" + bsonUpsertablePattern = "@bson_upsertable" + goInjectPattern = `(?s)@go_inject\s(.+)` +) + +var bsonTagRegex, bsonCompatibleRegex, bsonUpsertableRegex, goInjectRegex *regexp.Regexp + +func init() { + bsonTagRegex = regexp.MustCompile(bsonTagPattern) + bsonCompatibleRegex = regexp.MustCompile(bsonCompatiblePattern) + bsonUpsertableRegex = regexp.MustCompile(bsonUpsertablePattern) + goInjectRegex = regexp.MustCompile(goInjectPattern) +} + +func (g *Generator) IsMessageBsonCompatible(message *Descriptor) bool { + if loc, ok := g.file.comments[message.path]; ok { + preMessageComments := strings.TrimSuffix(loc.GetLeadingComments(), "\n") + return bsonCompatibleRegex.Match([]byte(preMessageComments)) + } + + return false +} + +func (g *Generator) IsMessageBsonUpsertable(message *Descriptor) bool { + if loc, ok := g.file.comments[message.path]; ok { + preMessageComments := strings.TrimSuffix(loc.GetLeadingComments(), "\n") + return bsonUpsertableRegex.Match([]byte(preMessageComments)) + } + + return false +} + +func (g *Generator) GetBsonTagForField(message *Descriptor, fieldNumber int) string { + fieldPath := fmt.Sprintf("%s,%d,%d", message.path, messageFieldPath, fieldNumber) + if loc, ok := g.file.comments[fieldPath]; ok { + comment := strings.TrimSuffix(loc.GetTrailingComments(), "\n") + matchedGroups := bsonTagRegex.FindStringSubmatch(comment) + if matchedGroups == nil { + return "" + } + + return strings.TrimSpace(matchedGroups[1]) + } + + return "" +} + +func (g *Generator) GetGoInjectForMessage(message *Descriptor) string { + if loc, ok := g.file.comments[message.path]; ok { + allLeadingComments := loc.GetLeadingDetachedComments() + allLeadingComments = append(allLeadingComments, loc.GetLeadingComments()) + fmt.Fprintf(os.Stderr, "ALL LEDING: %q\n", allLeadingComments) + + for _, leadingComment := range allLeadingComments { + matchedGroups := goInjectRegex.FindStringSubmatch(leadingComment) + if matchedGroups != nil { + return matchedGroups[1] + } + } + } + + return "" +} diff --git a/protoc-gen-go/generator/generator.go b/protoc-gen-go/generator/generator.go index a5879fe676..9d6168b59e 100644 --- a/protoc-gen-go/generator/generator.go +++ b/protoc-gen-go/generator/generator.go @@ -1747,6 +1747,9 @@ func (g *Generator) generateMessage(message *Descriptor) { oneofTypeName := make(map[*descriptor.FieldDescriptorProto]string) // without star oneofInsertPoints := make(map[int32]int) // oneof_index => offset of g.Buffer + messageIsBsonCompatible := g.IsMessageBsonCompatible(message) + messageIsBsonUpsertable := g.IsMessageBsonUpsertable(message) + g.PrintComments(message.path) g.P("type ", ccTypeName, " struct {") g.In() @@ -1783,7 +1786,22 @@ func (g *Generator) generateMessage(message *Descriptor) { fieldName, fieldGetterName := ns[0], ns[1] typename, wiretype := g.GoType(message, field) jsonName := *field.Name + + bsonTag := "" + bsonOverride := g.GetBsonTagForField(message, i) + if bsonOverride != "" { + bsonTag = bsonOverride + } else if messageIsBsonCompatible || messageIsBsonUpsertable { + bsonTag = LowerCamelCase(*field.Name) + if messageIsBsonUpsertable { + bsonTag += ",omitempty" + } + } + tag := fmt.Sprintf("protobuf:%s json:%q", g.goTag(message, field, wiretype), jsonName+",omitempty") + if bsonTag != "" { + tag += fmt.Sprintf(" bson:%q", bsonTag) + } fieldNames[field] = fieldName fieldGetterNames[field] = fieldGetterName @@ -2497,6 +2515,9 @@ func (g *Generator) generateMessage(message *Descriptor) { g.P() } + // Injections + g.P(g.GetGoInjectForMessage(message)) + for _, ext := range message.ext { g.generateExtension(ext) } @@ -2665,6 +2686,19 @@ func isASCIIDigit(c byte) bool { return '0' <= c && c <= '9' } +func LowerCamelCase(s string) string { + if s == "" { + return "" + } + + r := []byte(CamelCase(s)) + // It can be assumed that the first character here is upper-case, convert it + // lower-case. + r[0] = r[0] | ('a' - 'A') + + return string(r) +} + // CamelCase returns the CamelCased name. // If there is an interior underscore followed by a lower case letter, // drop the underscore and convert the letter to upper case.