Skip to content

Commit

Permalink
Add allowedLocations option to google batch (#3453)
Browse files Browse the repository at this point in the history

Signed-off-by: Ben Sherman <bentshermann@gmail.com>
Co-authored-by: Paolo Di Tommaso <paolo.ditommaso@gmail.com>
  • Loading branch information
bentsherman and pditommaso committed Dec 5, 2022
1 parent cc0dc54 commit c619eb8
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/google.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ Name Description
google.project The Google Project Id to use for the pipeline execution.
google.location The Google *location* where the job executions are deployed (default: ``us-central1``).
google.enableRequesterPaysBuckets When ``true`` uses the configured Google project id as the billing project for storage access. This is required when accessing data from *requester pays enabled* buckets. See `Requester Pays on Google Cloud Storage documentation <https://cloud.google.com/storage/docs/requester-pays>`_ (default: ``false``).
google.batch.allowedLocations Define the set of allowed locations for VMs to be provisioned. See `Google documentation <https://cloud.google.com/batch/docs/reference/rest/v1/projects.locations.jobs#locationpolicy>`_ for details (default: no restriction. Requires version ``22.12.0-edge`` or later).
google.batch.bootDiskSize Set the size of the virtual machine boot disk, e.g ``50.GB`` (default: none).
google.batch.cpuPlatform Set the minimum CPU Platform, e.g. ``'Intel Skylake'``. See `Specifying a minimum CPU Platform for VM instances <https://cloud.google.com/compute/docs/instances/specify-min-cpu-platform#specifications>`_ (default: none).
google.batch.spot When ``true`` enables the usage of *spot* virtual machines or ``false`` otherwise (default: ``false``).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,12 @@ class GoogleBatchTaskHandler extends TaskHandler {
final instancePolicyOrTemplate = AllocationPolicy.InstancePolicyOrTemplate.newBuilder()
final instancePolicy = AllocationPolicy.InstancePolicy.newBuilder()

if( executor.config.getAllowedLocations() )
allocationPolicy.setLocation(
AllocationPolicy.LocationPolicy.newBuilder()
.addAllAllowedLocations( executor.config.getAllowedLocations() )
)

if( task.config.getAccelerator() ) {
final accelerator = AllocationPolicy.Accelerator.newBuilder()
.setCount( task.config.getAccelerator().getRequest() )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class BatchConfig {

private GoogleOpts googleOpts
private GoogleCredentials credentials
private List<String> allowedLocations
private MemoryUnit bootDiskSize
private String cpuPlatform
private boolean spot
Expand All @@ -45,6 +46,7 @@ class BatchConfig {

GoogleOpts getGoogleOpts() { return googleOpts }
GoogleCredentials getCredentials() { return credentials }
List<String> getAllowedLocations() { allowedLocations }
MemoryUnit getBootDiskSize() { bootDiskSize }
String getCpuPlatform() { cpuPlatform }
boolean getPreemptible() { preemptible }
Expand All @@ -58,6 +60,7 @@ class BatchConfig {
final result = new BatchConfig()
result.googleOpts = GoogleOpts.create(session)
result.credentials = result.googleOpts.credentials
result.allowedLocations = session.config.navigate('google.batch.allowedLocations', List.of()) as List<String>
result.bootDiskSize = session.config.navigate('google.batch.bootDiskSize') as MemoryUnit
result.cpuPlatform = session.config.navigate('google.batch.cpuPlatform')
result.spot = session.config.navigate('google.batch.spot',false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class GoogleBatchTaskHandlerTest extends Specification {
then:
def taskGroup = req.getTaskGroups(0)
def runnable = taskGroup.getTaskSpec().getRunnables(0)
def instancePolicy = req.getAllocationPolicy().getInstances(0).getPolicy()
def allocationPolicy = req.getAllocationPolicy()
def instancePolicy = allocationPolicy.getInstances(0).getPolicy()
and:
taskGroup.getTaskSpec().getComputeResource().getBootDiskMib() == 0
taskGroup.getTaskSpec().getComputeResource().getCpuMilli() == 2_000
Expand All @@ -81,7 +82,8 @@ class GoogleBatchTaskHandlerTest extends Specification {
instancePolicy.getMinCpuPlatform() == ''
instancePolicy.getProvisioningModel().toString() == 'PROVISIONING_MODEL_UNSPECIFIED'
and:
req.getAllocationPolicy().getNetwork().getNetworkInterfacesCount() == 0
allocationPolicy.getLocation().getAllowedLocationsCount() == 0
allocationPolicy.getNetwork().getNetworkInterfacesCount() == 0
and:
req.getLogsPolicy().getDestination().toString() == 'CLOUD_LOGGING'
}
Expand All @@ -103,6 +105,7 @@ class GoogleBatchTaskHandlerTest extends Specification {
and:
def exec = Mock(GoogleBatchExecutor) {
getConfig() >> Mock(BatchConfig) {
getAllowedLocations() >> ['zones/us-central1-a', 'zones/us-central1-c']
getBootDiskSize() >> BOOT_DISK
getCpuPlatform() >> CPU_PLATFORM
getSpot() >> true
Expand Down Expand Up @@ -157,6 +160,9 @@ class GoogleBatchTaskHandlerTest extends Specification {
'/var/lib/nvidia/bin:/usr/local/nvidia/bin'
]
and:
allocationPolicy.getLocation().getAllowedLocationsCount() == 2
allocationPolicy.getLocation().getAllowedLocations(0) == 'zones/us-central1-a'
allocationPolicy.getLocation().getAllowedLocations(1) == 'zones/us-central1-c'
allocationPolicy.getInstances(0).getInstallGpuDrivers() == true
allocationPolicy.getLabelsMap() == [foo: 'bar']
allocationPolicy.getServiceAccount().getEmail() == 'foo@bar.baz'
Expand Down

0 comments on commit c619eb8

Please sign in to comment.