Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in colab_demo #65

Closed
paramjeet2021 opened this issue Jul 20, 2021 · 3 comments
Closed

Error in colab_demo #65

paramjeet2021 opened this issue Jul 20, 2021 · 3 comments

Comments

@paramjeet2021
Copy link

I had converted the model into _slim format, not _slim_f16.

Now, when I execute the colab code and I get below error

loading netwrok from the Google storage
read from disk/gcs in 107.5s
Traceback (most recent call last):
  File "content_generation.py", line 90, in <module>
    print(infer(top_p=top_p, temp=temp, gen_len=512, context=context)[0])
  File "content_generation.py", line 77, in infer
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})
  File "/home/paramjeetsingh80/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 309, in generate
    return self.generate_xmap(self.state,
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 615, in fun_mapped
    out_flat = xmap_p.bind(
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 818, in bind
    return core.call_bind(self, fun, *args, **params)  # type: ignore
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 821, in process
    return trace.process_xmap(self, fun, tracers, params)
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/core.py", line 606, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 646, in xmap_impl
    xmap_callable = make_xmap_callable(
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/linear_util.py", line 262, in memoized_fun
    ans = call(fun, *args)
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 673, in make_xmap_callable
    _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
  File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/jax/experimental/maps.py", line 1454, in _check_out_avals_vs_out_axes
    raise TypeError(f"One of xmap results has an out_axes specification of "
TypeError: One of xmap results has an out_axes specification of ['batch', ...], but is actually mapped along more axes defined by this xmap call: shard
@paramjeet2021
Copy link
Author

With Jax==0.2.12 I get the following error:

import jax
2021-07-20 11:07:32.634036: F external/org_tensorflow/tensorflow/core/tpu/tpu_executor_init_fns.inc:110] TpuTransferManager_ReadDynamicShapes not available in this library.
Aborted (core dumped) 

@kingoflolz
Copy link
Owner

Please use this script to install your dependencies, as you need both an older version of jax as well as libtpu https://github.com/kingoflolz/mesh-transformer-jax/blob/master/scripts/init_ray.sh

@paramjeet2021
Copy link
Author

paramjeet2021 commented Jul 21, 2021

I followed this script and executed this code line by line in command prompt. Once installed then imported and it gave below error

>>> import jax
2021-07-21 09:16:15.453982: F external/org_tensorflow/tensorflow/core/tpu/tpu_executor_init_fns.inc:110] TpuTransferManager_ReadDynamicShapes not available in this library.
Aborted (core dumped)

Then I creared a new TPU node and followed the ssh into the TPU and ran the same scrips line by line in command prompt then I got the following error:

import jax
2021-07-21 08:55:12.813348: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2021-07-21 08:55:13.039260: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2021-07-21 08:55:13.045162: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory

Is there any other way to properly setup Jax - 0.2.12 or could you suggest change in colab_demo code such that it work with new Jax versions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants