You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
venv/lib/python3.8/site-packages/jax/experimental/maps.py:527: UserWarning: xmap is an experimental feature and probably has bugs!
warn("xmap is an experimental feature and probably has bugs!")
venv/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:429: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
warnings.warn(
venv/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:416: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
warnings.warn(
key shape (1, 2)
in shape (1, 2048)
dp 1
mp 1
Stacktrace
2022-03-07 18:41:42.980600: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2124] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to allocate request for 256.00MiB (268435456B) on device ordinal 0
Traceback (most recent call last):
File "./to_hf_weights.py", line 488, in <module>
save_sharded_to_hf_format(input_ckpt, params, output_path, np_dtype, torch_dtype)
File "./to_hf_weights.py", line 464, in save_sharded_to_hf_format
network = CausalTransformer(params_local)
File "/home/jonathan.hendler/finishing-school/mesh_transformer/transformer_shard.py", line 277, in __init__
self.state = self.init_xmap(jnp.array(key.take(mp_per_host)), x)
File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 666, in fun_mapped
out_flat = xmap_p.bind(fun_flat, *args_flat, **params)
File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 871, in bind
return core.map_bind(self, fun, *args, in_axes=in_axes, **params)
File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/core.py", line 1801, in map_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 874, in process
return trace.process_xmap(self, fun, tracers, params)
File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/core.py", line 594, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/experimental/maps.py", line 703, in xmap_impl
return xmap_callable(*args)
File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/home/jonathan.hendler/venv/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1524, in execute_replicated
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
RuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 256.00MiB (268435456B) on device ordinal 0
Posting here because 256MiB seems particularly small for a TPU vm.
Command
Output
Stacktrace
Configuration info:
#202 (comment)
TPU_VERSION = "v2-alpha"
Python version:
Python 3.8.10
Pip freeze
The text was updated successfully, but these errors were encountered: