From 540b1da2fcfc1f1e97368cfefdf8b3144e6747fe Mon Sep 17 00:00:00 2001 From: miteshvp Date: Mon, 9 Dec 2019 14:11:02 +0530 Subject: [PATCH] add emr secret as a separate key --- rudra/deployments/emr_scripts/emr_script_builder.py | 8 ++++++-- tests/deployments/emr_scripts/test_emr_script_builder.py | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/rudra/deployments/emr_scripts/emr_script_builder.py b/rudra/deployments/emr_scripts/emr_script_builder.py index 999359e..7c765fe 100644 --- a/rudra/deployments/emr_scripts/emr_script_builder.py +++ b/rudra/deployments/emr_scripts/emr_script_builder.py @@ -40,6 +40,10 @@ def construct_job(self, input_dict): or input_dict.get('aws_access_key') aws_secret_key = os.getenv("AWS_S3_SECRET_ACCESS_KEY")\ or input_dict.get('aws_secret_key') + aws_emr_access_key = os.getenv("AWS_EMR_ACCESS_KEY_ID") \ + or input_dict.get('aws_emr_access_key') + aws_emr_secret_key = os.getenv("AWS_EMR_SECRET_ACCESS_KEY")\ + or input_dict.get('aws_emr_secret_key') github_token = os.getenv("GITHUB_TOKEN", input_dict.get('github_token')) self.bucket_name = input_dict.get('bucket_name') if self.hyper_params: @@ -59,8 +63,8 @@ def construct_job(self, input_dict): 'GITHUB_TOKEN': github_token } - self.aws_emr = AmazonEmr(aws_access_key_id=aws_access_key, - aws_secret_access_key=aws_secret_key) + self.aws_emr = AmazonEmr(aws_access_key_id=aws_emr_access_key, + aws_secret_access_key=aws_emr_secret_key) self.aws_emr_client = self.aws_emr.connect() diff --git a/tests/deployments/emr_scripts/test_emr_script_builder.py b/tests/deployments/emr_scripts/test_emr_script_builder.py index 8d350b2..3fd409a 100644 --- a/tests/deployments/emr_scripts/test_emr_script_builder.py +++ b/tests/deployments/emr_scripts/test_emr_script_builder.py @@ -32,7 +32,9 @@ def test_construct_job_without_required_params(self): assert not set(ast.literal_eval(grps[1])) - req_params @patch.dict('os.environ', {'AWS_S3_ACCESS_KEY_ID': 'fake_id', - 'AWS_S3_SECRET_ACCESS_KEY': 'fake_secret'}) + 'AWS_S3_SECRET_ACCESS_KEY': 'fake_secret', + 'AWS_EMR_ACCESS_KEY_ID': 'fake_id', + 'AWS_EMR_SECRET_ACCESS_KEY': 'fake_secret'}) def test_construct_job(self): emr_builder_obj = EMRScriptBuilder() req_params = {'environment': 'dev',