Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. #8

Open
johnnytam100 opened this issue Dec 8, 2023 · 2 comments

Comments

@johnnytam100
Copy link

Hi AF2Rank authors! I was trying to run the colab notebook but met the following DNN library initialization failed error.
When I run the first cell under "## rank structures"

NAME = "1mjc"
CHAIN = "A" # this can be multiple chains
NATIVE_PATH = f"{NAME}.pdb"
DECOY_DIR = f"{NAME}"

if save_output_pdbs:
  os.makedirs(f"{NAME}_output",ok_exists=True)


# get data
%shell wget -qnc https://files.ipd.uw.edu/pub/decoyset/natives/{NAME}.pdb
%shell wget -qnc https://files.ipd.uw.edu/pub/decoyset/decoys/{NAME}.zip
%shell unzip -qqo {NAME}.zip

# setup model
clear_mem()
af = af2rank(NATIVE_PATH, CHAIN, model_name=SETTINGS["model_name"])

I met the following error

---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
[<ipython-input-4-5740c8be5964>](https://localhost:8080/#) in <cell line: 17>()
     15 # setup model
     16 clear_mem()
---> 17 af = af2rank(NATIVE_PATH, CHAIN, model_name=SETTINGS["model_name"])

21 frames
[<ipython-input-2-b0744074ff0e>](https://localhost:8080/#) in __init__(self, pdb, chain, model_name, model_names)
     75                  "model_name":model_name,
     76                  "model_names":model_names}
---> 77     self.reset()
     78 
     79   def reset(self):

[<ipython-input-2-b0744074ff0e>](https://localhost:8080/#) in reset(self)
     78 
     79   def reset(self):
---> 80     self.model = mk_af_model(protocol="fixbb",
     81                              use_templates=True,
     82                              use_multimer=self.args["use_multimer"],

[/content/colabdesign/af/model.py](https://localhost:8080/#) in __init__(self, protocol, use_multimer, use_templates, debug, data_dir, **kwargs)
    118     self._model_params, self._model_names = [],[]
    119     for model_name in model_names:
--> 120       params = data.get_model_haiku_params(model_name=model_name, data_dir=data_dir, fuse=True)
    121       if params is not None:
    122         if not self._args["use_multimer"] and not self._args["use_templates"]:

[/content/colabdesign/af/alphafold/model/data.py](https://localhost:8080/#) in get_model_haiku_params(model_name, data_dir, fuse)
     39     with open(path, 'rb') as f:
     40       params = np.load(io.BytesIO(f.read()), allow_pickle=False)
---> 41     return utils.flat_params_to_haiku(params, fuse=fuse)

[/content/colabdesign/af/alphafold/model/utils.py](https://localhost:8080/#) in flat_params_to_haiku(params, fuse)
    108             P[f"{k}/{c}"] = {}
    109             for d in ["bias","weights"]:
--> 110               P[f"{k}/{c}"][d] = jnp.concatenate([L[d],R[d]],-1)
    111           P[f"{k}/center_norm"] = P.pop(f"{k}/center_layer_norm")
    112           P[f"{k}/left_norm_input"] = P.pop(f"{k}/layer_norm_input")

[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in concatenate(arrays, axis, dtype)
   1852   k = 16
   1853   while len(arrays_out) > 1:
-> 1854     arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
   1855                   for i in range(0, len(arrays_out), k)]
   1856   return arrays_out[0]

[/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in <listcomp>(.0)
   1852   k = 16
   1853   while len(arrays_out) > 1:
-> 1854     arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
   1855                   for i in range(0, len(arrays_out), k)]
   1856   return arrays_out[0]

[/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py](https://localhost:8080/#) in concatenate(operands, dimension)
    615     if isinstance(op, Array):
    616       return type_cast(Array, op)
--> 617   return concatenate_p.bind(*operands, dimension=dimension)
    618 
    619 

[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in bind(self, *args, **params)
    384     assert (not config.jax_enable_checks or
    385             all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 386     return self.bind_with_trace(find_top_trace(args), args, params)
    387 
    388   def bind_with_trace(self, trace, args, params):

[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in bind_with_trace(self, trace, args, params)
    387 
    388   def bind_with_trace(self, trace, args, params):
--> 389     out = trace.process_primitive(self, map(trace.full_raise, args), params)
    390     return map(full_lower, out) if self.multiple_results else full_lower(out)
    391 

[/usr/local/lib/python3.10/dist-packages/jax/_src/core.py](https://localhost:8080/#) in process_primitive(self, primitive, tracers, params)
    819 
    820   def process_primitive(self, primitive, tracers, params):
--> 821     return primitive.impl(*tracers, **params)
    822 
    823   def process_call(self, primitive, f, tracers, params):

[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in apply_primitive(prim, *args, **params)
    129   try:
    130     in_avals, in_shardings = util.unzip2([arg_spec(a) for a in args])
--> 131     compiled_fun = xla_primitive_callable(
    132         prim, in_avals, OrigShardings(in_shardings), **params)
    133   except pxla.DeviceAssignmentMismatchError as e:

[/usr/local/lib/python3.10/dist-packages/jax/_src/util.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    261         return f(*args, **kwargs)
    262       else:
--> 263         return cached(config._trace_context(), *args, **kwargs)
    264 
    265     wrapper.cache_clear = cached.cache_clear

[/usr/local/lib/python3.10/dist-packages/jax/_src/util.py](https://localhost:8080/#) in cached(_, *args, **kwargs)
    254     @functools.lru_cache(max_size)
    255     def cached(_, *args, **kwargs):
--> 256       return f(*args, **kwargs)
    257 
    258     @functools.wraps(f)

[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in xla_primitive_callable(prim, in_avals, orig_in_shardings, **params)
    220       return out,
    221   donated_invars = (False,) * len(in_avals)
--> 222   compiled = _xla_callable_uncached(
    223       lu.wrap_init(prim_fun), prim.name, donated_invars, False, in_avals,
    224       orig_in_shardings)

[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_callable_uncached(fun, name, donated_invars, keep_unused, in_avals, orig_in_shardings)
    250       fun, name, donated_invars, keep_unused, True, in_avals, orig_in_shardings,
    251       lowering_platform=None)
--> 252   return computation.compile().unsafe_call
    253 
    254 

[/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in compile(self, compiler_options)
   2204             **self.compile_args)
   2205       else:
-> 2206         executable = UnloadedMeshExecutable.from_hlo(
   2207             self._name,
   2208             self._hlo,

[/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in from_hlo(***failed resolving arguments***)
   2542           break
   2543 
-> 2544     xla_executable, compile_options = _cached_compilation(
   2545         hlo, name, mesh, spmd_lowering,
   2546         tuple_args, auto_spmd_lowering, allow_prop_to_outputs,

[/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py](https://localhost:8080/#) in _cached_compilation(computation, name, mesh, spmd_lowering, tuple_args, auto_spmd_lowering, _allow_propagation_to_outputs, host_callbacks, backend, da, pmap_nreps, compiler_options_keys, compiler_options_values)
   2452       "Finished XLA compilation of {fun_name} in {elapsed_time} sec",
   2453       fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
-> 2454     xla_executable = dispatch.compile_or_get_cached(
   2455         backend, computation, dev, compile_options, host_callbacks)
   2456   return xla_executable, compile_options

[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in compile_or_get_cached(backend, computation, devices, compile_options, host_callbacks)
    494 
    495   if not use_compilation_cache:
--> 496     return backend_compile(backend, computation, compile_options,
    497                            host_callbacks)
    498 

[/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    312   def wrapper(*args, **kwargs):
    313     with TraceAnnotation(name, **decorator_kwargs):
--> 314       return func(*args, **kwargs)
    315     return wrapper
    316   return wrapper

[/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in backend_compile(backend, module, options, host_callbacks)
    462   # TODO(sharadmv): remove this fallback when all backends allow `compile`
    463   # to take in `host_callbacks`
--> 464   return backend.compile(built_c, compile_options=options)
    465 
    466 _ir_dump_counter = itertools.count()

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

Do you know how to deal with this error?
Thanks!

@sokrypton
Copy link
Collaborator

Hmmm... I dont get this error. Which notebook version are you using?

I tried with:
https://colab.research.google.com/github/sokrypton/ColabDesign/blob/main/af/examples/AF2Rank.ipynb
image

@johnnytam100
Copy link
Author

Hi Sergey! I am using the latest Google Colab but seems the version info is not provided..
May I know what environment are you using?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants