diff --git a/lib/omniauth/strategies/oauth2.rb b/lib/omniauth/strategies/oauth2.rb index 3ffff1b..890cb3e 100644 --- a/lib/omniauth/strategies/oauth2.rb +++ b/lib/omniauth/strategies/oauth2.rb @@ -64,13 +64,15 @@ def token_params end def callback_phase # rubocop:disable AbcSize, CyclomaticComplexity, MethodLength, PerceivedComplexity - error = request.params["error_reason"] || request.params["error"] + params = request.params + error = params['error_reason'] || params['error'] if error - fail!(error, CallbackError.new(request.params["error"], request.params["error_description"] || request.params["error_reason"], request.params["error_uri"])) - elsif !options.provider_ignores_state && (request.params["state"].to_s.empty? || request.params["state"] != session.delete("omniauth.state")) - fail!(:csrf_detected, CallbackError.new(:csrf_detected, "CSRF detected")) + description = params['error_description'] || params['error_reason'] + fail!(error, CallbackError.new(params['error'], description, params['error_uri'])) + elsif csrf_detected? + fail!(:csrf_detected, CallbackError.new(:csrf_detected, 'CSRF detected')) else - self.access_token = build_access_token + self.access_token = params['access_token'] ? build_access_token_from_token : build_access_token self.access_token = access_token.refresh! if access_token.expired? super end @@ -82,12 +84,35 @@ def callback_phase # rubocop:disable AbcSize, CyclomaticComplexity, MethodLength fail!(:failed_to_connect, e) end + def csrf_detected? + params = request.params + if params['code'] + return if options.provider_ignores_state + state = params['state'].to_s + state.empty? || state != session.delete('omniauth.state') + elsif params['access_token'] + csrf_token = params['csrf_token'] + csrf_token.empty? || csrf_token != session['_csrf_token'] + end + end + protected def build_access_token verifier = request.params["code"] client.auth_code.get_token(verifier, {:redirect_uri => callback_url}.merge(token_params.to_hash(:symbolize_keys => true)), deep_symbolize(options.auth_token_params)) end + + def build_access_token_from_token + token_opts = request.params.slice( + 'access_token', + 'expires_at', + 'expires_in', + 'refresh_token', + ) + token_opts.merge!(options.access_token_options || {}) + ::OAuth2::AccessToken.from_hash(client, token_opts) + end def deep_symbolize(options) hash = {}