-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make APIs as uniform as possible #98
Comments
Im open for anything really since at our end we have to write wrappers anyhow to make this compatible with how we organize computations. |
We had a discussion about this today, here is a quick summary:
JuliaCurrently does something like this basis = SolidHarmonics(10)
# one of
sph = basis(R)
sph = compute(basis, R) JaxCurrently does something like this sph = compute_spherical_harmonics(lmax=10, R, normalize=True) We can change it to calculator = SphericalHarmonics(lmax=10)
sph = compute_spherical_harmonics(calculator, R)
sph = calculator(R)
calculator = SolidHarmonics(lmax=10)
sph = compute_solid_harmonics(calculator, R)
sph = calculator(R) Torch/NumpyWe can change these to calculator = SphericalHarmonics(lmax=10)
sph = calculator(R)
# add this one
calculator = SolidHarmonics(lmax=10)
sph = calculator(R) |
Also would be nice to be very explicit about which part of the "what sphericart computes" article maps to which arguments/classes. Currently it's technically written down, but a bit hard to parse. |
My only gripe with what you suggest is that the type of the calculator should determine whether it's spherical or solids - hence just "compute". But as a general rule I'm not certain that unifying too much is even a good thing. Different languages and different frameworks make different usage conventions natural. |
For For |
The idea for jax was that we can use the calculator object to cache the state we need (right now we are using a hidden global cache). We are pretty confident it should be possible to make it work with the jit (maybe defining the state as a PyTree).
This works really well in Julia where type-based dispatch is king, but in jax in particular this patter is much harder to integrate with the function transformations (jit, grad, vmap, …) |
Hm... does this state even need to be visible to |
At the moment, the APIs for
C/C++/NumPy/torch
,JAX
,Julia
andCUDA
are all slightly different. We should discuss up to what point we should aim at making them uniform and where we should instead give way to the idioms of each language/frameworkThe text was updated successfully, but these errors were encountered: