Skip to content

Commit 71365d5

Browse files
BihanBihan  Rana
andauthored
Add tpu support in gcp (#1323)
* Add TPU support in gcp * Filter TPU Pods for initial release * Modify pretty_resources for TPU --------- Co-authored-by: Bihan Rana <bihan@Bihans-MacBook-Pro.local>
1 parent 81a22fa commit 71365d5

File tree

4 files changed

+152
-8
lines changed

4 files changed

+152
-8
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def get_long_description():
9191
"google-cloud-logging>=2.0.0",
9292
"google-api-python-client>=2.80.0",
9393
"google-cloud-billing>=1.11.0",
94+
"google-cloud-tpu>=1.18.3",
9495
]
9596

9697
DATACRUNCH_DEPS = ["datacrunch"]

src/dstack/_internal/core/backends/gcp/compute.py

Lines changed: 103 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44

55
import google.api_core.exceptions
66
import google.cloud.compute_v1 as compute_v1
7+
from google.cloud import tpu_v2
78

89
import dstack._internal.core.backends.gcp.auth as auth
910
import dstack._internal.core.backends.gcp.resources as gcp_resources
1011
from dstack._internal.core.backends.base.compute import (
1112
Compute,
1213
get_gateway_user_data,
1314
get_instance_name,
15+
get_shim_commands,
1416
get_user_data,
1517
)
1618
from dstack._internal.core.backends.base.offers import get_catalog_offers
@@ -45,6 +47,7 @@ def __init__(self, config: GCPConfig):
4547
self.firewalls_client = compute_v1.FirewallsClient(credentials=self.credentials)
4648
self.regions_client = compute_v1.RegionsClient(credentials=self.credentials)
4749
self.subnetworks_client = compute_v1.SubnetworksClient(credentials=self.credentials)
50+
self.tpu_client = tpu_v2.TpuClient(credentials=self.credentials)
4851

