Skip to content

Commit

Permalink
Log a warning after 60 secs to remind the user to run code on all hos…
Browse files Browse the repository at this point in the history
…ts for Cloud TPU 1VM
  • Loading branch information
yashk2810 committed Jul 28, 2021
1 parent 28d9e92 commit bf28b88
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
18 changes: 17 additions & 1 deletion jax/lib/xla_bridge.py
Expand Up @@ -145,6 +145,22 @@ def _make_tpu_driver_client():
return tpu_driver_client.TpuBackend.create(worker=FLAGS.jax_backend_target)


def tpu_client_timer_callback(timer: float):
def _log_warning():
logging.warning('Did you run your code on all the hosts?')

# Will log a warning after `timer` secs.
t = threading.Timer(timer, _log_warning)
t.start()

try:
client = xla_client.make_tpu_client()
finally:
t.cancel()

return client


# Backends, in increasing order of preference.
# We have no particular opinion about how "backends" relate to "devices". For
# example, there could be multiple backends that provide the same kind of
Expand All @@ -170,7 +186,7 @@ def register_backend_factory(name, factory, *, priority=0):
priority=100)
register_backend_factory('gpu', xla_client.make_gpu_client,
priority=200)
register_backend_factory('tpu', xla_client.make_tpu_client,
register_backend_factory('tpu', partial(tpu_client_timer_callback, timer=60.0),
priority=300)

_default_backend = None
Expand Down
15 changes: 15 additions & 0 deletions tests/xla_bridge_test.py
Expand Up @@ -12,12 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import time
import warnings

from absl.testing import absltest
from jax.lib import xla_bridge as xb
from jax.lib import xla_client as xc
from jax import test_util as jtu

mock = absltest.mock


def mock_tpu_client():
time.sleep(0.03)
return None

class XlaBridgeTest(absltest.TestCase):

Expand Down Expand Up @@ -56,6 +65,12 @@ def test_local_devices(self):
with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"):
xb.local_devices(backend="foo")

@mock.patch('jax.lib.xla_client.make_tpu_client', side_effect=mock_tpu_client)
def test_timer_tpu_warning_1vm(self, _):
with self.assertLogs('absl', level='WARNING') as al:
xb.tpu_client_timer_callback(0.01)
self.assertIn('Did you run your code on all the hosts?', al.output[0])


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit bf28b88

Please sign in to comment.