Skip to content

Commit

Permalink
Support multiple JWKS matches in client assertion auth (#1921)
Browse files Browse the repository at this point in the history
* Add jose mocks

* mock error

* Add Mock error class

* Comments

* Readjust lint

* Add test coverage and returned errors rather than throw

* Fix lint error

* Move catch logic for multiple keys to helper
  • Loading branch information
jamestouri committed Apr 27, 2023
1 parent e2bf308 commit f710952
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 3 deletions.
92 changes: 91 additions & 1 deletion packages/server/src/oauth/token.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { OAuthGrantType, OAuthTokenType, createReference, parseJWTPayload, parse
import { AccessPolicy, ClientApplication, Login, Project, SmartAppLaunch } from '@medplum/fhirtypes';
import { randomUUID } from 'crypto';
import express from 'express';
import { SignJWT, generateKeyPair } from 'jose';
import { SignJWT, generateKeyPair, jwtVerify } from 'jose';
import fetch from 'node-fetch';
import request from 'supertest';
import { createClient } from '../admin/client';
Expand All @@ -18,12 +18,27 @@ import { hashCode } from './token';
jest.mock('jose', () => {
const core = jest.requireActual('@medplum/core');
const original = jest.requireActual('jose');
let count = 0;
return {
...original,
jwtVerify: jest.fn((credential: string) => {
const payload = core.parseJWTPayload(credential);
if (payload.invalid) {
throw new Error('Verification failed');
} else if (payload.multipleMatching) {
count = payload.successVerified ? count + 1 : 0;
let error: MockJoseMultipleMatchingError;
if (count <= 1) {
error = new MockJoseMultipleMatchingError(
'multiple matching keys found in the JSON Web Key Set',
'ERR_JWKS_MULTIPLE_MATCHING_KEYS'
);
} else if (count === 2) {
error = new MockJoseMultipleMatchingError('Verification fail', 'ERR_JWS_SIGNATURE_VERIFICATION_FAILED');
} else {
return { payload };
}
throw error;
}
return { payload };
}),
Expand Down Expand Up @@ -100,6 +115,10 @@ describe('OAuth2 Token', () => {
});
});

afterEach(() => {
jest.clearAllMocks();
});

afterAll(async () => {
await shutdownApp();
});
Expand Down Expand Up @@ -1279,6 +1298,64 @@ describe('OAuth2 Token', () => {
});
});

test('Client assertion multiple matching 3rd check success', async () => {
// Create a new client
const client2 = await createClient(systemRepo, { project, name: 'Test Client 2' });

// Set the client jwksUri
await systemRepo.updateResource<ClientApplication>({ ...client2, jwksUri: 'https://example.com/jwks.json' });

// Create the JWT
const keyPair = await generateKeyPair('ES384');
const jwt = await new SignJWT({ multipleMatching: true, successVerified: true })
.setProtectedHeader({ alg: 'ES384' })
.setIssuedAt()
.setIssuer(client2.id as string)
.setSubject(client2.id as string)
.setAudience('http://localhost:8103/oauth2/token')
.setExpirationTime('2h')
.sign(keyPair.privateKey);
expect(jwt).toBeDefined();

// Then use the JWT for a client credentials grant
const res = await request(app).post('/oauth2/token').type('form').send({
grant_type: 'client_credentials',
client_assertion_type: 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer',
client_assertion: jwt,
});
expect(res.status).toBe(200);
expect(jwtVerify).toBeCalledTimes(3);
});

test('Client assertion multiple inner error', async () => {
// Create a new client
const client2 = await createClient(systemRepo, { project, name: 'Test Client 2' });

// Set the client jwksUri
await systemRepo.updateResource<ClientApplication>({ ...client2, jwksUri: 'https://example.com/jwks.json' });

// Create the JWT
const keyPair = await generateKeyPair('ES384');
const jwt = await new SignJWT({ multipleMatching: true })
.setProtectedHeader({ alg: 'ES384' })
.setIssuedAt()
.setIssuer(client2.id as string)
.setSubject(client2.id as string)
.setAudience('http://localhost:8103/oauth2/token')
.setExpirationTime('2h')
.sign(keyPair.privateKey);
expect(jwt).toBeDefined();

// Then use the JWT for a client credentials grant
const res = await request(app).post('/oauth2/token').type('form').send({
grant_type: 'client_credentials',
client_assertion_type: 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer',
client_assertion: jwt,
});
expect(res.status).toBe(400);
expect(jwtVerify).toBeCalledTimes(2);
});

test('Client assertion invalid assertion type', async () => {
// Create a new client
const client2 = await createClient(systemRepo, { project, name: 'Test Client 2' });
Expand Down Expand Up @@ -1429,3 +1506,16 @@ describe('OAuth2 Token', () => {
expect(res.body.error_description).toBe('Invalid subject_token_type');
});
});

class MockJoseMultipleMatchingError extends Error {
code: string;
[Symbol.asyncIterator]!: () => AsyncIterableIterator<any>;
constructor(message: string, code: string) {
super(message);
this.name = 'CustomError';
this.code = code;
this[Symbol.asyncIterator] = async function* () {
yield 'key1', yield 'key2';
};
}
}
9 changes: 8 additions & 1 deletion packages/server/src/oauth/token.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {
revokeLogin,
timingSafeEqualStr,
tryLogin,
verifyMultipleMatchingException,
} from './utils';

type ClientIdAndSecret = { error?: string; clientId?: string; clientSecret?: string };
Expand Down Expand Up @@ -462,7 +463,13 @@ async function parseClientAssertion(clientAssertiontype: string, clientAssertion

try {
await jwtVerify(clientAssertion, JWKS, verifyOptions);
} catch (err) {
} catch (error: any) {
// There are some edge cases where there are multiple matching JWKS
// and we need to iterate throught the JWKSMultipleMatchingKeys error
// and return the first verified match
if (error?.code === 'ERR_JWKS_MULTIPLE_MATCHING_KEYS') {
return await verifyMultipleMatchingException(error, clientId, clientAssertion, verifyOptions, client);
}
return { error: 'Invalid client assertion signature' };
}

Expand Down
29 changes: 28 additions & 1 deletion packages/server/src/oauth/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import {
} from '@medplum/fhirtypes';
import bcrypt from 'bcryptjs';
import { timingSafeEqual } from 'crypto';
import { JWTPayload } from 'jose';
import { JWTPayload, jwtVerify, VerifyOptions } from 'jose';
import fetch from 'node-fetch';
import { authenticator } from 'otplib';
import { getAccessPolicyForLogin } from '../fhir/accesspolicy';
Expand Down Expand Up @@ -715,3 +715,30 @@ export async function getExternalUserInfo(
throw new OperationOutcomeError(badRequest('Failed to verify code - check your identity provider configuration'));
}
}

interface ValidationAssertion {
clientId?: string;
clientSecret?: string;
error?: string;
}
export async function verifyMultipleMatchingException(
publicKeys: AsyncIterableIterator<any>,
clientId: string,
clientAssertion: string,
verifyOptions: VerifyOptions,
client: ClientApplication
): Promise<ValidationAssertion> {
for await (const publicKey of publicKeys) {
try {
await jwtVerify(clientAssertion, publicKey, verifyOptions);
// If we validate successfully inside the catch we can validate the client assertion
return { clientId, clientSecret: client.secret };
} catch (innerError: any) {
if (innerError?.code === 'ERR_JWS_SIGNATURE_VERIFICATION_FAILED') {
continue;
}
return { error: innerError.code };
}
}
return { error: 'ERR_JWS_SIGNATURE_VERIFICATION_FAILED' };
}

0 comments on commit f710952

Please sign in to comment.