4952
def get_offers(
5053
self, requirements: Optional[Requirements] = None
@@ -70,7 +73,7 @@ def get_offers(
7073
availability = InstanceAvailability.NO_QUOTA
7174
if _has_gpu_quota(quotas[region], offer.instance.resources):
7275
availability = InstanceAvailability.UNKNOWN
73-
# todo quotas: cpu, memory, global gpu
76+
# todo quotas: cpu, memory, global gpu, tpu
7477
offers_with_availability.append(
7578
InstanceOfferWithAvailability(**offer.dict(), availability=availability)
7679
)
@@ -84,13 +87,22 @@ def terminate_instance(
8487
# Old instances have region set to zone, e.g. us-central1-a.
8588
# New instance have region set to region, e.g. us-central1. Zone is stored in backend_data.
8689
zone = region
90+
is_tpu = False
8791
if backend_data is not None:
8892
backend_data_dict = json.loads(backend_data)
8993
zone = backend_data_dict["zone"]
94+
is_tpu = backend_data_dict.get("is_tpu", False)
9095
try:
91-
self.instances_client.delete(
92-
project=self.config.project_id, zone=zone, instance=instance_id
93-
)
96+
if is_tpu:
97+
name = f"projects/{self.project_id}/locations/{zone}/nodes/{instance_id}"
98+
delete_request = tpu_v2.DeleteNodeRequest(
99+
name=name,
100+
)
101+
self.tpu_client.delete_node(request=delete_request)
102+
else:
103+
self.instances_client.delete(
104+
project=self.config.project_id, zone=zone, instance=instance_id
105+
)
94106
except google.api_core.exceptions.NotFound:
95107
pass
96108

@@ -120,21 +132,74 @@ def create_instance(
120132
network=self.config.vpc_resource_name,
121133
)
122134
disk_size = round(instance_offer.instance.resources.disk.size_mib / 1024)
123-
124135
# Choose any usable subnet in a VPC.
125136
# Configuring a specific subnet per region is not supported yet.
126137
subnetwork = _get_vpc_subnet(
127138
subnetworks_client=self.subnetworks_client,
128139
config=self.config,
129140
region=instance_offer.region,
130141
)
142+
commands = get_shim_commands(authorized_keys=authorized_keys)
143+
startup_script = " ".join([" && ".join(commands)])
144+
startup_script = "#! /bin/bash\n" + startup_script
145+
instance_id = f"tpu-{instance_config.instance_name}"
131146

132147
labels = {
133148
"owner": "dstack",
134149
"dstack_project": instance_config.project_name.lower(),
135150
"dstack_user": instance_config.user.lower(),
136151
}
137152
labels = {k: v for k, v in labels.items() if gcp_resources.is_valid_label_value(v)}
153+
tpu = (
154+
_is_tpu(instance_offer.instance.resources.gpus[0].name)
155+
if instance_offer.instance.resources.gpus
156+
else False
157+
)
158+
if tpu:
159+
for zone in _get_instance_zones(instance_offer):
160+
tpu_node = gcp_resources.create_tpu_node_struct(
161+
instance_name=instance_offer.instance.name,
162+
startup_script=startup_script,
163+
authorized_keys=authorized_keys,
164+
spot=instance_offer.instance.resources.spot,
165+
labels=labels,
166+
)
167+
168+
create_node_request = tpu_v2.CreateNodeRequest(
169+
parent=f"projects/{self.config.project_id}/locations/{zone}",
170+
node_id=instance_id,
171+
node=tpu_node,
172+
)
173+
try:
174+
operation = self.tpu_client.create_node(request=create_node_request)
175+
gcp_resources.wait_for_operation(
176+
operation, verbose_name="tpu instance creation"
177+
)
178+
except (
179+
google.api_core.exceptions.ServiceUnavailable,
180+
google.api_core.exceptions.NotFound,
181+
google.api_core.exceptions.ResourceExhausted,
182+
):
183+
continue
184+
node_request = tpu_v2.GetNodeRequest(
185+
name=f"projects/dstack/locations/{zone}/nodes/{instance_id}",
186+
)
187+
instance = self.tpu_client.get_node(request=node_request)
188+
return JobProvisioningData(
189+
backend=instance_offer.backend,
190+
instance_type=instance_offer.instance,
191+
instance_id=instance_id,
192+
hostname=instance.network_endpoints[0].access_config.external_ip,
193+
internal_ip=None,
194+
region=zone,
195+
price=instance_offer.price,
196+
ssh_port=22,
197+
username="ubuntu",
198+
ssh_proxy=None,
199+
dockerized=True,
200+
backend_data=json.dumps({"is_tpu": tpu, "zone": zone}),
201+
)
202+
raise NoCapacityError()
138203

139204
for zone in _get_instance_zones(instance_offer):
140205
request = compute_v1.InsertInstanceRequest()
@@ -301,6 +366,9 @@ def _filter(offer: InstanceOffer) -> bool:
301366
# strip zone
302367
if offer.region[:-2] not in regions:
303368
return False
369+
# remove TPU Pod for initial release
370+
if _is_tpu(f"tpu-{offer.instance.name}") and _is_pod(offer.instance.name):
371+
return False
304372
for family in [
305373
"e2-medium",
306374
"e2-standard-",
@@ -324,6 +392,8 @@ def _has_gpu_quota(quotas: Dict[str, float], resources: Resources) -> bool:
324392
if not resources.gpus:
325393
return True
326394
gpu = resources.gpus[0]
395+
if _is_tpu(gpu.name):
396+
return True
327397
quota_name = f"NVIDIA_{gpu.name}_GPUS"
328398
if gpu.name == "A100" and gpu.memory_mib == 80 * 1024:
329399
quota_name = "NVIDIA_A100_80GB_GPUS"
@@ -352,3 +422,31 @@ def _get_instance_zones(instance_offer: InstanceOffer) -> List[str]:
352422
continue
353423
zones.append(offer.region)
354424
return zones
425+
426+
427+
def _is_tpu(name: str) -> bool:
428+
tpu_versions = ["tpu-v2", "tpu-v3", "tpu-v4", "tpu-v5p", "tpu-v5litepod"]
429+
parts = name.split("-")
430+
if len(parts) == 3:
431+
version = f"{parts[0]}-{parts[1]}"
432+
cores = parts[2]
433+
if version in tpu_versions and cores.isdigit():
434+
return True
435+
return False
436+
437+
438+
def _is_pod(instance_name: str) -> bool:
439+
parts = instance_name.split("-")
440+
if len(parts) != 2:
441+
raise ValueError(f"Invalid tpu type: {instance_name}")
442+
version, tensor_cores = parts
443+
try:
444+
tensor_cores = int(tensor_cores)
445+
except ValueError:
446+
raise ValueError(f"Invalid number in tpu tensor cores: {tensor_cores}")
447+
if version in ["v2", "v3"]:
448+
return tensor_cores > 8
449+
elif version in ["v4", "v5p", "v5litepod"]:
450+
return True
451+
else:
452+
raise ValueError(f"Unknown TPU version: {version}")

src/dstack/_internal/core/backends/gcp/resources.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import google.api_core.exceptions
77
import google.cloud.compute_v1 as compute_v1
88
from google.api_core.extended_operation import ExtendedOperation
9+
from google.api_core.operation import Operation
10+
from google.cloud import tpu_v2
911

1012
import dstack.version as version
1113
from dstack._internal.core.errors import ComputeError
@@ -278,3 +280,32 @@ def generate_random_resource_name(length: int = 40) -> str:
278280
return random.choice(string.ascii_lowercase) + "".join(
279281
random.choice(string.ascii_lowercase + string.digits) for _ in range(length)
280282
)
283+
284+
285+
def create_tpu_node_struct(
286+
instance_name: str,
287+
startup_script: str,
288+
authorized_keys: List[str],
289+
spot: bool,
290+
labels: Dict[str, str],
291+
) -> tpu_v2.Node:
292+
node = tpu_v2.Node()
293+
if spot:
294+
node.scheduling_config = tpu_v2.SchedulingConfig(preemptible=True)
295+
node.accelerator_type = instance_name
296+
node.runtime_version = "tpu-ubuntu2204-base"
297+
node.network_config = tpu_v2.NetworkConfig(enable_external_ips=True)
298+
ssh_keys = "\n".join(f"ubuntu:{key}" for key in authorized_keys)
299+
node.metadata = {"ssh-keys": ssh_keys, "startup-script": startup_script}
300+
node.labels = labels
301+
return node
302+
303+
304+
def wait_for_operation(operation: Operation, verbose_name: str = "operation", timeout: int = 300):
305+
try:
306+
result = operation.result(timeout=timeout)
307+
except Exception as e:
308+
logger.error("Error during %s: %s", verbose_name, e)
309+
logger.error("Operation ID: %s", operation)
310+
raise operation.exception() or RuntimeError(str(e))
311+
return result

src/dstack/_internal/utils/common.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,27 @@ def pretty_resources(
8585
"""
8686
parts = []
8787
if cpus is not None:
88-
parts.append(f"{cpus}xCPU")
88+
if isinstance(cpus, int):
89+
if cpus > 0:
90+
parts.append(f"{cpus}xCPU")
91+
else:
92+
parts.append(f"{cpus}xCPU")
8993
if memory is not None:
90-
parts.append(f"{memory}")
94+
if isinstance(memory, str):
95+
memory_value = int(memory[:-2])
96+
if memory_value > 0:
97+
parts.append(f"{memory}")
98+
else:
99+
parts.append(f"{memory}")
91100
if gpu_count:
92101
gpu_parts = []
93102
if gpu_memory is not None:
94-
gpu_parts.append(f"{gpu_memory}")
103+
if isinstance(gpu_memory, str):
104+
gpu_memory_value = int(gpu_memory[:-2])
105+
if gpu_memory_value > 0:
106+
parts.append(f"{gpu_memory}")
107+
else:
108+
gpu_parts.append(f"{gpu_memory}")
95109
if total_gpu_memory is not None:
96110
gpu_parts.append(f"total {total_gpu_memory}")
97111
if compute_capability is not None:

0 commit comments

Comments
 (0)