/
callback.go
243 lines (207 loc) · 6.95 KB
/
callback.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
package login
import (
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os/exec"
"runtime"
"strings"
"sync"
"github.com/bwplotka/oidc"
)
const (
codeParam = "code"
stateParam = "state"
errParam = "error"
errDescParam = "error_description"
)
func rand128Bits() string {
buff := make([]byte, 16) // 128 bit random ID.
if _, err := io.ReadFull(rand.Reader, buff); err != nil {
panic(err)
}
return strings.TrimRight(base64.URLEncoding.EncodeToString(buff), "=")
}
// open opens the specified URL in the default browser of the user.
func openBrowser(url string) error {
var cmd string
var args []string
switch runtime.GOOS {
case "windows":
cmd = "cmd"
args = []string{"/c", "start"}
case "darwin":
cmd = "open"
default: // "linux", "freebsd", "openbsd", "netbsd"
cmd = "xdg-open"
}
args = append(args, url)
return exec.Command(cmd, args...).Start()
}
// callbackResponse contains return message from callback server including token or error.
type callbackResponse struct {
token *oidc.Token
err error
}
// callbackRequest specifies values that are needed for expected callback handling.
type callbackRequest struct {
ctx context.Context
expectedState string
cfg oidc.Config
client *oidc.Client
}
// CallbackServer carries a callback handler for OIDC auth code flow.
// NOTE: This is not thread-safe in terms of multiple logins in the same time.
type CallbackServer struct {
redirectURL string
callbackCh chan *callbackResponse
// CallbackReq is written in separate thread so guard that.
callbackReqMu sync.Mutex
// If empty, nothing is expected, so callback should immediately return err.
callbackReq *callbackRequest
}
// NewServer creates HTTP server with OIDC callback on the bindAddress an argument. BindAddress is the ultimately a redirectURL that all clients MUST register
// first on the OIDC server. It can (and is recommended) to point to localhost. Bind Address must include port. You can specify 0 if your
// OIDC provider support wildcard on port (almost all server does NOT).
func NewServer(bindAddress string) (srv *CallbackServer, closeSrv func(), err error) {
bindURL, err := url.Parse(bindAddress)
if err != nil {
return nil, nil, fmt.Errorf("BindAddress is not in a form of URL. Err: %v", err)
}
listener, err := net.Listen("tcp", bindURL.Host)
if err != nil {
return nil, nil, fmt.Errorf("Failed to Listen for tcp on: %s. Err: %v", bindURL.Host, err)
}
s := &CallbackServer{
redirectURL: fmt.Sprintf("http://%s%s", listener.Addr().String(), bindURL.Path),
callbackCh: make(chan *callbackResponse),
}
mux := http.NewServeMux()
mux.HandleFunc(bindURL.Path, s.callbackHandler)
go func() {
http.Serve(listener, mux)
}()
return s, func() {
listener.Close()
close(s.callbackCh)
}, nil
}
// NewReuseServer creates HTTP server with OIDC callback registered on given HTTP mux. Server constructed in such way
// is not responsible for serving the callback. This is responsibility of the caller.
func NewReuseServer(pattern string, listenAddress string, mux *http.ServeMux) *CallbackServer {
s := &CallbackServer{
redirectURL: fmt.Sprintf("http://%s%s", listenAddress, pattern),
callbackCh: make(chan *callbackResponse),
}
mux.HandleFunc(pattern, s.callbackHandler)
return s
}
// callbackHandler handles redirect from OIDC provider with either code or error parameters.
// If none callback is expected it will return error.
// In case of valid code with corresponded state it will perform token exchange with OIDC provider.
// Any message is propagated via Go channel if the callback was expected.
// NOTE: This is not thread-safe in terms of multiple logins in the same time.
func (s *CallbackServer) callbackHandler(w http.ResponseWriter, r *http.Request) {
s.callbackReqMu.Lock()
if s.callbackReq == nil {
w.WriteHeader(http.StatusPreconditionFailed)
w.Write([]byte("Did not expect OIDC callback"))
return
}
defer func() {
s.callbackReq = nil
s.callbackReqMu.Unlock()
}()
err := r.ParseForm()
if err != nil {
err := fmt.Errorf("Failed to parse request form. Err: %v", err)
s.errRespond(w, r, err)
return
}
code, state, err := parseCallbackRequest(r.Form)
if err != nil {
s.errRespond(w, r, err)
return
}
if state != s.callbackReq.expectedState {
err := fmt.Errorf("Invalid state parameter. Got %s, expected: %s", state, s.callbackReq.expectedState)
s.errRespond(w, r, err)
return
}
ctx := mergeContexts(r.Context(), s.callbackReq.ctx)
oidcToken, err := s.callbackReq.client.Exchange(ctx, s.callbackReq.cfg, code)
if err != nil {
s.errRespond(w, r, err)
return
}
callbackResponse := &callbackResponse{
token: oidcToken,
}
OKCallbackResponse(w, r)
select {
case <-s.callbackReq.ctx.Done():
case s.callbackCh <- callbackResponse:
}
return
}
func parseCallbackRequest(form url.Values) (code string, state string, err error) {
state = form.Get(stateParam)
if state == "" {
return "", "", errors.New("User session error. No state parameter.")
}
if errorCode := form.Get(errParam); errorCode != "" {
// Got error from provider. Passing through.
return "", "", fmt.Errorf("Got error from provider: %s Desc: %s", errorCode, form.Get(errDescParam))
}
code = form.Get(codeParam)
if code == "" {
return "", "", errors.New("Missing code token.")
}
return code, state, nil
}
// OKCallbackResponse is package wide function variable that returns HTTP response on successful OIDC `code` flow.
var OKCallbackResponse = func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OIDC authentication flow is completed. You can close browser tab."))
}
// ErrCallbackResponse is package wide function variable that returns HTTP response on failed OIDC `code` flow.
// Note that, by default we don't want user to see anything wrong on browser side. All errors are propagated to command.
// If it is required otherwise, override this function.
var ErrCallbackResponse = func(w http.ResponseWriter, _ *http.Request, _ error) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OIDC authentication flow is completed. You can close browser tab."))
}
func (s *CallbackServer) errRespond(w http.ResponseWriter, r *http.Request, err error) {
callbackResponse := &callbackResponse{
err: err,
}
ErrCallbackResponse(w, r, err)
select {
case <-s.callbackReq.ctx.Done():
case s.callbackCh <- callbackResponse:
}
return
}
func mergeContexts(originalCtx context.Context, oidcCtx context.Context) context.Context {
if customClient := originalCtx.Value(oidc.HTTPClientCtxKey); customClient != nil {
return originalCtx
}
return context.WithValue(originalCtx, oidc.HTTPClientCtxKey, oidcCtx.Value(oidc.HTTPClientCtxKey))
}
func (s *CallbackServer) ExpectCallback(callbackReq *callbackRequest) {
s.callbackReqMu.Lock()
defer s.callbackReqMu.Unlock()
s.callbackReq = callbackReq
}
func (s *CallbackServer) Callback() <-chan *callbackResponse {
return s.callbackCh
}
func (s *CallbackServer) RedirectURL() string {
return s.redirectURL
}