Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mongo): passes the password in the function arguments #1729

Merged
merged 10 commits into from
Mar 27, 2024
6 changes: 3 additions & 3 deletions pkg/core/proxy/integrations/mongo/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"go.uber.org/zap"
)

func decodeMongo(ctx context.Context, logger *zap.Logger, reqBuf []byte, clientConn net.Conn, dstCfg *integrations.ConditionalDstCfg, mockDb integrations.MockMemDb, _ models.OutgoingOptions) error {
func decodeMongo(ctx context.Context, logger *zap.Logger, reqBuf []byte, clientConn net.Conn, dstCfg *integrations.ConditionalDstCfg, mockDb integrations.MockMemDb, opts models.OutgoingOptions) error {
startedDecoding := time.Now()
requestBuffers := [][]byte{reqBuf}

Expand Down Expand Up @@ -220,7 +220,7 @@ func decodeMongo(ctx context.Context, logger *zap.Logger, reqBuf []byte, clientC
if len(configMocks[bestMatchIndex].Spec.MongoRequests) > 0 {
expectedRequestSections = configMocks[bestMatchIndex].Spec.MongoRequests[0].Message.(*models.MongoOpMessage).Sections
}
message, err := encodeOpMsg(respMessage, mongoRequest.(*models.MongoOpMessage).Sections, expectedRequestSections, logger)
message, err := encodeOpMsg(respMessage, mongoRequest.(*models.MongoOpMessage).Sections, expectedRequestSections, opts.MongoPassword, logger)
if err != nil {
utils.LogError(logger, err, "failed to encode the recorded OpMsg response", zap.Any("for request with id", responseTo))
errCh <- err
Expand Down Expand Up @@ -264,7 +264,7 @@ func decodeMongo(ctx context.Context, logger *zap.Logger, reqBuf []byte, clientC
if len(matchedMock.Spec.MongoRequests) > 0 {
expectedRequestSections = matchedMock.Spec.MongoRequests[0].Message.(*models.MongoOpMessage).Sections
}
message, err := encodeOpMsg(respMessage, mongoRequest.(*models.MongoOpMessage).Sections, expectedRequestSections, logger)
message, err := encodeOpMsg(respMessage, mongoRequest.(*models.MongoOpMessage).Sections, expectedRequestSections, opts.MongoPassword, logger)
if err != nil {
utils.LogError(logger, err, "failed to encode the recorded OpMsg response", zap.Any("for request with id", responseTo))
errCh <- err
Expand Down
2 changes: 0 additions & 2 deletions pkg/core/proxy/integrations/mongo/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ func init() {

// TODO: Remove these global variables, and find a better way to handle this
var configRequests = []string{""}
var password string

type Mongo struct {
logger *zap.Logger
Expand Down Expand Up @@ -60,7 +59,6 @@ func (m *Mongo) RecordOutgoing(ctx context.Context, src net.Conn, dst net.Conn,
}

func (m *Mongo) MockOutgoing(ctx context.Context, src net.Conn, dstCfg *integrations.ConditionalDstCfg, mockDb integrations.MockMemDb, opts models.OutgoingOptions) error {
password = opts.MongoPassword
logger := m.logger.With(zap.Any("Client IP Address", src.RemoteAddr().String()), zap.Any("Client ConnectionID", util.GetNextID()), zap.Any("Destination ConnectionID", util.GetNextID()))
reqBuf, err := util.ReadInitialBuf(ctx, logger, src)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions pkg/core/proxy/integrations/mongo/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ func extractSectionSingle(data string) (string, error) {
return content, nil
}

func encodeOpMsg(responseOpMsg *models.MongoOpMessage, actualRequestMsgSections []string, expectedRequestMsgSections []string, logger *zap.Logger) (*opMsg, error) {
func encodeOpMsg(responseOpMsg *models.MongoOpMessage, actualRequestMsgSections []string, expectedRequestMsgSections []string, mongoPassword string, logger *zap.Logger) (*opMsg, error) {
message := &opMsg{
flags: wiremessage.MsgFlag(responseOpMsg.FlagBits),
checksum: uint32(responseOpMsg.Checksum),
Expand Down Expand Up @@ -505,7 +505,7 @@ func encodeOpMsg(responseOpMsg *models.MongoOpMessage, actualRequestMsgSections
return nil, err
}

resultStr, ok, err := handleScramAuth(actualRequestMsgSections, expectedRequestMsgSections, sectionStr, logger)
resultStr, ok, err := handleScramAuth(actualRequestMsgSections, expectedRequestMsgSections, sectionStr, mongoPassword, logger)
if err != nil {
return nil, err
}
Expand Down
23 changes: 15 additions & 8 deletions pkg/core/proxy/integrations/mongo/scramAuth.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"strconv"
"strings"
"sync"

"go.keploy.io/server/v2/pkg/core/proxy/integrations/scram"
"go.keploy.io/server/v2/pkg/core/proxy/util"
Expand Down Expand Up @@ -46,7 +47,7 @@ func isScramAuthRequest(actualRequestSections []string, logger *zap.Logger) bool

// authMessageMap stores the auth message from the saslStart request for the converstionIds. So, that
// it can be used in the saslContinue request to generate the new server proof
var authMessageMap map[string]string
var authMessageMap = sync.Map{}

// handleScramAuth handles the SCRAM authentication requests by generating the
// appropriate response string.
Expand All @@ -61,7 +62,7 @@ var authMessageMap map[string]string
// - The generated response string.
// - A boolean indicating if the processing was successful.
// - An error, if any, that occurred during processing.
func handleScramAuth(actualRequestSections, expectedRequestSections []string, responseSection string, logger *zap.Logger) (string, bool, error) {
func handleScramAuth(actualRequestSections, expectedRequestSections []string, responseSection, mongoPassword string, logger *zap.Logger) (string, bool, error) {
// Iterate over each section in the actual request sections
for i, v := range actualRequestSections {
// single document do not uses section sequence for SCRAM auth
Expand All @@ -88,7 +89,7 @@ func handleScramAuth(actualRequestSections, expectedRequestSections []string, re
// Check if the message is for final request of the SASL (authentication) process
} else if _, exists := actualMsg["saslContinue"]; exists {
if _, exists := actualMsg["payload"]; exists {
return handleSaslContinue(actualMsg, responseSection, logger)
return handleSaslContinue(actualMsg, responseSection, mongoPassword, logger)
}
}
}
Expand Down Expand Up @@ -313,7 +314,8 @@ func handleSaslStart(i int, actualMsg map[string]interface{}, expectedRequestSec
// generate the auth message from the recieved first request and recorded first response
authMessage := scram.GenerateAuthMessage(string(decodedActualReqPayload), newFirstAuthResponse, logger)
// store the auth message in the global map for the conversationId
authMessageMap[conversationID] = authMessage
authMessageMap.Store(conversationID, authMessage)

logger.Debug("genrate the new auth message for the recieved auth request", zap.String("msg", authMessage))

// marshal the new first response for the SCRAM authentication
Expand All @@ -337,7 +339,7 @@ func handleSaslStart(i int, actualMsg map[string]interface{}, expectedRequestSec
// - The updated response section string.
// - A boolean indicating if the processing was successful.
// - An error, if any, that occurred during processing.
func handleSaslContinue(actualMsg map[string]interface{}, responseSection string, logger *zap.Logger) (string, bool, error) {
func handleSaslContinue(actualMsg map[string]interface{}, responseSection, mongoPassword string, logger *zap.Logger) (string, bool, error) {
var responseMsg map[string]interface{}

err := json.Unmarshal([]byte(responseSection), &responseMsg)
Expand Down Expand Up @@ -380,10 +382,15 @@ func handleSaslContinue(actualMsg map[string]interface{}, responseSection string
salt := ""
itr := 0
// get the authMessage from the saslStart conversation. Since, saslContinue have the same conversationId
authMsg := authMessageMap[conversationID]
// authMsg := authMessageMap[conversationID]
authMessage, ok := authMessageMap.Load(conversationID)
authMessageStr := ""
if ok {
authMessageStr = authMessage.(string)
}

// get the salt and iteration from the authMessage to generate salted password
fields = strings.Split(authMsg, ",")
fields = strings.Split(authMessageStr, ",")
for _, part := range fields {
if strings.HasPrefix(part, "s=") {
// Split based on "=" and get the value of "s"
Expand All @@ -405,7 +412,7 @@ func handleSaslContinue(actualMsg map[string]interface{}, responseSection string
}
// Since, the server proof is the signature generated by the authMessage and salted password.
// So, need to return the new server proof according to the new authMessage which is different from the recorded.
newVerifier, err := scram.GenerateServerFinalMessage(authMessageMap[conversationID], "SCRAM-SHA-1", password, salt, itr, logger)
newVerifier, err := scram.GenerateServerFinalMessage(authMessageStr, "SCRAM-SHA-1", mongoPassword, salt, itr, logger)
if err != nil {
utils.LogError(logger, err, "failed to get the new server proof")
return "", false, err
Expand Down
Loading