forked from keybase/managed-bots
/
db.go
139 lines (124 loc) · 3.04 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package base
import (
"database/sql"
"fmt"
"time"
"golang.org/x/oauth2"
)
type DB struct {
*sql.DB
}
func NewDB(db *sql.DB) *DB {
return &DB{
DB: db,
}
}
func (d *DB) RunTxn(fn func(tx *sql.Tx) error) error {
tx, err := d.Begin()
if err != nil {
return err
}
if err := fn(tx); err != nil {
if rerr := tx.Rollback(); rerr != nil {
fmt.Printf("unable to rollback: %v", rerr)
}
return err
}
return tx.Commit()
}
type BaseOAuthDB struct {
*DB
}
func NewBaseOAuthDB(db *sql.DB) *BaseOAuthDB {
return &BaseOAuthDB{
DB: NewDB(db),
}
}
func (d *BaseOAuthDB) GetState(state string) (*OAuthRequest, error) {
var oauthState OAuthRequest
row := d.DB.QueryRow(`SELECT identifier, conv_id, msg_id, is_complete
FROM oauth_state
WHERE state = ?`, state)
err := row.Scan(&oauthState.TokenIdentifier, &oauthState.ConvID,
&oauthState.MsgID, &oauthState.IsComplete)
switch err {
case nil:
return &oauthState, nil
case sql.ErrNoRows:
return nil, nil
default:
return nil, err
}
}
func (d *BaseOAuthDB) PutState(state string, oauthState *OAuthRequest) error {
err := d.RunTxn(func(tx *sql.Tx) error {
_, err := tx.Exec(`INSERT INTO oauth_state
(state, identifier, conv_id, msg_id)
VALUES (?, ?, ?, ?)
ON DUPLICATE KEY UPDATE
identifier=VALUES(identifier),
conv_id=VALUES(conv_id),
msg_id=VALUES(msg_id)
`, state, oauthState.TokenIdentifier, oauthState.ConvID, oauthState.MsgID)
return err
})
return err
}
func (d *BaseOAuthDB) CompleteState(state string) error {
err := d.RunTxn(func(tx *sql.Tx) error {
_, err := tx.Exec(`UPDATE oauth_state
SET is_complete=true
WHERE state = ?`, state)
return err
})
return err
}
type GoogleOAuthDB struct {
*BaseOAuthDB
}
func NewGoogleOAuthDB(db *sql.DB) *GoogleOAuthDB {
return &GoogleOAuthDB{
BaseOAuthDB: NewBaseOAuthDB(db),
}
}
func (d *GoogleOAuthDB) GetToken(identifier string) (*oauth2.Token, error) {
var token oauth2.Token
var expiry int64
row := d.DB.QueryRow(`SELECT access_token, token_type, refresh_token, ROUND(UNIX_TIMESTAMP(expiry))
FROM oauth
WHERE identifier = ?`, identifier)
err := row.Scan(&token.AccessToken, &token.TokenType,
&token.RefreshToken, &expiry)
switch err {
case nil:
token.Expiry = time.Unix(expiry, 0)
return &token, nil
case sql.ErrNoRows:
return nil, nil
default:
return nil, err
}
}
func (d *GoogleOAuthDB) PutToken(identifier string, token *oauth2.Token) error {
err := d.RunTxn(func(tx *sql.Tx) error {
_, err := tx.Exec(`INSERT INTO oauth
(identifier, access_token, token_type, refresh_token, expiry, ctime, mtime)
VALUES (?, ?, ?, ?, ?, NOW(), NOW())
ON DUPLICATE KEY UPDATE
access_token=VALUES(access_token),
refresh_token=VALUES(refresh_token),
expiry=VALUES(expiry),
mtime=VALUES(mtime)
`, identifier, token.AccessToken, token.TokenType, token.RefreshToken, token.Expiry)
return err
})
return err
}
func (d *GoogleOAuthDB) DeleteToken(identifier string) error {
err := d.RunTxn(func(tx *sql.Tx) error {
_, err := tx.Exec(`DELETE FROM oauth
WHERE identifier = ?`, identifier)
return err
})
return err
}