Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstream breaking change in kfac-jax #70

Closed
gcassella opened this issue Nov 6, 2023 · 0 comments · Fixed by #71
Closed

Upstream breaking change in kfac-jax #70

gcassella opened this issue Nov 6, 2023 · 0 comments · Fixed by #71

Comments

@gcassella
Copy link
Contributor

The most recent commit to the kfac-jax repo (at the time of writing, f466559d86b07d6a2291cc699ac769c8e0931592) contains a breaking change for the ferminet repository. Last working commit is bacdf8eaf4f5bd1a467b7e9d9703e571ed37c897. Following the installation / usage instructions in README.md will result in a broken installation as a result.

To reproduce, install as per usual instructions and run:

import sys

from absl import logging
from ferminet.utils import system
from ferminet import base_config
from ferminet import train

# Settings in a config files are loaded by executing the the get_config
# function.
def get_config():
  # Get default options.
  cfg = base_config.default()
  # Set up molecule
  cfg.system.electrons = (1,1)
  cfg.system.molecule = [system.Atom('H', (0, 0, -1)), system.Atom('H', (0, 0, 1))]

  # Set training hyperparameters
  cfg.batch_size = 256
  cfg.pretrain.iterations = 100

  return cfg

# Optional, for also printing training progress to STDOUT.
# If running a script, you can also just use the --alsologtostderr flag.
logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)

# Define H2 molecule
cfg = get_config()
train.train(cfg)

resulting in stack trace,

Traceback (most recent call last):
  File "/home/ettore/ferminet/test.py", line 6, in <module>
    from ferminet import train
  File "/home/ettore/ferminet/ferminet/train.py", line 24, in <module>
    from ferminet import checkpoint
  File "/home/ettore/ferminet/ferminet/checkpoint.py", line 24, in <module>
    from ferminet import networks
  File "/home/ettore/ferminet/ferminet/networks.py", line 21, in <module>
    from ferminet import envelopes
  File "/home/ettore/ferminet/ferminet/envelopes.py", line 21, in <module>
    from ferminet import curvature_tags_and_blocks
  File "/home/ettore/ferminet/ferminet/curvature_tags_and_blocks.py", line 27, in <module>
    vmap_psd_inv_cholesky = jax.vmap(kfac_jax.utils.psd_inv_cholesky, (0, None), 0)
AttributeError: module 'kfac_jax._src.utils' has no attribute 'psd_inv_cholesky'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant