generated from datumforge/go-template
-
Notifications
You must be signed in to change notification settings - Fork 7
/
auth.go
246 lines (204 loc) · 7.61 KB
/
auth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
package auth
import (
"context"
"errors"
"fmt"
"time"
echo "github.com/datumforge/echox"
"github.com/datumforge/datum/internal/ent/generated"
"github.com/datumforge/datum/internal/ent/generated/apitoken"
"github.com/datumforge/datum/internal/ent/generated/organization"
"github.com/datumforge/datum/internal/ent/generated/personalaccesstoken"
"github.com/datumforge/datum/internal/ent/generated/privacy"
"github.com/datumforge/datum/internal/ent/generated/user"
"github.com/datumforge/datum/pkg/auth"
"github.com/datumforge/datum/pkg/rout"
"github.com/datumforge/datum/pkg/tokens"
)
// SessionSkipperFunc is the function that determines if the session check should be skipped
// due to the request being a PAT or API Token auth request
var SessionSkipperFunc = func(c echo.Context) bool {
return auth.GetAuthTypeFromEchoContext(c) != auth.JWTAuthentication
}
// Authenticate is a middleware function that is used to authenticate requests - it is not applied to all routes so be cognizant of that
func Authenticate(conf AuthOptions) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// if skipper function returns true, skip this middleware
if conf.Skipper(c) {
return next(c)
}
// execute any before functions
if conf.BeforeFunc != nil {
conf.BeforeFunc(c)
}
validator, err := conf.Validator()
if err != nil {
return err
}
// Create a reauthenticator function to handle refresh tokens if they are provided.
reauthenticate := Reauthenticate(conf, validator)
// Get access token from the request, if not available then attempt to refresh
// using the refresh token cookie.
accessToken, err := auth.GetAccessToken(c)
if err != nil {
switch {
case errors.Is(err, ErrNoAuthorization):
if accessToken, err = reauthenticate(c); err != nil {
return rout.HTTPErrorResponse(err)
}
default:
return rout.HTTPErrorResponse(err)
}
}
// Verify the access token is authorized for use with datum and extract claims.
authType := auth.JWTAuthentication
var au *auth.AuthenticatedUser
claims, err := validator.Verify(accessToken)
if err != nil {
// if its not a JWT, check to see if its a PAT or API Token
if conf.DBClient == nil {
return rout.HTTPErrorResponse(rout.ErrInvalidCredentials)
}
au, err = checkToken(c.Request().Context(), conf, accessToken)
if err != nil {
return rout.HTTPErrorResponse(rout.ErrInvalidCredentials)
}
} else {
// Add claims to context for use in downstream processing and continue handlers
au, err = createAuthenticatedUserFromClaims(c.Request().Context(), conf.DBClient, claims, authType)
if err != nil {
return rout.HTTPErrorResponse(rout.ErrInvalidCredentials)
}
}
auth.SetAuthenticatedUserContext(c, au)
return next(c)
}
}
}
// Reauthenticate is a middleware helper that can use refresh tokens in the echo context
// to obtain a new access token. If it is unable to obtain a new valid access token,
// then an error is returned and processing should stop.
func Reauthenticate(conf AuthOptions, validator tokens.Validator) func(c echo.Context) (string, error) {
// If no reauthenticator is available on the configuration, always return an error.
if conf.reauth == nil {
return func(c echo.Context) (string, error) {
return "", ErrRefreshDisabled
}
}
// If the reauthenticator is available, return a function that utilizes it.
return func(c echo.Context) (string, error) {
// Get the refresh token from the cookies or the headers of the request.
refreshToken, err := auth.GetRefreshToken(c)
if err != nil {
return "", err
}
// Check to ensure the refresh token is still valid.
if _, err = validator.Verify(refreshToken); err != nil {
return "", err
}
// Reauthenticate using the refresh token.
req := &RefreshRequest{RefreshToken: refreshToken}
reply, err := conf.reauth.Refresh(c.Request().Context(), req)
if err != nil {
return "", err
}
// Set the new access and refresh cookies
auth.SetAuthCookies(c.Response().Writer, reply.AccessToken, reply.RefreshToken)
return reply.AccessToken, nil
}
}
// createAuthenticatedUserFromClaims creates an authenticated user from the claims provided
func createAuthenticatedUserFromClaims(ctx context.Context, dbClient *generated.Client, claims *tokens.Claims, authType auth.AuthenticationType) (*auth.AuthenticatedUser, error) {
// get the user ID from the claims
user, err := dbClient.User.Query().Where(user.MappingID(claims.UserID)).Only(ctx)
if err != nil {
return nil, err
}
// all the query to get the organization, need to bypass the authz filter to get the org
ctx = privacy.DecisionContext(ctx, privacy.Allow)
org, err := dbClient.Organization.Query().Where(organization.MappingID(claims.OrgID)).Only(ctx)
if err != nil {
return nil, err
}
return &auth.AuthenticatedUser{
SubjectID: user.ID,
SubjectName: getSubjectName(user),
OrganizationID: org.ID,
OrganizationName: org.Name,
OrganizationIDs: []string{org.ID},
AuthenticationType: authType,
}, nil
}
// checkToken checks the bearer authorization token against the database to see if the provided
// token is an active personal access token or api token.
// If the token is valid, the authenticated user is returned
func checkToken(ctx context.Context, conf AuthOptions, token string) (*auth.AuthenticatedUser, error) {
// allow check to bypass privacy rules
ctx = privacy.DecisionContext(ctx, privacy.Allow)
// check if the token is a personal access token
au, err := isValidPersonalAccessToken(ctx, conf.DBClient, token)
if err == nil {
return au, nil
}
// check if the token is an API token
au, err = isValidAPIToken(ctx, conf.DBClient, token)
if err == nil {
return au, nil
}
return nil, err
}
// isValidPersonalAccessToken checks if the provided token is a valid personal access token and returns the authenticated user
func isValidPersonalAccessToken(ctx context.Context, dbClient *generated.Client, token string) (*auth.AuthenticatedUser, error) {
pat, err := dbClient.PersonalAccessToken.Query().Where(personalaccesstoken.Token(token)).
WithOwner().
WithOrganizations().
Only(ctx)
if err != nil {
return nil, err
}
if pat.ExpiresAt.Before(time.Now()) {
return nil, rout.ErrExpiredCredentials
}
// gather the authorized organization IDs
var orgIDs []string
orgs := pat.Edges.Organizations
if len(orgs) == 0 {
// an access token must have at least one organization to be used
return nil, rout.ErrInvalidCredentials
}
for _, org := range orgs {
orgIDs = append(orgIDs, org.ID)
}
return &auth.AuthenticatedUser{
SubjectID: pat.OwnerID,
SubjectName: getSubjectName(pat.Edges.Owner),
OrganizationIDs: orgIDs,
AuthenticationType: auth.PATAuthentication,
}, nil
}
func isValidAPIToken(ctx context.Context, dbClient *generated.Client, token string) (*auth.AuthenticatedUser, error) {
t, err := dbClient.APIToken.Query().Where(apitoken.Token(token)).
Only(ctx)
if err != nil {
return nil, err
}
if !t.ExpiresAt.IsZero() && t.ExpiresAt.Before(time.Now()) {
return nil, rout.ErrExpiredCredentials
}
return &auth.AuthenticatedUser{
SubjectID: t.ID,
SubjectName: fmt.Sprintf("service: %s", t.Name),
OrganizationID: t.OwnerID,
OrganizationIDs: []string{t.OwnerID},
AuthenticationType: auth.APITokenAuthentication,
}, nil
}
// getSubjectName returns the subject name for the user
func getSubjectName(user *generated.User) string {
subjectName := user.FirstName + " " + user.LastName
if subjectName == "" {
subjectName = user.DisplayName
}
return subjectName
}