diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index 372806bd24..e0329a783e 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -17,6 +17,7 @@ import os import sys import unittest +from unittest.mock import patch sys.path[0:0] = [""] @@ -111,6 +112,63 @@ def test_poisoned_cache(self): client.get_database().test.find_one() self.assertNotEqual(auth.get_cached_credentials(), None) + def test_environment_variables_ignored(self): + creds = self.setup_cache() + self.assertIsNotNone(creds) + prev = os.environ.copy() + + client = MongoClient(self.uri) + self.addCleanup(client.close) + + client.get_database().test.find_one() + + self.assertIsNotNone(auth.get_cached_credentials()) + + mock_env = dict( + AWS_ACCESS_KEY_ID="foo", AWS_SECRET_ACCESS_KEY="bar", AWS_SESSION_TOKEN="baz" + ) + + with patch.dict("os.environ", mock_env): + self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo") + client.get_database().test.find_one() + + auth.set_cached_credentials(None) + + client2 = MongoClient(self.uri) + self.addCleanup(client2.close) + + with patch.dict("os.environ", mock_env): + self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo") + with self.assertRaises(OperationFailure): + client2.get_database().test.find_one() + + def test_no_cache_environment_variables(self): + creds = self.setup_cache() + self.assertIsNotNone(creds) + auth.set_cached_credentials(None) + + mock_env = dict(AWS_ACCESS_KEY_ID=creds.username, AWS_SECRET_ACCESS_KEY=creds.password) + if creds.token: + mock_env["AWS_SESSION_TOKEN"] = creds.token + + client = MongoClient(self.uri) + self.addCleanup(client.close) + + with patch.dict(os.environ, mock_env): + self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], creds.username) + client.get_database().test.find_one() + + self.assertIsNone(auth.get_cached_credentials()) + + mock_env["AWS_ACCESS_KEY_ID"] = "foo" + + client2 = MongoClient(self.uri) + self.addCleanup(client2.close) + + with patch.dict("os.environ", mock_env), self.assertRaises(OperationFailure): + self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo") + client2.get_database().test.find_one() + class TestAWSLambdaExamples(unittest.TestCase): def test_shared_client(self):