-
Notifications
You must be signed in to change notification settings - Fork 116
/
sql.go
76 lines (61 loc) · 1.9 KB
/
sql.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package authn
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/lyft/clutch/backend/service"
"github.com/lyft/clutch/backend/service/db/postgres"
)
type authnToken struct {
userID string
provider string
accessToken []byte
refreshToken []byte
idToken []byte
expiry time.Time
}
type repository struct {
db *sql.DB
}
func newRepository() (*repository, error) {
svcName := postgres.Name
svc, ok := service.Registry[svcName]
if !ok {
return nil, fmt.Errorf("database '%s' not registered", svcName)
}
pg, ok := svc.(postgres.Client)
if !ok {
return nil, fmt.Errorf("database does not implement the required interface")
}
return &repository{db: pg.DB()}, nil
}
// #nosec G101
const createOrUpdateProviderToken = `
INSERT INTO authn_tokens (user_id, provider, access_token, refresh_token, id_token, expiry) VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (user_id, provider) DO UPDATE SET
user_id = EXCLUDED.user_id,
provider = EXCLUDED.provider,
access_token = EXCLUDED.access_token,
refresh_token = EXCLUDED.refresh_token,
id_token = EXCLUDED.id_token,
expiry = EXCLUDED.expiry
`
func (r *repository) createOrUpdateProviderToken(ctx context.Context, token *authnToken) error {
_, err := r.db.ExecContext(ctx, createOrUpdateProviderToken,
token.userID, token.provider, token.accessToken, token.refreshToken, token.idToken, token.expiry)
return err
}
// #nosec G101
const readProviderToken = `
SELECT user_id, provider, access_token, refresh_token, id_token, expiry FROM authn_tokens WHERE user_id = $1 AND provider = $2
`
func (r *repository) readProviderToken(ctx context.Context, userID, provider string) (*authnToken, error) {
t := &authnToken{}
q := r.db.QueryRowContext(ctx, readProviderToken, userID, provider)
err := q.Scan(&t.userID, &t.provider, &t.accessToken, &t.refreshToken, &t.idToken, &t.expiry)
if err != nil {
return nil, err
}
return t, nil
}