diff --git a/api/src/auth/drivers/oauth2.ts b/api/src/auth/drivers/oauth2.ts index 4be584d427b8f..a3a9a1b3eaea2 100644 --- a/api/src/auth/drivers/oauth2.ts +++ b/api/src/auth/drivers/oauth2.ts @@ -7,7 +7,12 @@ import { getAuthProvider } from '../../auth'; import env from '../../env'; import { AuthenticationService, UsersService } from '../../services'; import { AuthDriverOptions, User, AuthData, SessionData } from '../../types'; -import { InvalidCredentialsException, ServiceUnavailableException, InvalidConfigException } from '../../exceptions'; +import { + InvalidCredentialsException, + ServiceUnavailableException, + InvalidConfigException, + InvalidTokenException, +} from '../../exceptions'; import { respond } from '../../middleware/respond'; import asyncHandler from '../../utils/async-handler'; import { Url } from '../../utils/url'; @@ -38,7 +43,8 @@ export class OAuth2AuthDriver extends LocalAuthDriver { authorization_endpoint: authorizeUrl, token_endpoint: accessUrl, userinfo_endpoint: profileUrl, - issuer: additionalConfig.provider, + // Required for openid providers (openid flow should be preferred!) + issuer: additionalConfig.issuerUrl, }); this.client = new issuer.Client({ @@ -53,16 +59,20 @@ export class OAuth2AuthDriver extends LocalAuthDriver { return generators.codeVerifier(); } - generateAuthUrl(codeVerifier: string): string { + generateAuthUrl(codeVerifier: string, prompt = false): string { try { const codeChallenge = generators.codeChallenge(codeVerifier); + const paramsConfig = typeof this.config.params === 'object' ? this.config.params : {}; + return this.client.authorizationUrl({ scope: this.config.scope ?? 'email', + access_type: 'offline', + prompt: prompt ? 'consent' : undefined, + ...paramsConfig, code_challenge: codeChallenge, code_challenge_method: 'S256', // Some providers require state even with PKCE state: codeChallenge, - access_type: 'offline', }); } catch (e) { throw handleError(e); @@ -160,16 +170,21 @@ export class OAuth2AuthDriver extends LocalAuthDriver { } } - if (!authData?.refreshToken) { - return sessionData; + if (authData?.refreshToken) { + try { + const tokenSet = await this.client.refresh(authData.refreshToken); + // Update user refreshToken if provided + if (tokenSet.refresh_token) { + await this.usersService.updateOne(user.id, { + auth_data: JSON.stringify({ refreshToken: tokenSet.refresh_token }), + }); + } + } catch (e) { + throw handleError(e); + } } - try { - const tokenSet = await this.client.refresh(authData.refreshToken); - return { accessToken: tokenSet.access_token }; - } catch (e) { - throw handleError(e); - } + return sessionData; } } @@ -177,7 +192,7 @@ const handleError = (e: any) => { if (e instanceof errors.OPError) { if (e.error === 'invalid_grant') { // Invalid token - return new InvalidCredentialsException(); + return new InvalidTokenException(); } // Server response error return new ServiceUnavailableException('Service returned unexpected response', { @@ -199,7 +214,8 @@ export function createOAuth2AuthRouter(providerName: string): Router { (req, res) => { const provider = getAuthProvider(providerName) as OAuth2AuthDriver; const codeVerifier = provider.generateCodeVerifier(); - const token = jwt.sign({ verifier: codeVerifier, redirect: req.query.redirect }, env.SECRET as string, { + const prompt = !!req.query.prompt; + const token = jwt.sign({ verifier: codeVerifier, redirect: req.query.redirect, prompt }, env.SECRET as string, { expiresIn: '5m', issuer: 'directus', }); @@ -209,7 +225,7 @@ export function createOAuth2AuthRouter(providerName: string): Router { sameSite: 'lax', }); - return res.redirect(provider.generateAuthUrl(codeVerifier)); + return res.redirect(provider.generateAuthUrl(codeVerifier, prompt)); }, respond ); @@ -223,12 +239,14 @@ export function createOAuth2AuthRouter(providerName: string): Router { tokenData = jwt.verify(req.cookies[`oauth2.${providerName}`], env.SECRET as string, { issuer: 'directus' }) as { verifier: string; redirect?: string; + prompt: boolean; }; } catch (e) { + logger.warn(`Couldn't verify OAuth2 cookie`); throw new InvalidCredentialsException(); } - const { verifier, redirect } = tokenData; + const { verifier, redirect, prompt } = tokenData; const authenticationService = new AuthenticationService({ accountability: { @@ -254,6 +272,11 @@ export function createOAuth2AuthRouter(providerName: string): Router { state: req.query.state, }); } catch (error: any) { + // Prompt user for a new refresh_token if invalidated + if (error instanceof InvalidTokenException && !prompt) { + return res.redirect(`./?${redirect ? `redirect=${redirect}&` : ''}prompt=true`); + } + logger.warn(error); if (redirect) { @@ -263,6 +286,8 @@ export function createOAuth2AuthRouter(providerName: string): Router { reason = 'SERVICE_UNAVAILABLE'; } else if (error instanceof InvalidCredentialsException) { reason = 'INVALID_USER'; + } else if (error instanceof InvalidTokenException) { + reason = 'INVALID_TOKEN'; } return res.redirect(`${redirect.split('?')[0]}?reason=${reason}`); diff --git a/api/src/auth/drivers/openid.ts b/api/src/auth/drivers/openid.ts index ae9f8e9bba87b..438d77c733cd7 100644 --- a/api/src/auth/drivers/openid.ts +++ b/api/src/auth/drivers/openid.ts @@ -7,7 +7,12 @@ import { getAuthProvider } from '../../auth'; import env from '../../env'; import { AuthenticationService, UsersService } from '../../services'; import { AuthDriverOptions, User, AuthData, SessionData } from '../../types'; -import { InvalidCredentialsException, ServiceUnavailableException, InvalidConfigException } from '../../exceptions'; +import { + InvalidCredentialsException, + ServiceUnavailableException, + InvalidConfigException, + InvalidTokenException, +} from '../../exceptions'; import { respond } from '../../middleware/respond'; import asyncHandler from '../../utils/async-handler'; import { Url } from '../../utils/url'; @@ -62,17 +67,21 @@ export class OpenIDAuthDriver extends LocalAuthDriver { return generators.codeVerifier(); } - async generateAuthUrl(codeVerifier: string): Promise { + async generateAuthUrl(codeVerifier: string, prompt = false): Promise { try { const client = await this.client; const codeChallenge = generators.codeChallenge(codeVerifier); + const paramsConfig = typeof this.config.params === 'object' ? this.config.params : {}; + return client.authorizationUrl({ scope: this.config.scope ?? 'openid profile email', + access_type: 'offline', + prompt: prompt ? 'consent' : undefined, + ...paramsConfig, code_challenge: codeChallenge, code_challenge_method: 'S256', // Some providers require state even with PKCE state: codeChallenge, - access_type: 'offline', }); } catch (e) { throw handleError(e); @@ -173,17 +182,22 @@ export class OpenIDAuthDriver extends LocalAuthDriver { } } - if (!authData?.refreshToken) { - return sessionData; + if (authData?.refreshToken) { + try { + const client = await this.client; + const tokenSet = await client.refresh(authData.refreshToken); + // Update user refreshToken if provided + if (tokenSet.refresh_token) { + await this.usersService.updateOne(user.id, { + auth_data: JSON.stringify({ refreshToken: tokenSet.refresh_token }), + }); + } + } catch (e) { + throw handleError(e); + } } - try { - const client = await this.client; - const tokenSet = await client.refresh(authData.refreshToken); - return { accessToken: tokenSet.access_token }; - } catch (e) { - throw handleError(e); - } + return sessionData; } } @@ -191,7 +205,7 @@ const handleError = (e: any) => { if (e instanceof errors.OPError) { if (e.error === 'invalid_grant') { // Invalid token - return new InvalidCredentialsException(); + return new InvalidTokenException(); } // Server response error return new ServiceUnavailableException('Service returned unexpected response', { @@ -213,7 +227,8 @@ export function createOpenIDAuthRouter(providerName: string): Router { asyncHandler(async (req, res) => { const provider = getAuthProvider(providerName) as OpenIDAuthDriver; const codeVerifier = provider.generateCodeVerifier(); - const token = jwt.sign({ verifier: codeVerifier, redirect: req.query.redirect }, env.SECRET as string, { + const prompt = !!req.query.prompt; + const token = jwt.sign({ verifier: codeVerifier, redirect: req.query.redirect, prompt }, env.SECRET as string, { expiresIn: '5m', issuer: 'directus', }); @@ -223,7 +238,7 @@ export function createOpenIDAuthRouter(providerName: string): Router { sameSite: 'lax', }); - return res.redirect(await provider.generateAuthUrl(codeVerifier)); + return res.redirect(await provider.generateAuthUrl(codeVerifier, prompt)); }), respond ); @@ -237,12 +252,14 @@ export function createOpenIDAuthRouter(providerName: string): Router { tokenData = jwt.verify(req.cookies[`openid.${providerName}`], env.SECRET as string, { issuer: 'directus' }) as { verifier: string; redirect?: string; + prompt: boolean; }; } catch (e) { + logger.warn(`Couldn't verify OpenID cookie`); throw new InvalidCredentialsException(); } - const { verifier, redirect } = tokenData; + const { verifier, redirect, prompt } = tokenData; const authenticationService = new AuthenticationService({ accountability: { @@ -268,6 +285,11 @@ export function createOpenIDAuthRouter(providerName: string): Router { state: req.query.state, }); } catch (error: any) { + // Prompt user for a new refresh_token if invalidated + if (error instanceof InvalidTokenException && !prompt) { + return res.redirect(`./?${redirect ? `redirect=${redirect}&` : ''}prompt=true`); + } + logger.warn(error); if (redirect) { @@ -277,6 +299,8 @@ export function createOpenIDAuthRouter(providerName: string): Router { reason = 'SERVICE_UNAVAILABLE'; } else if (error instanceof InvalidCredentialsException) { reason = 'INVALID_USER'; + } else if (error instanceof InvalidTokenException) { + reason = 'INVALID_TOKEN'; } return res.redirect(`${redirect.split('?')[0]}?reason=${reason}`); diff --git a/api/src/exceptions/index.ts b/api/src/exceptions/index.ts index 5d5f255707aba..2903ec4728a1d 100644 --- a/api/src/exceptions/index.ts +++ b/api/src/exceptions/index.ts @@ -8,6 +8,7 @@ export * from './invalid-ip'; export * from './invalid-otp'; export * from './invalid-payload'; export * from './invalid-query'; +export * from './invalid-token'; export * from './method-not-allowed'; export * from './range-not-satisfiable'; export * from './route-not-found'; diff --git a/api/src/exceptions/invalid-token.ts b/api/src/exceptions/invalid-token.ts new file mode 100644 index 0000000000000..223fff4ccc226 --- /dev/null +++ b/api/src/exceptions/invalid-token.ts @@ -0,0 +1,7 @@ +import { BaseException } from '@directus/shared/exceptions'; + +export class InvalidTokenException extends BaseException { + constructor(message = 'Invalid token') { + super(message, 403, 'INVALID_TOKEN'); + } +}