forked from mongodb/mongo-go-driver
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sasl.go
119 lines (97 loc) · 2.71 KB
/
sasl.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"github.com/mongodb/mongo-go-driver/bson"
"github.com/mongodb/mongo-go-driver/core/command"
"github.com/mongodb/mongo-go-driver/core/description"
"github.com/mongodb/mongo-go-driver/core/wiremessage"
)
// SaslClient is the client piece of a sasl conversation.
type SaslClient interface {
Start() (string, []byte, error)
Next(challenge []byte) ([]byte, error)
Completed() bool
}
// SaslClientCloser is a SaslClient that has resources to clean up.
type SaslClientCloser interface {
SaslClient
Close()
}
// ConductSaslConversation handles running a sasl conversation with MongoDB.
func ConductSaslConversation(ctx context.Context, desc description.Server, rw wiremessage.ReadWriter, db string, client SaslClient) error {
// Arbiters cannot be authenticated
if desc.Kind == description.RSArbiter {
return nil
}
if db == "" {
db = defaultAuthDB
}
if closer, ok := client.(SaslClientCloser); ok {
defer closer.Close()
}
mech, payload, err := client.Start()
if err != nil {
return newError(err, mech)
}
saslStartCmd := command.Read{
DB: db,
Command: bson.NewDocument(
bson.EC.Int32("saslStart", 1),
bson.EC.String("mechanism", mech),
bson.EC.Binary("payload", payload),
),
}
type saslResponse struct {
ConversationID int `bson:"conversationId"`
Code int `bson:"code"`
Done bool `bson:"done"`
Payload []byte `bson:"payload"`
}
var saslResp saslResponse
ssdesc := description.SelectedServer{Server: desc}
rdr, err := saslStartCmd.RoundTrip(ctx, ssdesc, rw)
if err != nil {
return newError(err, mech)
}
err = bson.Unmarshal(rdr, &saslResp)
if err != nil {
return newAuthError("unmarshall error", err)
}
cid := saslResp.ConversationID
for {
if saslResp.Code != 0 {
return newError(err, mech)
}
if saslResp.Done && client.Completed() {
return nil
}
payload, err = client.Next(saslResp.Payload)
if err != nil {
return newError(err, mech)
}
if saslResp.Done && client.Completed() {
return nil
}
saslContinueCmd := command.Read{
DB: db,
Command: bson.NewDocument(
bson.EC.Int32("saslContinue", 1),
bson.EC.Int32("conversationId", int32(cid)),
bson.EC.Binary("payload", payload),
),
}
rdr, err = saslContinueCmd.RoundTrip(ctx, ssdesc, rw)
if err != nil {
return newError(err, mech)
}
err = bson.Unmarshal(rdr, &saslResp)
if err != nil {
return newAuthError("unmarshal error", err)
}
}
}