Skip to content

Commit

Permalink
Hook OAuth storage into database.
Browse files Browse the repository at this point in the history
  • Loading branch information
cjslep committed Jul 16, 2019
1 parent 55c38d2 commit f258a10
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 16 deletions.
193 changes: 193 additions & 0 deletions database.go
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/go-fed/activity/streams"
"github.com/go-fed/activity/streams/vocab"
_ "github.com/lib/pq"
"gopkg.in/oauth2.v3"
)

type Database interface{}
Expand All @@ -53,6 +54,15 @@ type database struct {
instancePolicies *sql.Stmt
userPolicies *sql.Stmt
userResolutions *sql.Stmt
// Prepared statements for oauth
createTokenInfo *sql.Stmt
removeTokenByCode *sql.Stmt
removeTokenByAccess *sql.Stmt
removeTokenByRefresh *sql.Stmt
getTokenByCode *sql.Stmt
getTokenByAccess *sql.Stmt
getTokenByRefresh *sql.Stmt
getClientById *sql.Stmt
// Prepared statements for the database required by go-fed
inboxContains *sql.Stmt
getInbox *sql.Stmt
Expand Down Expand Up @@ -171,6 +181,40 @@ func newDatabase(c *config, a Application, debug bool) (db *database, err error)
return
}

// prepared statements for oauth
db.createTokenInfo, err = db.db.Prepare(sqlgen.CreateTokenInfo())
if err != nil {
return
}
db.removeTokenByCode, err = db.db.Prepare(sqlgen.RemoveTokenByCode())
if err != nil {
return
}
db.removeTokenByAccess, err = db.db.Prepare(sqlgen.RemoveTokenByAccess())
if err != nil {
return
}
db.removeTokenByRefresh, err = db.db.Prepare(sqlgen.RemoveTokenByRefresh())
if err != nil {
return
}
db.getTokenByCode, err = db.db.Prepare(sqlgen.GetTokenByCode())
if err != nil {
return
}
db.getTokenByAccess, err = db.db.Prepare(sqlgen.GetTokenByAccess())
if err != nil {
return
}
db.getTokenByRefresh, err = db.db.Prepare(sqlgen.GetTokenByRefresh())
if err != nil {
return
}
db.getClientById, err = db.db.Prepare(sqlgen.GetClientById())
if err != nil {
return
}

