-
Notifications
You must be signed in to change notification settings - Fork 590
/
helper.go
168 lines (140 loc) · 5.41 KB
/
helper.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
package oidctest
import (
"context"
"database/sql"
"encoding/json"
"net/http"
"net/url"
"testing"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
// LoginHelper helps with logging in a user and refreshing their oauth tokens.
// It is mainly because refreshing oauth tokens is a bit tricky and requires
// some database manipulation.
type LoginHelper struct {
fake *FakeIDP
client *codersdk.Client
}
func NewLoginHelper(client *codersdk.Client, fake *FakeIDP) *LoginHelper {
if client == nil {
panic("client must not be nil")
}
if fake == nil {
panic("fake must not be nil")
}
return &LoginHelper{
fake: fake,
client: client,
}
}
// Login just helps by making an unauthenticated client and logging in with
// the given claims. All Logins should be unauthenticated, so this is a
// convenience method.
func (h *LoginHelper) Login(t *testing.T, idTokenClaims jwt.MapClaims) (*codersdk.Client, *http.Response) {
t.Helper()
unauthenticatedClient := codersdk.New(h.client.URL)
return h.fake.Login(t, unauthenticatedClient, idTokenClaims)
}
// AttemptLogin does not assert a successful login.
func (h *LoginHelper) AttemptLogin(t *testing.T, idTokenClaims jwt.MapClaims) (*codersdk.Client, *http.Response) {
t.Helper()
unauthenticatedClient := codersdk.New(h.client.URL)
return h.fake.AttemptLogin(t, unauthenticatedClient, idTokenClaims)
}
// ExpireOauthToken expires the oauth token for the given user.
func (*LoginHelper) ExpireOauthToken(t *testing.T, db database.Store, user *codersdk.Client) database.UserLink {
t.Helper()
//nolint:gocritic // Testing
ctx := dbauthz.AsSystemRestricted(testutil.Context(t, testutil.WaitMedium))
id, _, err := httpmw.SplitAPIToken(user.SessionToken())
require.NoError(t, err)
// We need to get the OIDC link and update it in the database to force
// it to be expired.
key, err := db.GetAPIKeyByID(ctx, id)
require.NoError(t, err, "get api key")
link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
UserID: key.UserID,
LoginType: database.LoginTypeOIDC,
})
require.NoError(t, err, "get user link")
// Expire the oauth link for the given user.
updated, err := db.UpdateUserLink(ctx, database.UpdateUserLinkParams{
OAuthAccessToken: link.OAuthAccessToken,
OAuthAccessTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthRefreshToken: link.OAuthRefreshToken,
OAuthRefreshTokenKeyID: sql.NullString{}, // dbcrypt will update as required
OAuthExpiry: time.Now().Add(time.Hour * -1),
UserID: link.UserID,
LoginType: link.LoginType,
DebugContext: json.RawMessage("{}"),
})
require.NoError(t, err, "expire user link")
return updated
}
// ForceRefresh forces the client to refresh its oauth token. It does this by
// expiring the oauth token, then doing an authenticated call. This will force
// the API Key middleware to refresh the oauth token.
//
// A unit test assertion makes sure the refresh token is used.
func (h *LoginHelper) ForceRefresh(t *testing.T, db database.Store, user *codersdk.Client, idToken jwt.MapClaims) {
t.Helper()
link := h.ExpireOauthToken(t, db, user)
// Updates the claims that the IDP will return. By default, it always
// uses the original claims for the original oauth token.
h.fake.UpdateRefreshClaims(link.OAuthRefreshToken, idToken)
t.Cleanup(func() {
require.True(t, h.fake.RefreshUsed(link.OAuthRefreshToken), "refresh token must be used, but has not. Did you forget to call the returned function from this call?")
})
// Do any authenticated call to force the refresh
_, err := user.User(testutil.Context(t, testutil.WaitShort), "me")
require.NoError(t, err, "user must be able to be fetched")
}
// OAuth2GetCode emulates a user clicking "allow" on the IDP page. When doing
// unit tests, it's easier to skip this step sometimes. It does make an actual
// request to the IDP, so it should be equivalent to doing this "manually" with
// actual requests.
func OAuth2GetCode(rawAuthURL string, doRequest func(req *http.Request) (*http.Response, error)) (string, error) {
authURL, err := url.Parse(rawAuthURL)
if err != nil {
return "", xerrors.Errorf("failed to parse auth URL: %w", err)
}
r, err := http.NewRequestWithContext(context.Background(), http.MethodGet, rawAuthURL, nil)
if err != nil {
return "", xerrors.Errorf("failed to create auth request: %w", err)
}
expCode := http.StatusTemporaryRedirect
resp, err := doRequest(r)
if err != nil {
return "", xerrors.Errorf("request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != expCode {
return "", codersdk.ReadBodyAsError(resp)
}
to := resp.Header.Get("Location")
if to == "" {
return "", xerrors.Errorf("expected redirect location")
}
toURL, err := url.Parse(to)
if err != nil {
return "", xerrors.Errorf("failed to parse redirect location: %w", err)
}
code := toURL.Query().Get("code")
if code == "" {
return "", xerrors.Errorf("expected code in redirect location")
}
state := authURL.Query().Get("state")
newState := toURL.Query().Get("state")
if newState != state {
return "", xerrors.Errorf("expected state %q, got %q", state, newState)
}
return code, nil
}