forked from drone/go-scm
-
Notifications
You must be signed in to change notification settings - Fork 86
/
oauth1.go
150 lines (131 loc) · 3.81 KB
/
oauth1.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
// Copyright 2018 Drone.IO Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth1
import (
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"net/http"
"strconv"
"strings"
"time"
"github.com/jenkins-x/go-scm/scm"
"github.com/jenkins-x/go-scm/scm/transport/internal"
)
// clock provides a interface for current time providers. A Clock can be used
// in place of calling time.Now() directly.
type clock interface {
Now() time.Time
}
// A noncer provides random nonce strings.
type noncer interface {
Nonce() string
}
// Transport is an http.RoundTripper that refreshes oauth
// tokens, wrapping a base RoundTripper and refreshing the
// token if expired.
type Transport struct {
// Consumer Key
ConsumerKey string
// Consumer Private Key
PrivateKey *rsa.PrivateKey
// Source supplies the Token to add to the request
// Authorization headers.
Source scm.TokenSource
// Base is the base RoundTripper used to make requests.
// If nil, http.DefaultTransport is used.
Base http.RoundTripper
noncer noncer
clock clock
}
// RoundTrip authorizes and authenticates the request with
// an access token from the request context.
func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) {
ctx := r.Context()
token, err := t.Source.Token(ctx)
if err != nil {
return nil, err
}
if token == nil {
return t.base().RoundTrip(r)
}
r2 := internal.CloneRequest(r)
err = t.setRequestAuthHeader(r2, token)
if err != nil {
return nil, err
}
return t.base().RoundTrip(r2)
}
// base returns the base transport. If no base transport
// is configured, the default transport is returned.
func (t *Transport) base() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}
// setRequestAuthHeader sets the OAuth1 header for making
// authenticated requests with an AccessToken according to
// RFC 5849 3.1.
func (t *Transport) setRequestAuthHeader(r *http.Request, token *scm.Token) error {
oauthParams := t.commonOAuthParams()
oauthParams["oauth_token"] = token.Token
params := collectParameters(r, oauthParams)
signatureBase := signatureBase(r, params)
signature, err := sign(t.PrivateKey, signatureBase)
if err != nil {
return err
}
oauthParams["oauth_signature"] = signature
r.Header.Set("Authorization", authHeaderValue(oauthParams))
return nil
}
// commonOAuthParams returns a map of the common OAuth1
// protocol parameters, excluding the oauth_signature.
func (t *Transport) commonOAuthParams() map[string]string {
return map[string]string{
"oauth_consumer_key": t.ConsumerKey,
"oauth_signature_method": "RSA-SHA1",
"oauth_timestamp": strconv.FormatInt(t.epoch(), 10),
"oauth_nonce": t.nonce(),
"oauth_version": "1.0",
}
}
// Returns a base64 encoded random 32 byte string.
func (t *Transport) nonce() string {
if t.noncer != nil {
return t.noncer.Nonce()
}
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return err.Error()
}
return base64.StdEncoding.EncodeToString(b)
}
// Returns the Unix epoch seconds.
func (t *Transport) epoch() int64 {
if t.clock != nil {
return t.clock.Now().Unix()
}
return time.Now().Unix()
}
// authHeaderValue formats OAuth parameters according to
// RFC 5849 3.5.1.
func authHeaderValue(oauthParams map[string]string) string {
pairs := sortParameters(encodeParameters(oauthParams), `%s="%s"`)
return "OAuth " + strings.Join(pairs, ", ")
}
// collectParameters returns a map of request parameter keys
// and values as defined in RFC 5849 3.4.1.3.
func collectParameters(r *http.Request, oauthParams map[string]string) map[string]string {
params := map[string]string{}
for key, value := range r.URL.Query() {
params[key] = value[0]
}
for key, value := range oauthParams {
params[key] = value
}
return params
}