diff --git a/flintrock/ec2.py b/flintrock/ec2.py index 348abf9..2fd90a7 100644 --- a/flintrock/ec2.py +++ b/flintrock/ec2.py @@ -844,13 +844,16 @@ def launch( key_name, identity_file, instance_type, + master_instance_type, region, availability_zone, ami, user, security_groups, spot_price=None, + master_spot_price=None, spot_request_duration=None, + master_spot_request_duration=None, min_root_ebs_size_gb, vpc_id, subnet_id, @@ -922,29 +925,44 @@ def launch( else: user_data = '' + create_cluster_instances = functools.partial( + _create_instances, + region=region, + ami=ami, + assume_yes=assume_yes, + key_name=key_name, + block_device_mappings=block_device_mappings, + availability_zone=availability_zone, + placement_group=placement_group, + tenancy=tenancy, + security_group_ids=security_group_ids, + subnet_id=subnet_id, + instance_profile_arn=instance_profile_arn, + ebs_optimized=ebs_optimized, + instance_initiated_shutdown_behavior=instance_initiated_shutdown_behavior, + user_data=user_data) + try: - cluster_instances = _create_instances( - num_instances=num_instances, - region=region, - spot_price=spot_price, - spot_request_valid_until=duration_to_expiration(spot_request_duration), - ami=ami, - assume_yes=assume_yes, - key_name=key_name, - instance_type=instance_type, - block_device_mappings=block_device_mappings, - availability_zone=availability_zone, - placement_group=placement_group, - tenancy=tenancy, - security_group_ids=security_group_ids, - subnet_id=subnet_id, - instance_profile_arn=instance_profile_arn, - ebs_optimized=ebs_optimized, - instance_initiated_shutdown_behavior=instance_initiated_shutdown_behavior, - user_data=user_data) - - master_instance = cluster_instances[0] - slave_instances = cluster_instances[1:] + if master_instance_type: + master_instances = create_cluster_instances( + num_instances=1, + spot_price=master_spot_price, + spot_request_valid_until=duration_to_expiration(master_spot_request_duration), + instance_type=master_instance_type) + slave_instances = create_cluster_instances( + num_instances=num_slaves, + spot_price=spot_price, + spot_request_valid_until=duration_to_expiration(spot_request_duration), + instance_type=instance_type) + master_instance = master_instances[0] + else: + cluster_instances = create_cluster_instances( + num_instances=num_instances, + spot_price=spot_price, + spot_request_valid_until=duration_to_expiration(spot_request_duration), + instance_type=instance_type) + master_instance = cluster_instances[0] + slave_instances = cluster_instances[1:] master_tags = [ {'Key': 'flintrock-role', 'Value': 'master'}, diff --git a/flintrock/flintrock.py b/flintrock/flintrock.py index afcbc6b..809d0db 100644 --- a/flintrock/flintrock.py +++ b/flintrock/flintrock.py @@ -326,6 +326,7 @@ def cli(cli_context, config, provider, debug): type=click.Path(exists=True, dir_okay=False), help="Path to SSH .pem file for accessing nodes.") @click.option('--ec2-instance-type', default='m5.medium', show_default=True) +@click.option('--ec2-master-instance-type', default='m5.medium', show_default=True) @click.option('--ec2-region', default='us-east-1', show_default=True) # We set some of these defaults to empty strings because of boto3's parameter validation. # See: https://github.com/boto/boto3/issues/400 @@ -337,8 +338,11 @@ def cli(cli_context, config, provider, debug): help="Additional security groups names to assign to the instances. " "You can specify this option multiple times.") @click.option('--ec2-spot-price', type=float) +@click.option('--ec2-master-spot-price', type=float) @click.option('--ec2-spot-request-duration', default='7d', help="Duration a spot request is valid (e.g. 3d 2h 1m).") +@click.option('--ec2-master-spot-request-duration', default='7d', + help="Duration a spot request is valid (e.g. 3d 2h 1m).") @click.option('--ec2-min-root-ebs-size-gb', type=int, default=30) @click.option('--ec2-vpc-id', default='', help="Leave empty for default VPC.") @click.option('--ec2-subnet-id', default='') @@ -385,13 +389,16 @@ def launch( ec2_key_name, ec2_identity_file, ec2_instance_type, + ec2_master_instance_type, ec2_region, ec2_availability_zone, ec2_ami, ec2_user, ec2_security_groups, ec2_spot_price, + ec2_master_spot_price, ec2_spot_request_duration, + ec2_master_spot_request_duration, ec2_min_root_ebs_size_gb, ec2_vpc_id, ec2_subnet_id, @@ -492,13 +499,16 @@ def launch( key_name=ec2_key_name, identity_file=ec2_identity_file, instance_type=ec2_instance_type, + master_instance_type=ec2_master_instance_type, region=ec2_region, availability_zone=ec2_availability_zone, ami=ec2_ami, user=ec2_user, security_groups=ec2_security_groups, spot_price=ec2_spot_price, + master_spot_price=ec2_master_spot_price, spot_request_duration=ec2_spot_request_duration, + master_spot_request_duration=ec2_master_spot_request_duration, min_root_ebs_size_gb=ec2_min_root_ebs_size_gb, vpc_id=ec2_vpc_id, subnet_id=ec2_subnet_id, diff --git a/tests/test_scripts.py b/tests/test_scripts.py index f464df2..f7103f8 100644 --- a/tests/test_scripts.py +++ b/tests/test_scripts.py @@ -25,7 +25,7 @@ def destroy(): return tgz_file_name -@pytest.mark.parametrize('python', ['python', 'python2']) +@pytest.mark.parametrize('python', ['python', 'python3']) def test_download_package(python, project_root_dir, tgz_file): with tempfile.TemporaryDirectory() as temp_dir: subprocess.run(