Skip to content

Commit

Permalink
fix(NODE-6051): only provide expected allowed keys to libmongocrypt a…
Browse files Browse the repository at this point in the history
…fter fetching aws kms credentials (#4057)
  • Loading branch information
baileympearson committed Apr 4, 2024
1 parent 0e3d6ea commit c604e74
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 60 deletions.
27 changes: 17 additions & 10 deletions src/client-side-encryption/providers/aws.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
import { getAwsCredentialProvider } from '../../deps';
import { AWSSDKCredentialProvider } from '../../cmap/auth/aws_temporary_credentials';
import { type KMSProviders } from '.';

/**
* @internal
*/
export async function loadAWSCredentials(kmsProviders: KMSProviders): Promise<KMSProviders> {
const credentialProvider = getAwsCredentialProvider();
const credentialProvider = new AWSSDKCredentialProvider();

if ('kModuleError' in credentialProvider) {
return kmsProviders;
}
// We shouldn't ever receive a response from the AWS SDK that doesn't have a `SecretAccessKey`
// or `AccessKeyId`. However, TS says these fields are optional. We provide empty strings
// and let libmongocrypt error if we're unable to fetch the required keys.
const {
SecretAccessKey = '',
AccessKeyId = '',
Token
} = await credentialProvider.getCredentials();
const aws: NonNullable<KMSProviders['aws']> = {
secretAccessKey: SecretAccessKey,
accessKeyId: AccessKeyId
};
// the AWS session token is only required for temporary credentials so only attach it to the
// result if it's present in the response from the aws sdk
Token != null && (aws.sessionToken = Token);

const { fromNodeProviderChain } = credentialProvider;
const provider = fromNodeProviderChain();
// The state machine is the only place calling this so it will
// catch if there is a rejection here.
const aws = await provider();
return { ...kmsProviders, aws };
}
73 changes: 60 additions & 13 deletions test/integration/auth/mongodb_aws.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,26 @@ import * as http from 'http';
import { performance } from 'perf_hooks';
import * as sinon from 'sinon';

// eslint-disable-next-line @typescript-eslint/no-restricted-imports
import { refreshKMSCredentials } from '../../../src/client-side-encryption/providers';
import {
AWSTemporaryCredentialProvider,
MongoAWSError,
type MongoClient,
MongoDBAWS,
MongoMissingCredentialsError,
MongoServerError
MongoServerError,
setDifference
} from '../../mongodb';

function awsSdk() {
try {
return require('@aws-sdk/credential-providers');
} catch {
return null;
}
}
const isMongoDBAWSAuthEnvironment = (process.env.MONGODB_URI ?? '').includes('MONGODB-AWS');

describe('MONGODB-AWS', function () {
let awsSdkPresent;
let client: MongoClient;

beforeEach(function () {
const MONGODB_URI = process.env.MONGODB_URI;
if (!MONGODB_URI || MONGODB_URI.indexOf('MONGODB-AWS') === -1) {
if (!isMongoDBAWSAuthEnvironment) {
this.currentTest.skipReason = 'requires MONGODB_URI to contain MONGODB-AWS auth mechanism';
return this.skip();
}
Expand All @@ -39,7 +35,7 @@ describe('MONGODB-AWS', function () {
`Always inform the AWS tests if they run with or without the SDK (MONGODB_AWS_SDK=${MONGODB_AWS_SDK})`
).to.include(MONGODB_AWS_SDK);

awsSdkPresent = !!awsSdk();
awsSdkPresent = AWSTemporaryCredentialProvider.isAWSSDKInstalled;
expect(
awsSdkPresent,
MONGODB_AWS_SDK === 'true'
Expand Down Expand Up @@ -244,8 +240,10 @@ describe('MONGODB-AWS', function () {

const envCheck = () => {
const { AWS_WEB_IDENTITY_TOKEN_FILE = '' } = process.env;
credentialProvider = awsSdk();
return AWS_WEB_IDENTITY_TOKEN_FILE.length === 0 || credentialProvider == null;
return (
AWS_WEB_IDENTITY_TOKEN_FILE.length === 0 ||
!AWSTemporaryCredentialProvider.isAWSSDKInstalled
);
};

beforeEach(function () {
Expand All @@ -255,6 +253,9 @@ describe('MONGODB-AWS', function () {
return this.skip();
}

// @ts-expect-error We intentionally access a protected variable.
credentialProvider = AWSTemporaryCredentialProvider.awsSDK;

storedEnv = process.env;
if (test.env.AWS_STS_REGIONAL_ENDPOINTS === undefined) {
delete process.env.AWS_STS_REGIONAL_ENDPOINTS;
Expand Down Expand Up @@ -324,3 +325,49 @@ describe('MONGODB-AWS', function () {
}
});
});

describe('AWS KMS Credential Fetching', function () {
context('when the AWS SDK is not installed', function () {
beforeEach(function () {
this.currentTest.skipReason = !isMongoDBAWSAuthEnvironment
? 'Test must run in an AWS auth testing environment'
: AWSTemporaryCredentialProvider.isAWSSDKInstalled
? 'This test must run in an environment where the AWS SDK is not installed.'
: undefined;
this.currentTest?.skipReason && this.skip();
});
it('fetching AWS KMS credentials throws an error', async function () {
const error = await refreshKMSCredentials({ aws: {} }).catch(e => e);
expect(error).to.be.instanceOf(MongoAWSError);
});
});

context('when the AWS SDK is installed', function () {
beforeEach(function () {
this.currentTest.skipReason = !isMongoDBAWSAuthEnvironment
? 'Test must run in an AWS auth testing environment'
: !AWSTemporaryCredentialProvider.isAWSSDKInstalled
? 'This test must run in an environment where the AWS SDK is installed.'
: undefined;
this.currentTest?.skipReason && this.skip();
});
it('KMS credentials are successfully fetched.', async function () {
const { aws } = await refreshKMSCredentials({ aws: {} });

expect(aws).to.have.property('accessKeyId');
expect(aws).to.have.property('secretAccessKey');
});

it('does not return any extra keys for the `aws` credential provider', async function () {
const { aws } = await refreshKMSCredentials({ aws: {} });

const keys = new Set(Object.keys(aws ?? {}));
const allowedKeys = ['accessKeyId', 'secretAccessKey', 'sessionToken'];

expect(
Array.from(setDifference(keys, allowedKeys)),
'received an unexpected key in the response refreshing KMS credentials'
).to.deep.equal([]);
});
});
});
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { expect } from 'chai';
import * as dns from 'dns';
import { once } from 'events';
import { coerce } from 'semver';
import { satisfies } from 'semver';
import * as sinon from 'sinon';

import {
Expand Down Expand Up @@ -51,11 +51,9 @@ describe('Polling Srv Records for Mongos Discovery', () => {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const test = this.currentTest!;

const { major } = coerce(process.version);
test.skipReason =
major === 18 || major === 20
? 'TODO(NODE-5666): fix failing unit tests on Node18'
: undefined;
test.skipReason = satisfies(process.version, '>=18.0.0')
? `TODO(NODE-5666): fix failing unit tests on Node18 (Running with Nodejs ${process.version})`
: undefined;

if (test.skipReason) this.skip();
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import {
} from '../../../../src/client-side-encryption/providers/azure';
// eslint-disable-next-line @typescript-eslint/no-restricted-imports
import * as utils from '../../../../src/client-side-encryption/providers/utils';
// eslint-disable-next-line @typescript-eslint/no-restricted-imports
import { AWSSDKCredentialProvider } from '../../../../src/cmap/auth/aws_temporary_credentials';
import * as requirements from '../requirements.helper';

const originalAccessKeyId = process.env.AWS_ACCESS_KEY_ID;
Expand Down Expand Up @@ -154,25 +156,25 @@ describe('#refreshKMSCredentials', function () {
});
});

context('when the sdk is not installed', function () {
const kmsProviders = {
local: {
key: Buffer.alloc(96)
},
aws: {}
};

before(function () {
if (requirements.credentialProvidersInstalled.aws && this.currentTest) {
this.currentTest.skipReason = 'Credentials will be loaded when sdk present';
this.currentTest.skip();
return;
}
context('when the AWS SDK returns unknown fields', function () {
beforeEach(() => {
sinon.stub(AWSSDKCredentialProvider.prototype, 'getCredentials').resolves({
Token: 'example',
SecretAccessKey: 'example',
AccessKeyId: 'example',
Expiration: new Date()
});
});

it('does not refresh credentials', async function () {
const providers = await refreshKMSCredentials(kmsProviders);
expect(providers).to.deep.equal(kmsProviders);
afterEach(() => sinon.restore());
it('only returns fields libmongocrypt expects', async function () {
const credentials = await refreshKMSCredentials({ aws: {} });
expect(credentials).to.deep.equal({
aws: {
accessKeyId: accessKey,
secretAccessKey: secretKey,
sessionToken: sessionToken
}
});
});
});
});
Expand Down
7 changes: 3 additions & 4 deletions test/unit/connection_string.spec.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { coerce } from 'semver';
import { satisfies } from 'semver';

import { loadSpecTests } from '../spec';
import { executeUriValidationTest } from '../tools/uri_spec_runner';
Expand All @@ -15,14 +15,13 @@ describe('Connection String spec tests', function () {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const test = this.currentTest!;

const { major } = coerce(process.version);
const skippedTests = [
'Invalid port (zero) with IP literal',
'Invalid port (zero) with hostname'
];
test.skipReason =
major === 20 && skippedTests.includes(test.title)
? 'TODO(NODE-5666): fix failing unit tests on Node18'
satisfies(process.version, '>=20.0.0') && skippedTests.includes(test.title)
? 'TODO(NODE-5666): fix failing unit tests on Node20+'
: undefined;

if (test.skipReason) this.skip();
Expand Down
5 changes: 2 additions & 3 deletions test/unit/sdam/monitor.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { once } from 'node:events';
import * as net from 'node:net';

import { expect } from 'chai';
import { coerce } from 'semver';
import { satisfies } from 'semver';
import * as sinon from 'sinon';
import { setTimeout } from 'timers';
import { setTimeout as setTimeoutPromise } from 'timers/promises';
Expand Down Expand Up @@ -57,7 +57,6 @@ describe('monitoring', function () {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const test = this.currentTest!;

const { major } = coerce(process.version);
const failingTests = [
'should connect and issue an initial server check',
'should ignore attempts to connect when not already closed',
Expand All @@ -67,7 +66,7 @@ describe('monitoring', function () {
'correctly returns the mean of the heartbeat durations'
];
test.skipReason =
(major === 18 || major === 20) && failingTests.includes(test.title)
satisfies(process.version, '>=18.0.0') && failingTests.includes(test.title)
? 'TODO(NODE-5666): fix failing unit tests on Node18'
: undefined;

Expand Down
10 changes: 4 additions & 6 deletions test/unit/sdam/topology.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { expect } from 'chai';
import { once } from 'events';
import * as net from 'net';
import { type AddressInfo } from 'net';
import { coerce, type SemVer } from 'semver';
import { satisfies } from 'semver';
import * as sinon from 'sinon';
import { clearTimeout } from 'timers';

Expand Down Expand Up @@ -284,11 +284,9 @@ describe('Topology (unit)', function () {
it('should encounter a server selection timeout on garbled server responses', function () {
const test = this.test;

const { major } = coerce(process.version) as SemVer;
test.skipReason =
major === 18 || major === 20
? 'TODO(NODE-5666): fix failing unit tests on Node18'
: undefined;
test.skipReason = satisfies(process.version, '>=18.0.0')
? 'TODO(NODE-5666): fix failing unit tests on Node18'
: undefined;

if (test.skipReason) this.skip();

Expand Down

0 comments on commit c604e74

Please sign in to comment.