Skip to content

Commit

Permalink
Generate topology from external TPU names and topology parameters.
Browse files Browse the repository at this point in the history
This should allow the topology API to be usable on C API / Cloud TPU.

PiperOrigin-RevId: 541232667
  • Loading branch information
IvyZX authored and jax authors committed Jun 17, 2023
1 parent b83e6fb commit 7e61479
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
2 changes: 1 addition & 1 deletion jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ def using_pjrt_c_api(backend=None):


# TODO(parkers): Get rid of this in favor of a generic way to get topologies.
def make_pjrt_tpu_topology(topology_name=None, **kwargs):
def make_pjrt_tpu_topology(topology_name='', **kwargs):
# TODO(b/261484192): Make a system for lazily loading libtpu.so and call
# that inside make_tfrt_tpu_c_api_device_topology.
get_backend() # Properly initialize libtpu.so.
Expand Down
10 changes: 4 additions & 6 deletions jax/experimental/topologies.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@ def get_attached_topology(platform=None) -> TopologyDescription:


def get_topology_desc(
topology_name: Optional[str] = None,
platform: Optional[str] = None,
**kwargs
topology_name: str = "", platform: Optional[str] = None, **kwargs
) -> TopologyDescription:
if platform == "tpu" or platform is None:
if topology_name is not None:
kwargs.update(topology_name=topology_name)
return TopologyDescription(
xb.make_pjrt_tpu_topology(**kwargs)._make_compile_only_devices()
xb.make_pjrt_tpu_topology(
topology_name, **kwargs
)._make_compile_only_devices()
)
raise NotImplementedError(
"get_topology_desc(platform=%s) is not implemented" % repr(platform)
Expand Down

0 comments on commit 7e61479

Please sign in to comment.