44
55import google .api_core .exceptions
66import google .cloud .compute_v1 as compute_v1
7+ from google .cloud import tpu_v2
78
89import dstack ._internal .core .backends .gcp .auth as auth
910import dstack ._internal .core .backends .gcp .resources as gcp_resources
1011from dstack ._internal .core .backends .base .compute import (
1112 Compute ,
1213 get_gateway_user_data ,
1314 get_instance_name ,
15+ get_shim_commands ,
1416 get_user_data ,
1517)
1618from dstack ._internal .core .backends .base .offers import get_catalog_offers
@@ -45,6 +47,7 @@ def __init__(self, config: GCPConfig):
4547 self .firewalls_client = compute_v1 .FirewallsClient (credentials = self .credentials )
4648 self .regions_client = compute_v1 .RegionsClient (credentials = self .credentials )
4749 self .subnetworks_client = compute_v1 .SubnetworksClient (credentials = self .credentials )
50+ self .tpu_client = tpu_v2 .TpuClient (credentials = self .credentials )
4851
4952 def get_offers (
5053 self , requirements : Optional [Requirements ] = None
@@ -70,7 +73,7 @@ def get_offers(
7073 availability = InstanceAvailability .NO_QUOTA
7174 if _has_gpu_quota (quotas [region ], offer .instance .resources ):
7275 availability = InstanceAvailability .UNKNOWN
73- # todo quotas: cpu, memory, global gpu
76+ # todo quotas: cpu, memory, global gpu, tpu
7477 offers_with_availability .append (
7578 InstanceOfferWithAvailability (** offer .dict (), availability = availability )
7679 )
@@ -84,13 +87,22 @@ def terminate_instance(
8487 # Old instances have region set to zone, e.g. us-central1-a.
8588 # New instance have region set to region, e.g. us-central1. Zone is stored in backend_data.
8689 zone = region
90+ is_tpu = False
8791 if backend_data is not None :
8892 backend_data_dict = json .loads (backend_data )
8993 zone = backend_data_dict ["zone" ]
94+ is_tpu = backend_data_dict .get ("is_tpu" , False )
9095 try :
91- self .instances_client .delete (
92- project = self .config .project_id , zone = zone , instance = instance_id
93- )
96+ if is_tpu :
97+ name = f"projects/{ self .project_id } /locations/{ zone } /nodes/{ instance_id } "
98+ delete_request = tpu_v2 .DeleteNodeRequest (
99+ name = name ,
100+ )
101+ self .tpu_client .delete_node (request = delete_request )
102+ else :
103+ self .instances_client .delete (
104+ project = self .config .project_id , zone = zone , instance = instance_id
105+ )
94106 except google .api_core .exceptions .NotFound :
95107 pass
96108
@@ -120,21 +132,74 @@ def create_instance(
120132 network = self .config .vpc_resource_name ,
121133 )
122134 disk_size = round (instance_offer .instance .resources .disk .size_mib / 1024 )
123-
124135 # Choose any usable subnet in a VPC.
125136 # Configuring a specific subnet per region is not supported yet.
126137 subnetwork = _get_vpc_subnet (
127138 subnetworks_client = self .subnetworks_client ,
128139 config = self .config ,
129140 region = instance_offer .region ,
130141 )
142+ commands = get_shim_commands (authorized_keys = authorized_keys )
143+ startup_script = " " .join ([" && " .join (commands )])
144+ startup_script = "#! /bin/bash\n " + startup_script
145+ instance_id = f"tpu-{ instance_config .instance_name } "
131146
132147 labels = {
133148 "owner" : "dstack" ,
134149 "dstack_project" : instance_config .project_name .lower (),
135150 "dstack_user" : instance_config .user .lower (),
136151 }
137152 labels = {k : v for k , v in labels .items () if gcp_resources .is_valid_label_value (v )}
153+ tpu = (
154+ _is_tpu (instance_offer .instance .resources .gpus [0 ].name )
155+ if instance_offer .instance .resources .gpus
156+ else False
157+ )
158+ if tpu :
159+ for zone in _get_instance_zones (instance_offer ):
160+ tpu_node = gcp_resources .create_tpu_node_struct (
161+ instance_name = instance_offer .instance .name ,
162+ startup_script = startup_script ,
163+ authorized_keys = authorized_keys ,
164+ spot = instance_offer .instance .resources .spot ,
165+ labels = labels ,
166+ )
167+
168+ create_node_request = tpu_v2 .CreateNodeRequest (
169+ parent = f"projects/{ self .config .project_id } /locations/{ zone } " ,
170+ node_id = instance_id ,
171+ node = tpu_node ,
172+ )
173+ try :
174+ operation = self .tpu_client .create_node (request = create_node_request )
175+ gcp_resources .wait_for_operation (
176+ operation , verbose_name = "tpu instance creation"
177+ )
178+ except (
179+ google .api_core .exceptions .ServiceUnavailable ,
180+ google .api_core .exceptions .NotFound ,
181+ google .api_core .exceptions .ResourceExhausted ,
182+ ):
183+ continue
184+ node_request = tpu_v2 .GetNodeRequest (
185+ name = f"projects/dstack/locations/{ zone } /nodes/{ instance_id } " ,
186+ )
187+ instance = self .tpu_client .get_node (request = node_request )
188+ return JobProvisioningData (
189+ backend = instance_offer .backend ,
190+ instance_type = instance_offer .instance ,
191+ instance_id = instance_id ,
192+ hostname = instance .network_endpoints [0 ].access_config .external_ip ,
193+ internal_ip = None ,
194+ region = zone ,
195+ price = instance_offer .price ,
196+ ssh_port = 22 ,
197+ username = "ubuntu" ,
198+ ssh_proxy = None ,
199+ dockerized = True ,
200+ backend_data = json .dumps ({"is_tpu" : tpu , "zone" : zone }),
201+ )
202+ raise NoCapacityError ()
138203
139204 for zone in _get_instance_zones (instance_offer ):
140205 request = compute_v1 .InsertInstanceRequest ()
@@ -301,6 +366,9 @@ def _filter(offer: InstanceOffer) -> bool:
301366 # strip zone
302367 if offer .region [:- 2 ] not in regions :
303368 return False
369+ # remove TPU Pod for initial release
370+ if _is_tpu (f"tpu-{ offer .instance .name } " ) and _is_pod (offer .instance .name ):
371+ return False
304372 for family in [
305373 "e2-medium" ,
306374 "e2-standard-" ,
@@ -324,6 +392,8 @@ def _has_gpu_quota(quotas: Dict[str, float], resources: Resources) -> bool:
324392 if not resources .gpus :
325393 return True
326394 gpu = resources .gpus [0 ]
395+ if _is_tpu (gpu .name ):
396+ return True
327397 quota_name = f"NVIDIA_{ gpu .name } _GPUS"
328398 if gpu .name == "A100" and gpu .memory_mib == 80 * 1024 :
329399 quota_name = "NVIDIA_A100_80GB_GPUS"
@@ -352,3 +422,31 @@ def _get_instance_zones(instance_offer: InstanceOffer) -> List[str]:
352422 continue
353423 zones .append (offer .region )
354424 return zones
425+
426+
427+ def _is_tpu (name : str ) -> bool :
428+ tpu_versions = ["tpu-v2" , "tpu-v3" , "tpu-v4" , "tpu-v5p" , "tpu-v5litepod" ]
429+ parts = name .split ("-" )
430+ if len (parts ) == 3 :
431+ version = f"{ parts [0 ]} -{ parts [1 ]} "
432+ cores = parts [2 ]
433+ if version in tpu_versions and cores .isdigit ():
434+ return True
435+ return False
436+
437+
438+ def _is_pod (instance_name : str ) -> bool :
439+ parts = instance_name .split ("-" )
440+ if len (parts ) != 2 :
441+ raise ValueError (f"Invalid tpu type: { instance_name } " )
442+ version , tensor_cores = parts
443+ try :
444+ tensor_cores = int (tensor_cores )
445+ except ValueError :
446+ raise ValueError (f"Invalid number in tpu tensor cores: { tensor_cores } " )
447+ if version in ["v2" , "v3" ]:
448+ return tensor_cores > 8
449+ elif version in ["v4" , "v5p" , "v5litepod" ]:
450+ return True
451+ else :
452+ raise ValueError (f"Unknown TPU version: { version } " )
0 commit comments