Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make APIs as uniform as possible #98

Open
frostedoyster opened this issue Jan 19, 2024 · 7 comments
Open

Make APIs as uniform as possible #98

frostedoyster opened this issue Jan 19, 2024 · 7 comments

Comments

@frostedoyster
Copy link
Collaborator

At the moment, the APIs for C/C++/NumPy/torch, JAX, Julia and CUDA are all slightly different. We should discuss up to what point we should aim at making them uniform and where we should instead give way to the idioms of each language/framework

@cortner
Copy link
Member

cortner commented Jan 19, 2024

Im open for anything really since at our end we have to write wrappers anyhow to make this compatible with how we organize computations.

@Luthaf
Copy link
Contributor

Luthaf commented Apr 8, 2024

We had a discussion about this today, here is a quick summary:

  • use separate classes/objects for spherical harmonics and solid harmonics instead of a "normalize" parameter
  • use functional-looking API for most things, using a custom __call__ for numpy and torch

Julia

Currently does something like this

basis = SolidHarmonics(10)

# one of
sph = basis(R)
sph = compute(basis, R)

Jax

Currently does something like this

sph = compute_spherical_harmonics(lmax=10, R, normalize=True)

We can change it to

calculator = SphericalHarmonics(lmax=10)
sph = compute_spherical_harmonics(calculator, R)
sph = calculator(R)

calculator = SolidHarmonics(lmax=10)
sph = compute_solid_harmonics(calculator, R)
sph = calculator(R)

Torch/Numpy

We can change these to

calculator = SphericalHarmonics(lmax=10)
sph = calculator(R)

# add this one
calculator = SolidHarmonics(lmax=10)
sph = calculator(R)

@sirmarcel
Copy link

Also would be nice to be very explicit about which part of the "what sphericart computes" article maps to which arguments/classes. Currently it's technically written down, but a bit hard to parse.

@cortner
Copy link
Member

cortner commented Apr 8, 2024

My only gripe with what you suggest is that the type of the calculator should determine whether it's spherical or solids - hence just "compute".

But as a general rule I'm not certain that unifying too much is even a good thing. Different languages and different frameworks make different usage conventions natural.

@sirmarcel
Copy link

For jax, I'd suggest following the e3x API, i.e., just having functions solid_harmonics and spherical_harmonics. compute_spherical_harmonics accepting some opaque calculator object as argument doesn't seem very intuitive and probably doesn't play nicely with jit.

For torch, the suggest approach is good -- the classes should have a forward(self, xyz) and nothing else.

@Luthaf
Copy link
Contributor

Luthaf commented Apr 9, 2024

The idea for jax was that we can use the calculator object to cache the state we need (right now we are using a hidden global cache). We are pretty confident it should be possible to make it work with the jit (maybe defining the state as a PyTree).

My only gripe with what you suggest is that the type of the calculator should determine whether it's spherical or solids - hence just "compute".

This works really well in Julia where type-based dispatch is king, but in jax in particular this patter is much harder to integrate with the function transformations (jit, grad, vmap, …)

@sirmarcel
Copy link

Hm... does this state even need to be visible to jax/python at all if you're already calling out to custom code? Otherwise, you maybe can initialise it at tracing/"compile" time... it seems a bit clunky to carry around some state for this. I assume this needs to be updated/changed based on the requested spherical harmonic order? Or what is being cached here?

@Luthaf Luthaf mentioned this issue Jul 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants