Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Don't Merge Yet] Add support for dual reads #51

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.13

require (
github.com/DataDog/datadog-go v3.7.1+incompatible
github.com/hashicorp/go-getter v1.5.9
github.com/hashicorp/go-getter v1.5.11
github.com/kr/pretty v0.2.0 // indirect
github.com/stretchr/objx v0.2.0 // indirect
github.com/stretchr/testify v1.6.1
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9n
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
github.com/hashicorp/go-getter v1.5.9 h1:b7ahZW50iQiUek/at3CvZhPK1/jiV6CtKcsJiR6E4R0=
github.com/hashicorp/go-getter v1.5.9/go.mod h1:BrrV/1clo8cCYu6mxvboYg+KutTiFnXjMEgDD8+i7ZI=
github.com/hashicorp/go-getter v1.5.11 h1:wioTuNmaBU3IE9vdFtFMcmZWj0QzLc6DYaP6sNe5onY=
github.com/hashicorp/go-getter v1.5.11/go.mod h1:9i48BP6wpWweI/0/+FBjqLrp9S8XtwUGjiu0QkWHEaY=
github.com/hashicorp/go-safetemp v1.0.0 h1:2HR189eFNrjHQyENnQMMpCiBAsRxzbTMIgBhEyExpmo=
github.com/hashicorp/go-safetemp v1.0.0/go.mod h1:oaerMy3BhqiTbVye6QuFhFtIceqFoDHxNAB65b+Rj1I=
github.com/hashicorp/go-version v1.1.0 h1:bPIoEKD27tNdebFGGxxYwcL4nepeY4j1QP23PFRGzg0=
Expand Down
14 changes: 11 additions & 3 deletions mongo/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ func IsWrite(command Command) bool {
return false
}

func IsRead(command Command) bool {
switch command {
case Aggregate, Count, Find, GetMore, ListCollections, ListIndexes:
return true
}
return false
}

func CommandAndCollection(msg bsoncore.Document) (Command, string) {
for _, s := range collectionCommands {
if coll, ok := msg.Lookup(string(s)).StringValueOK(); ok {
Expand Down Expand Up @@ -83,10 +91,10 @@ func IsIsMasterDoc(doc bsoncore.Document) bool {

func IsIsMasterValueTruthy(val bsoncore.Value) bool {
if intValue, isInt := val.Int32OK(); intValue > 0 {
return true;
return true
} else if !isInt {
boolValue, isBool := val.BooleanOK()
return boolValue && isBool
}
return false;
}
return false
}
28 changes: 27 additions & 1 deletion mongo/cursor_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mongo

import (
"fmt"
"strconv"
"time"

"go.mongodb.org/mongo-driver/x/mongo/driver"
Expand Down Expand Up @@ -30,7 +31,6 @@ func (c *cursorCache) count() int {
}

func (c *cursorCache) peek(cursorID int64, collection string) (server driver.Server, ok bool) {

v, ok := c.c.Peek(buildKey(cursorID, collection))
if !ok {
return
Expand All @@ -46,6 +46,32 @@ func (c *cursorCache) remove(cursorID int64, collection string) {
c.c.Remove(buildKey(cursorID, collection))
}

func (c *cursorCache) peekDualCursorID(cursorID int64, collection string) (int64, bool) {
v, ok := c.c.Peek(buildKey(cursorID, collection))
if !ok {
return 0, ok
}

str, ok := v.(string)
if !ok {
return 0, ok
}

id, err := strconv.ParseInt(str, 10, 64)
if err != nil {
return 0, false
}
return int64(id), true
}

func (c *cursorCache) addDualCursorID(cursorID int64, collection string, dualCursorID int64) {
c.c.Add(buildKey(cursorID, collection), strconv.FormatInt(dualCursorID, 10))
}

func (c *cursorCache) removeDualCursorID(cursorID int64, collection string) {
c.c.Remove(buildKey(cursorID, collection))
}

func buildKey(cursorID int64, collection string) string {
return fmt.Sprintf("%d-%s", cursorID, collection)
}
33 changes: 33 additions & 0 deletions mongo/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ func (m *Mongo) RoundTrip(msg *Message, tags []string) (_ *Message, err error) {
requestCursorID, _ := msg.Op.CursorID()
requestCommand, collection := msg.Op.CommandAndCollection()
transactionDetails := msg.Op.TransactionDetails()

server, err := m.selectServer(requestCursorID, collection, transactionDetails)
if err != nil {
return nil, err
Expand Down Expand Up @@ -265,6 +266,38 @@ func (m *Mongo) RoundTrip(msg *Message, tags []string) (_ *Message, err error) {
}, nil
}

func (m *Mongo) RoundTripWithDualCursor(msg *Message, tags []string, originalCursorID int64) (*Message, error) {
requestCursorID, _ := msg.Op.CursorID()
_, collection := msg.Op.CommandAndCollection()
if dualCursorID, ok := m.cursors.peekDualCursorID(requestCursorID, collection); ok {
requestCursorID = dualCursorID
}

// Rewrite message with new cursor ID if doing a getMore
if opMsg, ok := (msg.Op).(*opMsg); ok {
encodedMsg := opMsg.EncodeWithCursorID(msg.Op.RequestID(), requestCursorID, true)
decodedMsg, err := Decode(encodedMsg)
if err == nil {
msg = &Message{
Wm: encodedMsg,
Op: decodedMsg,
}
}
}

dualMsg, err := m.RoundTrip(msg, tags)

if responseCursorID, ok := dualMsg.Op.CursorID(); ok {
if responseCursorID != 0 {
m.cursors.addDualCursorID(originalCursorID, collection, responseCursorID)
} else if requestCursorID != 0 {
m.cursors.removeDualCursorID(requestCursorID, collection)
}
}

return dualMsg, err
}

func (m *Mongo) selectServer(requestCursorID int64, collection string, transDetails *TransactionDetails) (server driver.Server, err error) {
defer func(start time.Time) {
_ = m.statsd.Timing("server_selection", time.Since(start), []string{fmt.Sprintf("success:%v", err == nil)}, 1)
Expand Down
5 changes: 3 additions & 2 deletions mongo/mongo_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package mongo_test

import (
"os"
"testing"

"github.com/DataDog/datadog-go/statsd"
"github.com/coinbase/mongobetween/mongo"
"github.com/coinbase/mongobetween/proxy"
Expand All @@ -12,8 +15,6 @@ import (
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.uber.org/zap"
"os"
"testing"
)

func insertOpMsg(t *testing.T) *mongo.Message {
Expand Down
98 changes: 96 additions & 2 deletions mongo/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,21 @@ func Decode(wm []byte) (Operation, error) {
return op, nil
}

func CopyMessage(msg *Message) (*Message, error) {
wmCopy := make([]byte, len(msg.Wm))
copy(wmCopy, msg.Wm)

copyOp, err := Decode(wmCopy)
if err != nil {
return nil, err
}

return &Message{
Wm: wmCopy,
Op: copyOp,
}, nil
}

type opUnknown struct {
opCode wiremessage.OpCode
reqID int32
Expand Down Expand Up @@ -250,6 +265,48 @@ type opMsgSectionSingle struct {
msg bsoncore.Document
}

func (o *opMsgSectionSingle) rebuildWithCursorID(cursorID int64) (*opMsgSectionSingle, error) {
elements, err := o.msg.Elements()
if err != nil {
return nil, err
}

db := bsoncore.NewDocumentBuilder()
for _, element := range elements {
if element.CompareKey([]byte("getMore")) {
db = db.AppendInt64(element.Key(), cursorID)
} else {
db = db.AppendValue(element.Key(), element.Value())
}
}

doc := db.Build()
if err = doc.Validate(); err != nil {
return nil, err
}
return &opMsgSectionSingle{doc}, nil
}

func (o *opMsgSectionSingle) stripCluserTime(cursorID int64) (*opMsgSectionSingle, error) {
elements, err := o.msg.Elements()
if err != nil {
return nil, err
}

db := bsoncore.NewDocumentBuilder()
for _, element := range elements {
if !element.CompareKey([]byte("$clusterTime")) {
db = db.AppendValue(element.Key(), element.Value())
}
}

doc := db.Build()
if err = doc.Validate(); err != nil {
return nil, err
}
return &opMsgSectionSingle{doc}, nil
}

func (o *opMsgSectionSingle) cursorID() (cursorID int64, ok bool) {
if getMore, ok := o.msg.Lookup("getMore").Int64OK(); ok {
return getMore, ok
Expand Down Expand Up @@ -320,6 +377,21 @@ func (o *opMsgSectionSequence) String() string {
return fmt.Sprintf("{ SectionSingle identifier: %s, msgs: [%s] }", o.identifier, strings.Join(msgs, ", "))
}

func MustOpMsgCursorSection(op Operation) ([]byte, bool) {
opmsg, _ := op.(*opMsg)
section := opmsg.sections[0].(*opMsgSectionSingle)
cursor := section.msg.Lookup("cursor")
cursorDoc, ok := cursor.DocumentOK()
if !ok {
return nil, false
}

if batchDoc := cursorDoc.Lookup("firstBatch").Data; len(batchDoc) > 0 {
return batchDoc, true
}
return cursorDoc.Lookup("nextBatch").Data, true
}

// see https://github.com/mongodb/mongo-go-driver/blob/v1.7.2/x/mongo/driver/operation.go#L1387-L1423
func decodeMsg(reqID int32, wm []byte) (*opMsg, error) {
var ok bool
Expand Down Expand Up @@ -378,11 +450,33 @@ func (m *opMsg) OpCode() wiremessage.OpCode {

// see https://github.com/mongodb/mongo-go-driver/blob/v1.7.2/x/mongo/driver/operation.go#L898-L904
func (m *opMsg) Encode(responseTo int32) []byte {
return m.EncodeWithCursorID(responseTo, 0, false)
}

func (m *opMsg) EncodeWithCursorID(responseTo int32, newCursorID int64, forDualRead bool) []byte {
var buffer []byte
idx, buffer := wiremessage.AppendHeaderStart(buffer, 0, responseTo, wiremessage.OpMsg)
buffer = wiremessage.AppendMsgFlags(buffer, m.flags)
for _, section := range m.sections {
buffer = section.append(buffer)
for idx, section := range m.sections {
sectionSingle := section.(*opMsgSectionSingle)
if forDualRead {
sectionWithoutTime, err := sectionSingle.stripCluserTime(newCursorID)
if err == nil {
sectionSingle = sectionWithoutTime
}
}

// Assume when doing getMores with a cursor, there's only one section.
if idx == 0 && newCursorID != 0 {
sectionWithNewCursor, err := sectionSingle.rebuildWithCursorID(newCursorID)
if err != nil {
buffer = section.append(buffer)
} else {
buffer = sectionWithNewCursor.append(buffer)
}
} else {
buffer = sectionSingle.append(buffer)
}
}
if m.flags&wiremessage.ChecksumPresent == wiremessage.ChecksumPresent {
// The checksum is a uint32, but we can use appendi32 to encode it. Overflow/underflow when casting to int32 is
Expand Down
Loading