// go-fed statement preparations
db.inboxContains, err = db.db.Prepare(sqlgen.InboxContains())
if err != nil {
Expand Down Expand Up @@ -302,6 +346,7 @@ func postgresConn(pg postgresConfig) (s string, err error) {
}

func (d *database) Close() error {
// apcore
d.hashPassForUserID.Close()
d.userIdForEmail.Close()
d.userIdForBoxPath.Close()
Expand All @@ -313,6 +358,16 @@ func (d *database) Close() error {
d.instancePolicies.Close()
d.userPolicies.Close()
d.userResolutions.Close()
// oauth
d.createTokenInfo.Close()
d.removeTokenByCode.Close()
d.removeTokenByAccess.Close()
d.removeTokenByRefresh.Close()
d.getTokenByCode.Close()
d.getTokenByAccess.Close()
d.getTokenByRefresh.Close()
d.getClientById.Close()
// go-fed
d.inboxContains.Close()
d.getInbox.Close()
d.actorForOutbox.Close()
Expand Down Expand Up @@ -559,6 +614,144 @@ func (d *database) UserResolutions(c context.Context, userId string) (r []resolu
return
}

// apcore oauth functions

func (d *database) CreateTokenInfo(c context.Context, info oauth2.TokenInfo) error {
_, err := d.createTokenInfo.ExecContext(
c,
info.GetClientID(),
info.GetUserID(),
info.GetRedirectURI(),
info.GetScope(),
info.GetCode(),
info.GetCodeCreateAt(),
info.GetCodeExpiresIn(),
info.GetAccess(),
info.GetAccessCreateAt(),
info.GetAccessExpiresIn(),
info.GetRefresh(),
info.GetRefreshCreateAt(),
info.GetRefreshExpiresIn())
return err
}

func (d *database) RemoveTokenByCode(c context.Context, code string) error {
_, err := d.removeTokenByCode.ExecContext(
c,
code)
return err
}

func (d *database) RemoveTokenByAccess(c context.Context, access string) error {
_, err := d.removeTokenByAccess.ExecContext(
c,
access)
return err
}

func (d *database) RemoveTokenByRefresh(c context.Context, refresh string) error {
_, err := d.removeTokenByRefresh.ExecContext(
c,
refresh)
return err
}

func (d *database) mustScanRowsForOneToken(r *sql.Rows, ti *tokenInfo) (err error) {
var n int
for r.Next() {
if n > 0 {
err = fmt.Errorf("multiple rows when obtaining OAuth2 token")
return
}
if err = r.Scan(
&ti.clientId,
&ti.userId,
&ti.redirectURI,
&ti.scope,
&ti.code,
&ti.codeCreated,
&ti.codeExpires,
&ti.access,
&ti.accessCreated,
&ti.accessExpires,
&ti.refresh,
&ti.refreshCreated,
&ti.refreshExpires); err != nil {
return
}
n++
}
err = r.Err()
return
}

func (d *database) GetTokenByCode(c context.Context, code string) (oti oauth2.TokenInfo, err error) {
ti := &tokenInfo{}
oti = ti
var r *sql.Rows
r, err = d.getTokenByCode.QueryContext(c, code)
if err != nil {
return
}
defer r.Close()
err = d.mustScanRowsForOneToken(r, ti)
return
}

func (d *database) GetTokenByAccess(c context.Context, access string) (oti oauth2.TokenInfo, err error) {
ti := &tokenInfo{}
oti = ti
var r *sql.Rows
r, err = d.getTokenByAccess.QueryContext(c, access)
if err != nil {
return
}
defer r.Close()
err = d.mustScanRowsForOneToken(r, ti)
return
}

func (d *database) GetTokenByRefresh(c context.Context, refresh string) (oti oauth2.TokenInfo, err error) {
ti := &tokenInfo{}
oti = ti
var r *sql.Rows
r, err = d.getTokenByRefresh.QueryContext(c, refresh)
if err != nil {
return
}
defer r.Close()
err = d.mustScanRowsForOneToken(r, ti)
return
}

func (d *database) GetClientById(c context.Context, id string) (oci oauth2.ClientInfo, err error) {
ci := &clientInfo{}
oci = ci
var r *sql.Rows
r, err = d.getClientById.QueryContext(c, id)
if err != nil {
return
}
defer r.Close()
var n int
for r.Next() {
if n > 0 {
err = fmt.Errorf("multiple rows when obtaining OAuth2 client")
return
}
if err = r.Scan(
&ci.id,
&ci.secret,
&ci.domain,
&ci.userId); err != nil {
return
}
n++
}
err = r.Err()
return
}

// go-fed ActivityPub implementation

func (d *database) InboxContains(c context.Context, inbox, id *url.URL) (contains bool, err error) {
Expand Down
50 changes: 50 additions & 0 deletions db_postgres.go
Expand Up @@ -39,6 +39,16 @@ type sqlGenerator interface {
UserPolicies() string
InsertResolutions() string
UserResolutions() string

CreateTokenInfo() string
RemoveTokenByCode() string
RemoveTokenByAccess() string
RemoveTokenByRefresh() string
GetTokenByCode() string
GetTokenByAccess() string
GetTokenByRefresh() string
GetClientById() string

InboxContains() string
GetInbox() string
SetInboxUpdate() string
Expand Down Expand Up @@ -372,6 +382,46 @@ func (p *pgV0) UserResolutions() string {
return ""
}

func (p *pgV0) CreateTokenInfo() string {
// TODO
return ""
}

func (p *pgV0) RemoveTokenByCode() string {
// TODO
return ""
}

func (p *pgV0) RemoveTokenByAccess() string {
// TODO
return ""
}

func (p *pgV0) RemoveTokenByRefresh() string {
// TODO
return ""
}

func (p *pgV0) GetTokenByCode() string {
// TODO
return ""
}

func (p *pgV0) GetTokenByAccess() string {
// TODO
return ""
}

func (p *pgV0) GetTokenByRefresh() string {
// TODO
return ""
}

func (p *pgV0) GetClientById() string {
// TODO
return ""
}

func (p *pgV0) InboxContains() string {
// TODO
return ""
Expand Down
25 changes: 9 additions & 16 deletions oauth_stores.go
Expand Up @@ -17,6 +17,7 @@
package apcore

import (
"context"
"time"

"gopkg.in/oauth2.v3"
Expand Down Expand Up @@ -164,44 +165,37 @@ func newTokenStore(d *database) (t *tokenStore, err error) {

// Create and store the new token information
func (t *tokenStore) Create(info oauth2.TokenInfo) error {
// TODO
return nil
return t.d.CreateTokenInfo(context.Background(), info)
}

// Delete the authorization code
func (t *tokenStore) RemoveByCode(code string) error {
// TODO
return nil
return t.d.RemoveTokenByCode(context.Background(), code)
}

// Use the access token to delete the token information
func (t *tokenStore) RemoveByAccess(access string) error {
// TODO
return nil
return t.d.RemoveTokenByAccess(context.Background(), access)
}

// Use the refresh token to delete the token information
func (t *tokenStore) RemoveByRefresh(refresh string) error {
// TODO
return nil
return t.d.RemoveTokenByRefresh(context.Background(), refresh)
}

// Use the authorization code for token information data
func (t *tokenStore) GetByCode(code string) (oauth2.TokenInfo, error) {
// TODO
return nil, nil
return t.d.GetTokenByCode(context.Background(), code)
}

// Use the access token for token information data
func (t *tokenStore) GetByAccess(access string) (oauth2.TokenInfo, error) {
// TODO
return nil, nil
return t.d.GetTokenByAccess(context.Background(), access)
}

// Use the refresh token for token information data
func (t *tokenStore) GetByRefresh(refresh string) (oauth2.TokenInfo, error) {
// TODO
return nil, nil
return t.d.GetTokenByRefresh(context.Background(), refresh)
}

var _ oauth2.ClientInfo = &clientInfo{}
Expand Down Expand Up @@ -245,6 +239,5 @@ func newClientStore(d *database) (t *clientStore, err error) {

// According to the ID for the client information
func (c *clientStore) GetByID(id string) (oauth2.ClientInfo, error) {
// TODO
return nil, nil
return c.d.GetClientById(context.Background(), id)
}

0 comments on commit f258a10

Please sign in to comment.