/
sasl.go
214 lines (196 loc) · 6.22 KB
/
sasl.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
// Copyright 2016 The Mellium Contributors.
// Use of this source code is governed by the BSD 2-clause
// license that can be found in the LICENSE file.
package xmpp
import (
"context"
"crypto/tls"
"encoding/xml"
"errors"
"fmt"
"io"
"mellium.im/sasl"
"mellium.im/xmlstream"
"mellium.im/xmpp/internal/ns"
"mellium.im/xmpp/internal/saslerr"
"mellium.im/xmpp/stream"
)
// BUG(ssw): SASL feature does not have security layer byte precision.
// SASL returns a stream feature for performing authentication using the Simple
// Authentication and Security Layer (SASL) as defined in RFC 4422.
// It panics if no mechanisms are specified.
// The order in which mechanisms are specified will be the preferred order, so
// stronger mechanisms should be listed first.
//
// Identity is used when a user wants to act on behalf of another user.
// For instance, an admin might want to log in as another user to help them
// troubleshoot an issue.
// Normally it is left blank and the localpart of the Origin JID is used.
func SASL(identity, password string, mechanisms ...sasl.Mechanism) StreamFeature {
if len(mechanisms) == 0 {
panic("xmpp: Must specify at least 1 SASL mechanism")
}
return StreamFeature{
Name: xml.Name{Space: ns.SASL, Local: "mechanisms"},
Necessary: Secure,
Prohibited: Authn,
List: func(ctx context.Context, e xmlstream.TokenWriter, start xml.StartElement) (req bool, err error) {
req = true
if err = e.EncodeToken(start); err != nil {
return
}
startMechanism := xml.StartElement{Name: xml.Name{Space: "", Local: "mechanism"}}
for _, m := range mechanisms {
select {
case <-ctx.Done():
return true, ctx.Err()
default:
}
if err = e.EncodeToken(startMechanism); err != nil {
return
}
if err = e.EncodeToken(xml.CharData(m.Name)); err != nil {
return
}
if err = e.EncodeToken(startMechanism.End()); err != nil {
return
}
}
return req, e.EncodeToken(start.End())
},
Parse: func(ctx context.Context, r xml.TokenReader, start *xml.StartElement) (bool, interface{}, error) {
parsed := struct {
XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"`
List []string `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanism"`
}{}
err := xml.NewTokenDecoder(r).DecodeElement(&parsed, start)
return true, parsed.List, err
},
Negotiate: func(ctx context.Context, session *Session, data interface{}) (mask SessionState, rw io.ReadWriter, err error) {
if (session.State() & Received) == Received {
panic("SASL server not yet implemented")
}
c := session.Conn()
var selected sasl.Mechanism
// Select a mechanism, preferring the client order.
selectmechanism:
for _, m := range mechanisms {
for _, name := range data.([]string) {
if name == m.Name {
selected = m
break selectmechanism
}
}
}
// No matching mechanism found…
if selected.Name == "" {
return mask, nil, errors.New(`No matching SASL mechanisms found`)
}
opts := []sasl.Option{
sasl.Credentials(func() ([]byte, []byte, []byte) {
return []byte(session.LocalAddr().Localpart()), []byte(password), []byte(identity)
}),
sasl.RemoteMechanisms(data.([]string)...),
}
if tlsConn, ok := c.(*tls.Conn); ok {
opts = append(opts, sasl.TLSState(tlsConn.ConnectionState()))
}
client := sasl.NewClient(selected, opts...)
more, resp, err := client.Step(nil)
if err != nil {
return mask, nil, err
}
// RFC6120 §6.4.2:
// If the initiating entity needs to send a zero-length initial
// response, it MUST transmit the response as a single equals sign
// character ("="), which indicates that the response is present but
// contains no data.
if len(resp) == 0 {
resp = []byte{'='}
}
// Send <auth/> and the initial payload to start SASL auth.
if _, err = fmt.Fprintf(c,
`<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='%s'>%s</auth>`,
selected.Name, resp,
); err != nil {
return mask, nil, err
}
d := xml.NewTokenDecoder(session)
// If we're already done after the first step, decode the <success/> or
// <failure/> before we exit.
if !more {
tok, err := d.Token()
if err != nil {
return mask, nil, err
}
if t, ok := tok.(xml.StartElement); ok {
// TODO: Handle the additional data that could be returned if
// success?
_, _, err := decodeSASLChallenge(d, t, false)
if err != nil {
return mask, nil, err
}
} else {
return mask, nil, stream.BadFormat
}
}
success := false
for more {
select {
case <-ctx.Done():
return mask, nil, ctx.Err()
default:
}
tok, err := d.Token()
if err != nil {
return mask, nil, err
}
var challenge []byte
if t, ok := tok.(xml.StartElement); ok {
challenge, success, err = decodeSASLChallenge(d, t, true)
if err != nil {
return mask, nil, err
}
} else {
return mask, nil, stream.BadFormat
}
if more, resp, err = client.Step(challenge); err != nil {
return mask, nil, err
}
if !more && success {
// We're done with SASL and we're successful
break
}
// TODO: What happens if there's more and success (broken server)?
if _, err = fmt.Fprintf(c,
`<response xmlns='urn:ietf:params:xml:ns:xmpp-sasl'>%s</response>`, resp); err != nil {
return mask, nil, err
}
}
return Authn, c, nil
},
}
}
func decodeSASLChallenge(d *xml.Decoder, start xml.StartElement, allowChallenge bool) (challenge []byte, success bool, err error) {
switch start.Name {
case xml.Name{Space: ns.SASL, Local: "challenge"}, xml.Name{Space: ns.SASL, Local: "success"}:
if !allowChallenge && start.Name.Local == "challenge" {
return nil, false, stream.UnsupportedStanzaType
}
challenge := struct {
Data []byte `xml:",chardata"`
}{}
if err = d.DecodeElement(&challenge, &start); err != nil {
return nil, false, err
}
return challenge.Data, start.Name.Local == "success", nil
case xml.Name{Space: ns.SASL, Local: "failure"}:
fail := saslerr.Failure{}
if err = d.DecodeElement(&fail, &start); err != nil {
return nil, false, err
}
return nil, false, fail
default:
return nil, false, stream.UnsupportedStanzaType
}
}