Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to make JAX code run on single GPU instead of TPU? #33

Closed
ghost opened this issue Aug 13, 2021 · 2 comments
Closed

how to make JAX code run on single GPU instead of TPU? #33

ghost opened this issue Aug 13, 2021 · 2 comments

Comments

@ghost
Copy link

ghost commented Aug 13, 2021

I'm trying to run this example (JAX branch):

import sys

from absl import logging
from ferminet.utils import system
from ferminet import base_config
from ferminet import train

# Optional, for also printing training progress to STDOUT.
# If running a script, you can also just use the --alsologtostderr flag.
logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)

# Define H2 molecule
cfg = base_config.default()
cfg.system.electrons = (1,1)  # (alpha electrons, beta electrons)
cfg.system.molecule = [system.Atom('H', (0, 0, -1)), system.Atom('H', (0, 0, 1))]

# Set training parameters
cfg.batch_size = 256
cfg.pretrain.iterations = 100

train.train(cfg)

At train.train(cfg), the code seems to be running on TPU by default, how to change it to run on a single GPU instead?

INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
INFO:absl:Starting QMC with 1 XLA devices

@ghost
Copy link
Author

ghost commented Aug 13, 2021

seems im having a jax installation issue

@ghost ghost closed this as completed Aug 13, 2021
@jsspencer
Copy link
Collaborator

This is a standard message. By default, jax first attempts to run on TPU, then if it can't find one (which the second and third line show), it attempts to run on GPU and then CPU.

>>> import jax
>>> jax.local_devices()

will show what devices jax is running on.

This issue was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant