From faa79470af9647b9eff1855d3b8e510d7fdf979a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?To=CF=80Senpai?= Date: Mon, 20 Feb 2023 04:06:56 +0000 Subject: [PATCH] simplify oauth2 by removing session controller & converting a session to a struct (#241) --- _examples/oauth2/example.go | 8 +++-- oauth2/client.go | 61 +++++++++++++++++++++++----------- oauth2/client_impl.go | 58 ++++++++++++++++++++------------- oauth2/config.go | 11 +------ oauth2/session.go | 63 ------------------------------------ oauth2/session_controller.go | 58 --------------------------------- 6 files changed, 84 insertions(+), 175 deletions(-) delete mode 100644 oauth2/session.go delete mode 100644 oauth2/session_controller.go diff --git a/_examples/oauth2/example.go b/_examples/oauth2/example.go index c9878bd8..46111f75 100644 --- a/_examples/oauth2/example.go +++ b/_examples/oauth2/example.go @@ -25,6 +25,7 @@ var ( logger = log.Default() httpClient = http.DefaultClient client oauth2.Client + sessions map[string]oauth2.Session ) func init() { @@ -49,8 +50,8 @@ func handleRoot(w http.ResponseWriter, r *http.Request) { var body string cookie, err := r.Cookie("token") if err == nil { - session := client.SessionController().GetSession(cookie.Value) - if session != nil { + session, ok := sessions[cookie.Value] + if ok { var user *discord.OAuth2User user, err = client.GetUser(session) if err != nil { @@ -100,11 +101,12 @@ func handleTryLogin(w http.ResponseWriter, r *http.Request) { ) if code != "" && state != "" { identifier := randStr(32) - _, err := client.StartSession(code, state, identifier) + session, _, err := client.StartSession(code, state) if err != nil { writeError(w, "error while starting session", err) return } + sessions[identifier] = session http.SetCookie(w, &http.Cookie{Name: "token", Value: identifier}) } http.Redirect(w, r, "/", http.StatusTemporaryRedirect) diff --git a/oauth2/client.go b/oauth2/client.go index 83e0d642..fefbda7d 100644 --- a/oauth2/client.go +++ b/oauth2/client.go @@ -3,6 +3,7 @@ package oauth2 import ( "errors" "fmt" + "time" "github.com/disgoorg/snowflake/v2" @@ -14,8 +15,8 @@ var ( // ErrStateNotFound is returned when the state is not found in the SessionController. ErrStateNotFound = errors.New("state could not be found") - // ErrAccessTokenExpired is returned when the access token has expired. - ErrAccessTokenExpired = errors.New("access token expired. refresh the session") + // ErrSessionExpired is returned when the Session has expired. + ErrSessionExpired = errors.New("access token expired. refresh the session") // ErrMissingOAuth2Scope is returned when a specific OAuth2 scope is missing. ErrMissingOAuth2Scope = func(scope discord.OAuth2Scope) error { @@ -23,40 +24,62 @@ var ( } ) +// Session represents a discord access token response (https://discord.com/developers/docs/topics/oauth2#authorization-code-grant-access-token-response) +type Session struct { + // AccessToken allows requesting user information + AccessToken string `json:"access_token"` + + // RefreshToken allows refreshing the AccessToken + RefreshToken string `json:"refresh_token"` + + // Scopes returns the discord.OAuth2Scope(s) of the Session + Scopes []discord.OAuth2Scope `json:"scope"` + + // TokenType returns the discord.TokenType of the AccessToken + TokenType discord.TokenType `json:"token_type"` + + // Expiration returns the time.Time when the AccessToken expires and needs to be refreshed + Expiration time.Time `json:"expiration"` +} + +func (s Session) Expired() bool { + return s.Expiration.Before(time.Now()) +} + // Client is a high level wrapper around Discord's OAuth2 API. type Client interface { - // ID returns the configured client ID + // ID returns the configured client ID. ID() snowflake.ID - // Secret returns the configured client secret + // Secret returns the configured client secret. Secret() string - // Rest returns the underlying rest.OAuth2 + // Rest returns the underlying rest.OAuth2. Rest() rest.OAuth2 - // SessionController returns the configured SessionController - SessionController() SessionController - // StateController returns the configured StateController + // StateController returns the configured StateController. StateController() StateController - // GenerateAuthorizationURL generates an authorization URL with the given redirect URI, permissions, guildID, disableGuildSelect & scopes. State is automatically generated + // GenerateAuthorizationURL generates an authorization URL with the given redirect URI, permissions, guildID, disableGuildSelect & scopes. State is automatically generated. GenerateAuthorizationURL(redirectURI string, permissions discord.Permissions, guildID snowflake.ID, disableGuildSelect bool, scopes ...discord.OAuth2Scope) string - // GenerateAuthorizationURLState generates an authorization URL with the given redirect URI, permissions, guildID, disableGuildSelect & scopes. State is automatically generated & returned + // GenerateAuthorizationURLState generates an authorization URL with the given redirect URI, permissions, guildID, disableGuildSelect & scopes. State is automatically generated & returned. GenerateAuthorizationURLState(redirectURI string, permissions discord.Permissions, guildID snowflake.ID, disableGuildSelect bool, scopes ...discord.OAuth2Scope) (string, string) - // StartSession starts a new Session with the given authorization code & state - StartSession(code string, state string, identifier string, opts ...rest.RequestOpt) (Session, error) - // RefreshSession refreshes the given Session with the refresh token - RefreshSession(identifier string, session Session, opts ...rest.RequestOpt) (Session, error) + // StartSession starts a new Session with the given authorization code & state. + StartSession(code string, state string, opts ...rest.RequestOpt) (Session, *discord.IncomingWebhook, error) + // RefreshSession refreshes the given Session with the refresh token. + RefreshSession(session Session, opts ...rest.RequestOpt) (Session, error) + // VerifySession verifies the given Session & refreshes it if needed. + VerifySession(session Session, opts ...rest.RequestOpt) (Session, error) - // GetUser returns the discord.OAuth2User associated with the given Session. Fields filled in the struct depend on the Session.Scopes + // GetUser returns the discord.OAuth2User associated with the given Session. Fields filled in the struct depend on the Session.Scopes. GetUser(session Session, opts ...rest.RequestOpt) (*discord.OAuth2User, error) // GetMember returns the discord.Member associated with the given Session in a specific guild. GetMember(session Session, guildID snowflake.ID, opts ...rest.RequestOpt) (*discord.Member, error) - // GetGuilds returns the discord.OAuth2Guild(s) the user is a member of. This requires the discord.OAuth2ScopeGuilds scope in the Session + // GetGuilds returns the discord.OAuth2Guild(s) the user is a member of. This requires the discord.OAuth2ScopeGuilds scope in the Session. GetGuilds(session Session, opts ...rest.RequestOpt) ([]discord.OAuth2Guild, error) - // GetConnections returns the discord.Connection(s) the user has connected. This requires the discord.OAuth2ScopeConnections scope in the Session + // GetConnections returns the discord.Connection(s) the user has connected. This requires the discord.OAuth2ScopeConnections scope in the Session. GetConnections(session Session, opts ...rest.RequestOpt) ([]discord.Connection, error) - // GetApplicationRoleConnection returns the discord.ApplicationRoleConnection for the given application. This requires the discord.OAuth2ScopeRoleConnectionsWrite scope in the Session + // GetApplicationRoleConnection returns the discord.ApplicationRoleConnection for the given application. This requires the discord.OAuth2ScopeRoleConnectionsWrite scope in the Session. GetApplicationRoleConnection(session Session, applicationID snowflake.ID, opts ...rest.RequestOpt) (*discord.ApplicationRoleConnection, error) - // UpdateApplicationRoleConnection updates the discord.ApplicationRoleConnection for the given application. This requires the discord.OAuth2ScopeRoleConnectionsWrite scope in the Session + // UpdateApplicationRoleConnection updates the discord.ApplicationRoleConnection for the given application. This requires the discord.OAuth2ScopeRoleConnectionsWrite scope in the Session. UpdateApplicationRoleConnection(session Session, applicationID snowflake.ID, update discord.ApplicationRoleConnectionUpdate, opts ...rest.RequestOpt) (*discord.ApplicationRoleConnection, error) } diff --git a/oauth2/client_impl.go b/oauth2/client_impl.go index e5075af8..1220b936 100644 --- a/oauth2/client_impl.go +++ b/oauth2/client_impl.go @@ -35,10 +35,6 @@ func (c *clientImpl) Rest() rest.OAuth2 { return c.config.OAuth2 } -func (c *clientImpl) SessionController() SessionController { - return c.config.SessionController -} - func (c *clientImpl) StateController() StateController { return c.config.StateController } @@ -70,74 +66,92 @@ func (c *clientImpl) GenerateAuthorizationURLState(redirectURI string, permissio return discord.AuthorizeURL(values), state } -func (c *clientImpl) StartSession(code string, state string, identifier string, opts ...rest.RequestOpt) (Session, error) { +func (c *clientImpl) StartSession(code string, state string, opts ...rest.RequestOpt) (Session, *discord.IncomingWebhook, error) { redirectURI := c.StateController().ConsumeState(state) if redirectURI == "" { - return nil, ErrStateNotFound + return Session{}, nil, ErrStateNotFound } - exchange, err := c.Rest().GetAccessToken(c.id, c.secret, code, redirectURI, opts...) + accessToken, err := c.Rest().GetAccessToken(c.id, c.secret, code, redirectURI, opts...) if err != nil { - return nil, err + return Session{}, nil, err } - return c.SessionController().CreateSessionFromResponse(identifier, *exchange), nil + + return newSession(*accessToken), accessToken.Webhook, nil } -func (c *clientImpl) RefreshSession(identifier string, session Session, opts ...rest.RequestOpt) (Session, error) { - exchange, err := c.Rest().RefreshAccessToken(c.id, c.secret, session.RefreshToken(), opts...) +func (c *clientImpl) RefreshSession(session Session, opts ...rest.RequestOpt) (Session, error) { + accessToken, err := c.Rest().RefreshAccessToken(c.id, c.secret, session.RefreshToken, opts...) if err != nil { - return nil, err + return Session{}, err } - return c.SessionController().CreateSessionFromResponse(identifier, *exchange), nil + return newSession(*accessToken), nil +} + +func (c *clientImpl) VerifySession(session Session, opts ...rest.RequestOpt) (Session, error) { + if session.Expired() { + return c.RefreshSession(session, opts...) + } + return session, nil } func (c *clientImpl) GetUser(session Session, opts ...rest.RequestOpt) (*discord.OAuth2User, error) { if err := checkSession(session, discord.OAuth2ScopeIdentify); err != nil { return nil, err } - return c.Rest().GetCurrentUser(session.AccessToken(), opts...) + return c.Rest().GetCurrentUser(session.AccessToken, opts...) } func (c *clientImpl) GetMember(session Session, guildID snowflake.ID, opts ...rest.RequestOpt) (*discord.Member, error) { if err := checkSession(session, discord.OAuth2ScopeGuildsMembersRead); err != nil { return nil, err } - return c.Rest().GetCurrentMember(session.AccessToken(), guildID, opts...) + return c.Rest().GetCurrentMember(session.AccessToken, guildID, opts...) } func (c *clientImpl) GetGuilds(session Session, opts ...rest.RequestOpt) ([]discord.OAuth2Guild, error) { if err := checkSession(session, discord.OAuth2ScopeGuilds); err != nil { return nil, err } - return c.Rest().GetCurrentUserGuilds(session.AccessToken(), 0, 0, 0, opts...) + return c.Rest().GetCurrentUserGuilds(session.AccessToken, 0, 0, 0, opts...) } func (c *clientImpl) GetConnections(session Session, opts ...rest.RequestOpt) ([]discord.Connection, error) { if err := checkSession(session, discord.OAuth2ScopeConnections); err != nil { return nil, err } - return c.Rest().GetCurrentUserConnections(session.AccessToken(), opts...) + return c.Rest().GetCurrentUserConnections(session.AccessToken, opts...) } func (c *clientImpl) GetApplicationRoleConnection(session Session, applicationID snowflake.ID, opts ...rest.RequestOpt) (*discord.ApplicationRoleConnection, error) { if err := checkSession(session, discord.OAuth2ScopeRoleConnectionsWrite); err != nil { return nil, err } - return c.Rest().GetCurrentUserApplicationRoleConnection(session.AccessToken(), applicationID, opts...) + return c.Rest().GetCurrentUserApplicationRoleConnection(session.AccessToken, applicationID, opts...) } func (c *clientImpl) UpdateApplicationRoleConnection(session Session, applicationID snowflake.ID, update discord.ApplicationRoleConnectionUpdate, opts ...rest.RequestOpt) (*discord.ApplicationRoleConnection, error) { if err := checkSession(session, discord.OAuth2ScopeRoleConnectionsWrite); err != nil { return nil, err } - return c.Rest().UpdateCurrentUserApplicationRoleConnection(session.AccessToken(), applicationID, update, opts...) + return c.Rest().UpdateCurrentUserApplicationRoleConnection(session.AccessToken, applicationID, update, opts...) } func checkSession(session Session, scope discord.OAuth2Scope) error { - if session.Expiration().Before(time.Now()) { - return ErrAccessTokenExpired + if session.Expired() { + return ErrSessionExpired } - if !discord.HasScope(scope, session.Scopes()...) { + if !discord.HasScope(scope, session.Scopes...) { return ErrMissingOAuth2Scope(scope) } return nil } + +func newSession(accessToken discord.AccessTokenResponse) Session { + return Session{ + AccessToken: accessToken.AccessToken, + RefreshToken: accessToken.RefreshToken, + Scopes: accessToken.Scope, + TokenType: accessToken.TokenType, + Expiration: time.Now().Add(accessToken.ExpiresIn * time.Second), + } +} diff --git a/oauth2/config.go b/oauth2/config.go index 92d9f31b..4e9703ec 100644 --- a/oauth2/config.go +++ b/oauth2/config.go @@ -9,8 +9,7 @@ import ( // DefaultConfig is the configuration which is used by default func DefaultConfig() *Config { return &Config{ - Logger: log.Default(), - SessionController: NewSessionController(), + Logger: log.Default(), } } @@ -20,7 +19,6 @@ type Config struct { RestClient rest.Client RestClientConfigOpts []rest.ConfigOpt OAuth2 rest.OAuth2 - SessionController SessionController StateController StateController StateControllerConfigOpts []StateControllerConfigOpt } @@ -72,13 +70,6 @@ func WithOAuth2(oauth2 rest.OAuth2) ConfigOpt { } } -// WithSessionController applies a custom SessionController to the OAuth2 client -func WithSessionController(sessionController SessionController) ConfigOpt { - return func(config *Config) { - config.SessionController = sessionController - } -} - // WithStateController applies a custom StateController to the OAuth2 client func WithStateController(stateController StateController) ConfigOpt { return func(config *Config) { diff --git a/oauth2/session.go b/oauth2/session.go deleted file mode 100644 index 90f663ca..00000000 --- a/oauth2/session.go +++ /dev/null @@ -1,63 +0,0 @@ -package oauth2 - -import ( - "time" - - "github.com/disgoorg/disgo/discord" -) - -var _ Session = (*sessionImpl)(nil) - -// Session represents a discord access token response (https://discord.com/developers/docs/topics/oauth2#authorization-code-grant-access-token-response) -type Session interface { - // AccessToken allows requesting user information - AccessToken() string - - // RefreshToken allows refreshing the AccessToken - RefreshToken() string - - // Scopes returns the discord.OAuth2Scope(s) of the Session - Scopes() []discord.OAuth2Scope - - // TokenType returns the discord.TokenType of the AccessToken - TokenType() discord.TokenType - - // Expiration returns the time.Time when the AccessToken expires and needs to be refreshed - Expiration() time.Time - - // Webhook returns the discord.IncomingWebhook when the discord.OAuth2ScopeWebhookIncoming is set - Webhook() *discord.IncomingWebhook -} - -type sessionImpl struct { - accessToken string - refreshToken string - scopes []discord.OAuth2Scope - tokenType discord.TokenType - expiration time.Time - webhook *discord.IncomingWebhook -} - -func (s *sessionImpl) AccessToken() string { - return s.accessToken -} - -func (s *sessionImpl) RefreshToken() string { - return s.refreshToken -} - -func (s *sessionImpl) Scopes() []discord.OAuth2Scope { - return s.scopes -} - -func (s *sessionImpl) TokenType() discord.TokenType { - return s.tokenType -} - -func (s *sessionImpl) Expiration() time.Time { - return s.expiration -} - -func (s *sessionImpl) Webhook() *discord.IncomingWebhook { - return s.webhook -} diff --git a/oauth2/session_controller.go b/oauth2/session_controller.go deleted file mode 100644 index 6aa61fde..00000000 --- a/oauth2/session_controller.go +++ /dev/null @@ -1,58 +0,0 @@ -package oauth2 - -import ( - "time" - - "github.com/disgoorg/disgo/discord" -) - -var _ SessionController = (*sessionControllerImpl)(nil) - -// SessionController lets you manage your Session(s) -type SessionController interface { - // GetSession returns the Session for the given identifier or nil if none was found - GetSession(identifier string) Session - - // CreateSession creates a new Session from the given identifier, access token, refresh token, scope, token type, expiration and webhook - CreateSession(identifier string, accessToken string, refreshToken string, scopes []discord.OAuth2Scope, tokenType discord.TokenType, expiration time.Time, webhook *discord.IncomingWebhook) Session - - // CreateSessionFromResponse creates a new Session from the given identifier and discord.AccessTokenResponse payload - CreateSessionFromResponse(identifier string, response discord.AccessTokenResponse) Session -} - -// NewSessionController returns a new empty SessionController -func NewSessionController() SessionController { - return NewSessionControllerWithSessions(map[string]Session{}) -} - -// NewSessionControllerWithSessions returns a new SessionController with the given Session(s) -func NewSessionControllerWithSessions(sessions map[string]Session) SessionController { - return &sessionControllerImpl{sessions: sessions} -} - -type sessionControllerImpl struct { - sessions map[string]Session -} - -func (c *sessionControllerImpl) GetSession(identifier string) Session { - return c.sessions[identifier] -} - -func (c *sessionControllerImpl) CreateSession(identifier string, accessToken string, refreshToken string, scopes []discord.OAuth2Scope, tokenType discord.TokenType, expiration time.Time, webhook *discord.IncomingWebhook) Session { - session := &sessionImpl{ - accessToken: accessToken, - refreshToken: refreshToken, - scopes: scopes, - tokenType: tokenType, - expiration: expiration, - webhook: webhook, - } - - c.sessions[identifier] = session - - return session -} - -func (c *sessionControllerImpl) CreateSessionFromResponse(identifier string, response discord.AccessTokenResponse) Session { - return c.CreateSession(identifier, response.AccessToken, response.RefreshToken, response.Scope, response.TokenType, time.Now().Add(response.ExpiresIn*time.Second), response.Webhook) -}