forked from vitessio/vitess
-
Notifications
You must be signed in to change notification settings - Fork 13
/
auth_server_ldap.go
239 lines (216 loc) · 6.89 KB
/
auth_server_ldap.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
/*
Copyright 2017 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreedto in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ldapauthserver
import (
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"net"
"sync"
"time"
ldap "gopkg.in/ldap.v2"
"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/netutil"
"vitess.io/vitess/go/vt/log"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/vttls"
)
var (
ldapAuthConfigFile = flag.String("mysql_ldap_auth_config_file", "", "JSON File from which to read LDAP server config.")
ldapAuthConfigString = flag.String("mysql_ldap_auth_config_string", "", "JSON representation of LDAP server config.")
ldapAuthMethod = flag.String("mysql_ldap_auth_method", mysql.MysqlClearPassword, "client-side authentication method to use. Supported values: mysql_clear_password, dialog.")
)
// AuthServerLdap implements AuthServer with an LDAP backend
type AuthServerLdap struct {
Client
ServerConfig
Method string
User string
Password string
GroupQuery string
UserDnPattern string
RefreshSeconds time.Duration
}
// Init is public so it can be called from plugin_auth_ldap.go (go/cmd/vtgate)
func Init() {
if *ldapAuthConfigFile == "" && *ldapAuthConfigString == "" {
log.Infof("Not configuring AuthServerLdap because mysql_ldap_auth_config_file and mysql_ldap_auth_config_string are empty")
return
}
if *ldapAuthConfigFile != "" && *ldapAuthConfigString != "" {
log.Infof("Both mysql_ldap_auth_config_file and mysql_ldap_auth_config_string are non-empty, can only use one.")
return
}
if *ldapAuthMethod != mysql.MysqlClearPassword && *ldapAuthMethod != mysql.MysqlDialog {
log.Exitf("Invalid mysql_ldap_auth_method value: only support mysql_clear_password or dialog")
}
ldapAuthServer := &AuthServerLdap{
Client: &ClientImpl{},
ServerConfig: ServerConfig{},
Method: *ldapAuthMethod,
}
data := []byte(*ldapAuthConfigString)
if *ldapAuthConfigFile != "" {
var err error
data, err = ioutil.ReadFile(*ldapAuthConfigFile)
if err != nil {
log.Exitf("Failed to read mysql_ldap_auth_config_file: %v", err)
}
}
if err := json.Unmarshal(data, ldapAuthServer); err != nil {
log.Exitf("Error parsing AuthServerLdap config: %v", err)
}
mysql.RegisterAuthServerImpl("ldap", ldapAuthServer)
}
// AuthMethod is part of the AuthServer interface.
func (asl *AuthServerLdap) AuthMethod(user string) (string, error) {
return asl.Method, nil
}
// Salt will be unused in AuthServerLdap.
func (asl *AuthServerLdap) Salt() ([]byte, error) {
return mysql.NewSalt()
}
// ValidateHash is unimplemented for AuthServerLdap.
func (asl *AuthServerLdap) ValidateHash(salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (mysql.Getter, error) {
panic("unimplemented")
}
// Negotiate is part of the AuthServer interface.
func (asl *AuthServerLdap) Negotiate(c *mysql.Conn, user string, remoteAddr net.Addr) (mysql.Getter, error) {
// Finish the negotiation.
password, err := mysql.AuthServerNegotiateClearOrDialog(c, asl.Method)
if err != nil {
return nil, err
}
return asl.validate(user, password)
}
func (asl *AuthServerLdap) validate(username, password string) (mysql.Getter, error) {
if err := asl.Client.Connect("tcp", &asl.ServerConfig); err != nil {
return nil, err
}
defer asl.Client.Close()
if err := asl.Client.Bind(fmt.Sprintf(asl.UserDnPattern, username), password); err != nil {
return nil, err
}
groups, err := asl.getGroups(username)
if err != nil {
return nil, err
}
return &LdapUserData{asl: asl, groups: groups, username: username, lastUpdated: time.Now(), updating: false}, nil
}
//this needs to be passed an already connected client...should check for this
func (asl *AuthServerLdap) getGroups(username string) ([]string, error) {
err := asl.Client.Bind(asl.User, asl.Password)
if err != nil {
return nil, err
}
req := ldap.NewSearchRequest(
asl.GroupQuery,
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
fmt.Sprintf("(memberUid=%s)", username),
[]string{"cn"},
nil,
)
res, err := asl.Client.Search(req)
if err != nil {
return nil, err
}
var groups []string
for _, entry := range res.Entries {
for _, attr := range entry.Attributes {
groups = append(groups, attr.Values[0])
}
}
return groups, nil
}
// LdapUserData holds username and LDAP groups as well as enough data to
// intelligently update itself.
type LdapUserData struct {
asl *AuthServerLdap
groups []string
username string
lastUpdated time.Time
updating bool
sync.Mutex
}
func (lud *LdapUserData) update() {
lud.Lock()
if lud.updating {
lud.Unlock()
return
}
lud.updating = true
lud.Unlock()
err := lud.asl.Client.Connect("tcp", &lud.asl.ServerConfig)
if err != nil {
log.Errorf("Error updating LDAP user data: %v", err)
return
}
defer lud.asl.Client.Close() //after the error check
groups, err := lud.asl.getGroups(lud.username)
if err != nil {
log.Errorf("Error updating LDAP user data: %v", err)
return
}
lud.Lock()
lud.groups = groups
lud.lastUpdated = time.Now()
lud.updating = false
lud.Unlock()
}
// Get returns wrapped username and LDAP groups and possibly updates the cache
func (lud *LdapUserData) Get() *querypb.VTGateCallerID {
if time.Since(lud.lastUpdated) > lud.asl.RefreshSeconds*time.Second {
go lud.update()
}
return &querypb.VTGateCallerID{Username: lud.username, Groups: lud.groups}
}
// ServerConfig holds the config for and LDAP server
// * include port in ldapServer, "ldap.example.com:386"
type ServerConfig struct {
LdapServer string
LdapCert string
LdapKey string
LdapCA string
}
// Client provides an interface we can mock
type Client interface {
Connect(network string, config *ServerConfig) error
Close()
Bind(string, string) error
Search(*ldap.SearchRequest) (*ldap.SearchResult, error)
}
// ClientImpl is the real implementation of LdapClient
type ClientImpl struct {
*ldap.Conn
}
// Connect calls ldap.Dial and then upgrades the connection to TLS
// This must be called before any other methods
func (lci *ClientImpl) Connect(network string, config *ServerConfig) error {
conn, err := ldap.Dial(network, config.LdapServer)
lci.Conn = conn
// Reconnect with TLS ... why don't we simply DialTLS directly?
serverName, _, err := netutil.SplitHostPort(config.LdapServer)
if err != nil {
return err
}
tlsConfig, err := vttls.ClientConfig(config.LdapCert, config.LdapKey, config.LdapCA, serverName)
if err != nil {
return err
}
err = conn.StartTLS(tlsConfig)
if err != nil {
return err
}
return nil
}