You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I would like to install trax locally. Firstly, I found jax I installed is not suitable for GPU, so I follow the jax github to install Cuda version jax. Next, I validate jax could detect GPU in my local computer, but I could not run the sample code like transfomer and fast math.
# Error logs:
1) Run the sample code of pre-trained transformer in your Realme tutorial
- code:
import os
import numpy as np
import trax
# Create a Transformer model.
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(
input_vocab_size=33300,
d_model=512, d_ff=2048,
n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
max_len=64, mode='predict')
# Initialize using pre-trained weights.
model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
weights_only=True)
# input_signature=input_signature)
# Tokenize a sentence.
sentence = 'It is nice to learn new things today!'
tokenized = list(trax.data.tokenize(iter([sentence]), # Operates on streams.
vocab_dir='gs://trax-ml/vocabs/',
vocab_file='ende_32k.subword'))[0]
# Decode from the Transformer.
tokenized = tokenized[None, :] # Add batch dimension.
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
model, tokenized, temperature=0.0) # Higher temperature: more diverse results.
# De-tokenize,
tokenized_translation = tokenized_translation[0][:-1] # Remove batch and EOS.
translation = trax.data.detokenize(tokenized_translation,
vocab_dir='gs://trax-ml/vocabs/',
vocab_file='ende_32k.subword')
print(translation)
- Error Output:
2023-06-22 15:58:35.266959: W tensorflow/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".
2023-06-22 15:58:56.630331: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Failed to get stream's capture status: out of memory
2023-06-22 15:58:56.630403: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2461] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to get stream's capture status: out of memory; current tracing scope: fusion; current profiling annotation: XlaModule:#hlo_module=jit_PRNGKey,program_id=0#.
Traceback (most recent call last):
File "/home/littleliu/Documents/project/trax_learning/tryTrax.py", line 22, in <module>
model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/trax/layers/base.py", line 349, in init_from_file
self.init(input_signature)
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/trax/layers/base.py", line 310, in init
raise LayerError(name, 'init', self._caller,
trax.layers.base.LayerError: Exception passing through layer Serial (in init):
layer created in file [...]/trax/models/transformer.py, line 371
layer input shapes: (ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:float32})
File [...]/trax/layers/combinators.py, line 108, in init_weights_and_state
outputs, _ = sublayer._forward_abstract(inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File [...]/trax/layers/base.py, line 641, in _forward_abstract
layer created in file [...]/trax/models/transformer.py, line 372
layer input shapes: (ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:int64})
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to get stream's capture status: out of memory; current tracing scope: fusion; current profiling annotation: XlaModule:#hlo_module=jit_PRNGKey,program_id=0#.
2) Run the sample code of Fast Math:
- code:
import trax
from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend('jax') # Can be 'jax' or 'tensorflow-numpy'.
matrix = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f'matrix =\n{matrix}')
vector = fastnp.ones(3)
print(f'vector = {vector}')
product = fastnp.dot(vector, matrix)
print(f'product = {product}')
tanh = fastnp.tanh(product)
print(f'tanh(product) = {tanh}')
- Error Output:
matrix =
[[1 2 3]
[4 5 6]
[7 8 9]]
2023-06-22 16:03:23.041313: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2023-06-22 16:03:23.041386: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 36175872 bytes free, 4093902848 bytes total.
2023-06-22 16:03:23.041476: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:453] Possibly insufficient driver version: 525.85.5
Traceback (most recent call last):
File "/home/littleliu/Documents/project/trax_learning/fastnumpy.py", line 7, in <module>
vector = fastnp.ones(3)
^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 2161, in ones
return lax.full(shape, 1, _jnp_dtype(dtype))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 1205, in full
return broadcast(fill_value, shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 768, in broadcast
return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 796, in broadcast_in_dim
return broadcast_in_dim_p.bind(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/core.py", line 380, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/core.py", line 790, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 132, in apply_primitive
compiled_fun = xla_primitive_callable(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/util.py", line 284, in wrapper
return cached(config._trace_context(), *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/util.py", line 277, in cached
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 223, in xla_primitive_callable
compiled = _xla_callable_uncached(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 253, in _xla_callable_uncached
return computation.compile().unsafe_call
^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2329, in compile
executable = UnloadedMeshExecutable.from_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2651, in from_hlo
xla_executable, compile_options = _cached_compilation(
^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2561, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
The text was updated successfully, but these errors were encountered:
Description
Hi, I would like to install trax locally. Firstly, I found jax I installed is not suitable for GPU, so I follow the jax github to install Cuda version jax. Next, I validate jax could detect GPU in my local computer, but I could not run the sample code like transfomer and fast math.
Environment information
For bugs: reproduction and error logs
The text was updated successfully, but these errors were encountered: