forked from grafana/k6
/
http.go
150 lines (120 loc) · 3.33 KB
/
http.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package netext
import (
"bytes"
"encoding/base64"
"io"
"io/ioutil"
"net/http"
"strings"
"sync"
"github.com/ThomsonReutersEikon/go-ntlm/ntlm"
"github.com/pkg/errors"
)
type HTTPTransport struct {
*http.Transport
mu sync.Mutex
authCache map[string]bool
enableCache bool
}
func NewHTTPTransport(transport *http.Transport) *HTTPTransport {
return &HTTPTransport{
Transport: transport,
authCache: make(map[string]bool),
enableCache: true,
}
}
func (t *HTTPTransport) CloseIdleConnections() {
t.enableCache = false
t.Transport.CloseIdleConnections()
}
func (t *HTTPTransport) RoundTrip(req *http.Request) (res *http.Response, err error) {
if t.Transport == nil {
return nil, errors.New("no roundtrip defined")
}
// checking if the request needs ntlm authentication
if GetAuth(req.Context()) == "ntlm" && req.URL.User != nil {
return t.roundtripWithNTLM(req)
}
return t.Transport.RoundTrip(req)
}
func (t *HTTPTransport) roundtripWithNTLM(req *http.Request) (res *http.Response, err error) {
rt := t.Transport
username := req.URL.User.Username()
password, _ := req.URL.User.Password()
// Save request body
body := bytes.Buffer{}
if req.Body != nil {
_, err = body.ReadFrom(req.Body)
if err != nil {
return nil, err
}
if err := req.Body.Close(); err != nil {
return nil, err
}
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
}
// before making the request check if there is a cached authorization.
if _, ok := t.getAuthCache(req.URL.String()); t.enableCache && ok {
req.Header.Del("Authorization")
} else {
req.Header.Set("Authorization", "NTLM TlRMTVNTUAABAAAAB4IAAAAAAAAAAAAAAAAAAAAAAAAAAAAAMAAAAAAAMAA=")
}
res, err = rt.RoundTrip(req)
if err != nil {
return nil, err
}
if res.StatusCode != http.StatusUnauthorized {
return res, err
}
if _, err := io.Copy(ioutil.Discard, res.Body); err != nil {
return nil, err
}
if err := res.Body.Close(); err != nil {
return nil, err
}
req.Body = ioutil.NopCloser(bytes.NewReader(body.Bytes()))
// retrieve Www-Authenticate header from response
ntlmChallenge := res.Header.Get("WWW-Authenticate")
if !strings.HasPrefix(ntlmChallenge, "NTLM ") {
return nil, errors.New("Invalid WWW-Authenticate header")
}
challengeBytes, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(ntlmChallenge, "NTLM "))
if err != nil {
return nil, err
}
session, err := ntlm.CreateClientSession(ntlm.Version2, ntlm.ConnectionlessMode)
if err != nil {
return nil, err
}
session.SetUserInfo(username, password, "")
// parse NTLM challenge
challenge, err := ntlm.ParseChallengeMessage(challengeBytes)
if err != nil {
return nil, err
}
err = session.ProcessChallengeMessage(challenge)
if err != nil {
return nil, err
}
// authenticate user
authenticate, err := session.GenerateAuthenticateMessage()
if err != nil {
return nil, err
}
// set NTLM Authorization header
header := "NTLM " + base64.StdEncoding.EncodeToString(authenticate.Bytes())
req.Header.Set("Authorization", header)
t.setAuthCache(req.URL.String(), true)
return rt.RoundTrip(req)
}
func (t *HTTPTransport) setAuthCache(key string, value bool) {
t.mu.Lock()
defer t.mu.Unlock()
t.authCache[key] = value
}
func (t *HTTPTransport) getAuthCache(key string) (bool, bool) {
t.mu.Lock()
defer t.mu.Unlock()
value, ok := t.authCache[key]
return value, ok
}