From a3b4f97f9a806edbce5905ec1250892c5a70774e Mon Sep 17 00:00:00 2001 From: Maria Khrustaleva Date: Sat, 10 Dec 2022 21:16:51 +0100 Subject: [PATCH] Improve social authentication (#5349) --- cvat-core/src/api-implementation.ts | 11 ++ cvat-core/src/api.ts | 24 +-- cvat-core/src/server-proxy.ts | 36 ++++- cvat-ui/src/components/cvat-app.tsx | 6 + .../login-with-social-app.tsx | 48 ++++++ cvat/apps/engine/schema.py | 7 - cvat/apps/engine/serializers.py | 6 - cvat/apps/engine/views.py | 16 +- cvat/apps/iam/serializers.py | 21 ++- cvat/apps/iam/urls.py | 4 +- cvat/apps/iam/views.py | 141 ++++++++++++++++-- cvat/settings/base.py | 2 + cvat/settings/development.py | 1 + 13 files changed, 269 insertions(+), 54 deletions(-) create mode 100644 cvat-ui/src/components/login-with-social-app/login-with-social-app.tsx diff --git a/cvat-core/src/api-implementation.ts b/cvat-core/src/api-implementation.ts index dc2dff7c681..df4276b78ad 100644 --- a/cvat-core/src/api-implementation.ts +++ b/cvat-core/src/api-implementation.ts @@ -127,6 +127,17 @@ export default function implementAPI(cvat) { return result; }; + cvat.server.loginWithSocialAccount.implementation = async ( + provider: string, + code: string, + authParams?: string, + process?: string, + scope?: string, + ) => { + const result = await serverProxy.server.loginWithSocialAccount(provider, code, authParams, process, scope); + return result; + }; + cvat.users.get.implementation = async (filter) => { checkFilter(filter, { id: isInteger, diff --git a/cvat-core/src/api.ts b/cvat-core/src/api.ts index 165e33d7a9a..a4c773b6b48 100644 --- a/cvat-core/src/api.ts +++ b/cvat-core/src/api.ts @@ -177,18 +177,6 @@ function build() { const result = await PluginRegistry.apiWrapper(cvat.server.advancedAuthentication); return result; }, - /** - * Method returns enabled advanced authentication methods - * @method advancedAuthentication - * @async - * @memberof module:API.cvat.server - * @throws {module:API.cvat.exceptions.ServerError} - * @throws {module:API.cvat.exceptions.PluginError} - */ - async advancedAuthentication() { - const result = await PluginRegistry.apiWrapper(cvat.server.advancedAuthentication); - return result; - }, /** * Method allows to change user password * @method changePassword @@ -306,6 +294,18 @@ function build() { const result = await PluginRegistry.apiWrapper(cvat.server.installedApps); return result; }, + async loginWithSocialAccount( + provider: string, + code: string, + authParams?: string, + process?: string, + scope?: string, + ) { + const result = await PluginRegistry.apiWrapper( + cvat.server.loginWithSocialAccount, provider, code, authParams, process, scope, + ); + return result; + }, }, /** * Namespace is used for getting projects diff --git a/cvat-core/src/server-proxy.ts b/cvat-core/src/server-proxy.ts index 31955c3e4d2..59068f696be 100644 --- a/cvat-core/src/server-proxy.ts +++ b/cvat-core/src/server-proxy.ts @@ -366,6 +366,35 @@ async function login(credential, password) { Axios.defaults.headers.common.Authorization = `Token ${token}`; } +async function loginWithSocialAccount( + provider: string, + code: string, + authParams?: string, + process?: string, + scope?: string, +) { + removeToken(); + const data = { + code, + ...(process ? { process } : {}), + ...(scope ? { scope } : {}), + ...(authParams ? { auth_params: authParams } : {}), + }; + let authenticationResponse = null; + try { + authenticationResponse = await Axios.post(`${config.backendAPI}/auth/${provider}/login/token`, data, + { + proxy: config.proxy, + }); + } catch (errorData) { + throw generateError(errorData); + } + + token = authenticationResponse.data.key; + store.set('token', token); + Axios.defaults.headers.common.Authorization = `Token ${token}`; +} + async function logout() { try { await Axios.post(`${config.backendAPI}/auth/logout`, { @@ -447,11 +476,7 @@ async function getSelf() { async function authorized() { try { - const response = await getSelf(); - if (!store.get('token')) { - store.set('token', response.key); - Axios.defaults.headers.common.Authorization = `Token ${response.key}`; - } + await getSelf(); } catch (serverError) { if (serverError.code === 401) { // In CVAT app we use two types of authentication, @@ -2255,6 +2280,7 @@ export default Object.freeze({ request: serverRequest, userAgreements, installedApps, + loginWithSocialAccount, }), projects: Object.freeze({ diff --git a/cvat-ui/src/components/cvat-app.tsx b/cvat-ui/src/components/cvat-app.tsx index 07f29153d45..8a92bbc7e33 100644 --- a/cvat-ui/src/components/cvat-app.tsx +++ b/cvat-ui/src/components/cvat-app.tsx @@ -19,6 +19,7 @@ import 'antd/dist/antd.css'; import LogoutComponent from 'components/logout-component'; import LoginPageContainer from 'containers/login-page/login-page'; import LoginWithTokenComponent from 'components/login-with-token/login-with-token'; +import LoginWithSocialAppComponent from 'components/login-with-social-app/login-with-social-app'; import RegisterPageContainer from 'containers/register-page/register-page'; import ResetPasswordPageConfirmComponent from 'components/reset-password-confirm-page/reset-password-confirm-page'; import ResetPasswordPageComponent from 'components/reset-password-page/reset-password-page'; @@ -502,6 +503,11 @@ class CVATApplication extends React.PureComponent + { + const provider = search.get('provider'); + const code = search.get('code'); + const process = search.get('process'); + const scope = search.get('scope'); + const authParams = search.get('auth_params'); + + if (provider && code) { + cvat.server.loginWithSocialAccount(provider, code, authParams, process, scope) + .then(() => window.location.reload()) + .catch((exception: Error) => { + if (exception.message.includes('Unverified email')) { + history.push('/auth/email-verification-sent'); + } + history.push('/auth/login'); + notification.error({ + message: 'Could not log in with social account', + description: 'Go to developer console', + }); + return Promise.reject(exception); + }); + } + }, []); + + return ( +
+ +
+ ); +} diff --git a/cvat/apps/engine/schema.py b/cvat/apps/engine/schema.py index f18749c2c53..8ebd21def6a 100644 --- a/cvat/apps/engine/schema.py +++ b/cvat/apps/engine/schema.py @@ -178,13 +178,6 @@ class MetaUserSerializerExtension(AnyOfProxySerializerExtension): # field here, because these serializers don't have such. target_component = 'MetaUser' -class MetaSelfUserSerializerExtension(AnyOfProxySerializerExtension): - # Need to replace oneOf to anyOf for MetaUser variants - # Otherwise, clients cannot distinguish between classes - # using just input data. Also, we can't use discrimintator - # field here, because these serializers don't have such. - target_component = 'MetaSelfUser' - class PolymorphicProjectSerializerExtension(AnyOfProxySerializerExtension): # Need to replace oneOf to anyOf for PolymorphicProject variants # Otherwise, clients cannot distinguish between classes diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index 4b6c604acb6..fd2448b5ad0 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -53,12 +53,6 @@ class Meta: 'last_login': { 'allow_null': True } } -class SelfUserSerializer(UserSerializer): - key = serializers.CharField(allow_blank=True, required=False) - - class Meta(UserSerializer.Meta): - fields = UserSerializer.Meta.fields + ('key',) - class AttributeSerializer(serializers.ModelSerializer): values = serializers.ListField(allow_empty=True, child=serializers.CharField(max_length=200), diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index 03cc9f16b34..d06f0a97e43 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -24,9 +24,6 @@ from django.http import HttpResponse, HttpResponseNotFound, HttpResponseBadRequest from django.utils import timezone -from dj_rest_auth.models import get_token_model -from dj_rest_auth.app_settings import create_token - from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import ( OpenApiParameter, OpenApiResponse, PolymorphicProxySerializer, @@ -58,7 +55,7 @@ ) from cvat.apps.engine.models import CloudStorage as CloudStorageModel from cvat.apps.engine.serializers import ( - AboutSerializer, AnnotationFileSerializer, BasicUserSerializer, SelfUserSerializer, + AboutSerializer, AnnotationFileSerializer, BasicUserSerializer, DataMetaReadSerializer, DataMetaWriteSerializer, DataSerializer, ExceptionSerializer, FileInfoSerializer, JobReadSerializer, JobWriteSerializer, LabeledDataSerializer, LogEventSerializer, ProjectReadSerializer, ProjectWriteSerializer, ProjectSearchSerializer, @@ -1938,18 +1935,18 @@ def get_serializer_class(self): is_self = int(self.kwargs.get("pk", 0)) == user.id or \ self.action == "self" if user.is_staff: - return UserSerializer if not is_self else SelfUserSerializer + return UserSerializer if not is_self else UserSerializer else: if is_self and self.request.method in SAFE_METHODS: - return SelfUserSerializer + return UserSerializer else: return BasicUserSerializer @extend_schema(summary='Method returns an instance of a user who is currently authorized', responses={ - '200': PolymorphicProxySerializer(component_name='MetaSelfUser', + '200': PolymorphicProxySerializer(component_name='MetaUser', serializers=[ - SelfUserSerializer, BasicUserSerializer, + UserSerializer, BasicUserSerializer, ], resource_type_field_name=None), }) @action(detail=False, methods=['GET']) @@ -1957,9 +1954,6 @@ def self(self, request): """ Method returns an instance of a user who is currently authorized """ - token_model = get_token_model() - token = create_token(token_model, request.user, None) - request.user.key = token serializer_class = self.get_serializer_class() serializer = serializer_class(request.user, context={ "request": request }) return Response(serializer.data) diff --git a/cvat/apps/iam/serializers.py b/cvat/apps/iam/serializers.py index 9fa85b8a2b4..e6548ead876 100644 --- a/cvat/apps/iam/serializers.py +++ b/cvat/apps/iam/serializers.py @@ -3,7 +3,7 @@ # # SPDX-License-Identifier: MIT -from dj_rest_auth.registration.serializers import RegisterSerializer +from dj_rest_auth.registration.serializers import RegisterSerializer, SocialLoginSerializer from dj_rest_auth.serializers import PasswordResetSerializer, LoginSerializer from rest_framework.exceptions import ValidationError from rest_framework import serializers @@ -79,3 +79,22 @@ def is_username_authentication(): raise ValidationError('Unable to login with provided credentials') return self._validate_username_email(username, email, password) + + +class SocialLoginSerializerEx(SocialLoginSerializer): + auth_params = serializers.CharField(required=False, allow_blank=True, default='') + process = serializers.CharField(required=False, allow_blank=True, default='login') + scope = serializers.CharField(required=False, allow_blank=True, default='') + + def get_social_login(self, adapter, app, token, response): + request = self._get_request() + social_login = adapter.complete_login(request, app, token, response=response) + social_login.token = token + + social_login.state = { + 'process': self.initial_data.get('process'), + 'scope': self.initial_data.get('scope'), + 'auth_params': self.initial_data.get('auth_params'), + } + + return social_login diff --git a/cvat/apps/iam/urls.py b/cvat/apps/iam/urls.py index 959fa9b9387..d5b7ecbe05e 100644 --- a/cvat/apps/iam/urls.py +++ b/cvat/apps/iam/urls.py @@ -19,7 +19,7 @@ github_oauth2_callback as github_callback, google_oauth2_login as google_login, google_oauth2_callback as google_callback, - LoginViewEx, + LoginViewEx, GitHubLogin, GoogleLogin, ) urlpatterns = [ @@ -52,8 +52,10 @@ urlpatterns += [ path('github/login/', github_login, name='github_login'), path('github/login/callback/', github_callback, name='github_callback'), + path('github/login/token', GitHubLogin.as_view()), path('google/login/', google_login, name='google_login'), path('google/login/callback/', google_callback, name='google_callback'), + path('google/login/token', GoogleLogin.as_view()), ] urlpatterns = [path('auth/', include(urlpatterns))] diff --git a/cvat/apps/iam/views.py b/cvat/apps/iam/views.py index 7309f68a443..1112a16603e 100644 --- a/cvat/apps/iam/views.py +++ b/cvat/apps/iam/views.py @@ -12,24 +12,29 @@ from rest_framework import views, serializers from rest_framework.exceptions import ValidationError from rest_framework.permissions import AllowAny +from rest_framework.decorators import api_view, permission_classes from django.conf import settings from django.http import HttpResponse from django.views.decorators.http import etag as django_etag from rest_framework.response import Response -from dj_rest_auth.registration.views import RegisterView +from dj_rest_auth.registration.views import RegisterView, SocialLoginView from dj_rest_auth.views import LoginView from allauth.account import app_settings as allauth_settings from allauth.account.views import ConfirmEmailView from allauth.account.utils import has_verified_email, send_email_confirmation +from allauth.socialaccount.models import SocialLogin from allauth.socialaccount.providers.oauth2.views import OAuth2CallbackView, OAuth2LoginView +from allauth.socialaccount.providers.oauth2.client import OAuth2Client +from allauth.utils import get_request_param from furl import furl from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import OpenApiResponse, extend_schema, inline_serializer, extend_schema_view +from drf_spectacular.utils import OpenApiResponse, OpenApiParameter, extend_schema, inline_serializer, extend_schema_view from drf_spectacular.contrib.rest_auth import get_token_serializer_class from cvat.apps.iam.adapters import GitHubAdapter, GoogleAdapter from .authentication import Signer +from cvat.apps.iam.serializers import SocialLoginSerializerEx def get_context(request): from cvat.apps.organizations.models import Organization, Membership @@ -215,16 +220,87 @@ def get(self, request): class OAuth2CallbackViewEx(OAuth2CallbackView): def dispatch(self, request, *args, **kwargs): # Distinguish cancel from error - if (auth_error := request.GET.get('error', None)) and \ - auth_error == self.adapter.login_cancelled_error: - return HttpResponseRedirect(settings.SOCIALACCOUNT_CALLBACK_CANCELLED_URL) - return super().dispatch(request, *args, **kwargs) - -github_oauth2_login = OAuth2LoginView.adapter_view(GitHubAdapter) -github_oauth2_callback = OAuth2CallbackViewEx.adapter_view(GitHubAdapter) + if (auth_error := request.GET.get('error', None)): + if auth_error == self.adapter.login_cancelled_error: + return HttpResponseRedirect(settings.SOCIALACCOUNT_CALLBACK_CANCELLED_URL) + else: # unknown error + raise ValidationError(auth_error) + + code = request.GET.get('code') + + # verify request state + if self.adapter.supports_state: + state = SocialLogin.verify_and_unstash_state( + request, get_request_param(request, 'state') + ) + else: + state = SocialLogin.unstash_state(request) + + if not code: + return HttpResponseBadRequest('Parameter code not found in request') + return HttpResponseRedirect( + f'{settings.SOCIAL_APP_LOGIN_REDIRECT_URL}/?provider={self.adapter.provider_id}&code={code}' + f'&auth_params={state.get("auth_params")}&process={state.get("process")}' + f'&scope={state.get("scope")}') + + +@extend_schema( + summary="Redirets to Github authentication page", + description="Redirects to the Github authentication page. " + "After successful authentication on the provider side, " + "a redirect to the callback endpoint is performed", +) +@api_view(["GET"]) +@permission_classes([AllowAny]) +def github_oauth2_login(*args, **kwargs): + return OAuth2LoginView.adapter_view(GitHubAdapter)(*args, **kwargs) + +@extend_schema( + summary="Checks the authentication response from Github, redirects to the CVAT client if successful.", + description="Accepts a request from Github with code and state query parameters. " + "In case of successful authentication on the provider side, it will " + "redirect to the CVAT client", + parameters=[ + OpenApiParameter('code', description='Returned by github', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR), + OpenApiParameter('state', description='Returned by github', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR), + ], +) +@api_view(["GET"]) +@permission_classes([AllowAny]) +def github_oauth2_callback(*args, **kwargs): + return OAuth2CallbackViewEx.adapter_view(GitHubAdapter)(*args, **kwargs) + + +@extend_schema( + summary="Redirects to Google authentication page", + description="Redirects to the Google authentication page. " + "After successful authentication on the provider side, " + "a redirect to the callback endpoint is performed.", +) +@api_view(["GET"]) +@permission_classes([AllowAny]) +def google_oauth2_login(*args, **kwargs): + return OAuth2LoginView.adapter_view(GoogleAdapter)(*args, **kwargs) + +@extend_schema( + summary="Checks the authentication response from Google, redirects to the CVAT client if successful.", + description="Accepts a request from Google with code and state query parameters. " + "In case of successful authentication on the provider side, it will " + "redirect to the CVAT client", + parameters=[ + OpenApiParameter('code', description='Returned by google', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR), + OpenApiParameter('state', description='Returned by google', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR), + ], +) +@api_view(["GET"]) +@permission_classes([AllowAny]) +def google_oauth2_callback(*args, **kwargs): + return OAuth2CallbackViewEx.adapter_view(GoogleAdapter)(*args, **kwargs) -google_oauth2_login = OAuth2LoginView.adapter_view(GoogleAdapter) -google_oauth2_callback = OAuth2CallbackViewEx.adapter_view(GoogleAdapter) class ConfirmEmailViewEx(ConfirmEmailView): template_name = 'account/email/email_confirmation_signup_message.html' @@ -236,3 +312,46 @@ def get(self, *args, **kwargs): return self.post(*args, **kwargs) except Http404: return HttpResponseRedirect(settings.INCORRECT_EMAIL_CONFIRMATION_URL) + +@extend_schema( + methods=['POST'], + summary='Method returns an authentication token based on code parameter', + description="After successful authentication on the provider side, " + "the provider returns the 'code' parameter used to receive " + "an authentication token required for CVAT authentication.", + parameters=[ + OpenApiParameter('auth_params', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR), + OpenApiParameter('process', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR), + OpenApiParameter('scope', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR), + ], + responses=get_token_serializer_class() +) +class SocialLoginViewEx(SocialLoginView): + serializer_class = SocialLoginSerializerEx + + def post(self, request, *args, **kwargs): + # we have to re-implement this method because + # there is one case not covered by dj_rest_auth but covered by allauth + # user can be logged in with social account and "unverified" email + # (e.g. the provider doesn't provide information about email verification) + + self.request = request + self.serializer = self.get_serializer(data=self.request.data) + self.serializer.is_valid(raise_exception=True) + + if allauth_settings.EMAIL_VERIFICATION == allauth_settings.EmailVerificationMethod.MANDATORY and \ + not has_verified_email(self.serializer.validated_data.get('user')): + return HttpResponseBadRequest('Unverified email') + + self.login() + return self.get_response() + +class GitHubLogin(SocialLoginViewEx): + adapter_class = GitHubAdapter + client_class = OAuth2Client + callback_url = getattr(settings, 'GITHUB_CALLBACK_URL', None) + +class GoogleLogin(SocialLoginViewEx): + adapter_class = GoogleAdapter + client_class = OAuth2Client + callback_url = getattr(settings, 'GOOGLE_CALLBACK_URL', None) diff --git a/cvat/settings/base.py b/cvat/settings/base.py index 9e4ccc070a1..52fc7dca8b1 100644 --- a/cvat/settings/base.py +++ b/cvat/settings/base.py @@ -604,6 +604,8 @@ def add_ssh_keys(): # default = ACCOUNT_EMAIL_REQUIRED SOCIALACCOUNT_QUERY_EMAIL = True SOCIALACCOUNT_CALLBACK_CANCELLED_URL = '/auth/login' + # custom variable because by default LOGIN_REDIRECT_URL will be used + SOCIAL_APP_LOGIN_REDIRECT_URL = 'http://localhost:8080/auth/login-with-social-app' GITHUB_CALLBACK_URL = 'http://localhost:8080/api/auth/github/login/callback/' GOOGLE_CALLBACK_URL = 'http://localhost:8080/api/auth/google/login/callback/' diff --git a/cvat/settings/development.py b/cvat/settings/development.py index 26b55e0c8fa..2fadbfaa8e9 100644 --- a/cvat/settings/development.py +++ b/cvat/settings/development.py @@ -50,3 +50,4 @@ GITHUB_CALLBACK_URL = f'{UI_URL}/api/auth/github/login/callback/' GOOGLE_CALLBACK_URL = f'{UI_URL}/api/auth/google/login/callback/' SOCIALACCOUNT_CALLBACK_CANCELLED_URL = f'{UI_URL}/auth/login' + SOCIAL_APP_LOGIN_REDIRECT_URL = f'{UI_URL}/auth/login-with-social-app'