Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions project-words.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Babuska
Benadjila
Bursztein
CHES
CUDA
Cassiers
Chmielewski
DSOWINDOW
Expand Down Expand Up @@ -68,6 +69,7 @@ allclose
arange
argmax
argmin
argnames
argsort
ascad
astrojs
Expand Down Expand Up @@ -203,6 +205,7 @@ udevadm
usbutils
vals
vispy
vmap
vsintellicode
webassets
xlabel
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ capture = [
]
dev = [
"flake8",
"jax[cuda13]",
"mypy",
"psutil",
"pydantic",
Expand Down
16 changes: 14 additions & 2 deletions scaaml/stats/cpa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,21 @@
# limitations under the License.
"""Correlation power analysis (CPA) module. This module is structured to
support multiple backends (e.g., a GPU accelerated JAX implementation and a
NumPy implementation). It currently provides the NumPy version.
NumPy implementation). If JAX is installed defaults to JAX implementation
otherwise falls back to the NumPy one.

If a concrete version is needed use:
```python
from scaaml.stats.cpa.cpa import CPA
# from scaaml.stats.cpa.cpa_jax import CPA
```
"""
from scaaml.stats.cpa.cpa import CPA # NumPy based
try:
# JAX based if JAX is installed
from scaaml.stats.cpa.cpa_jax import CPA
except ImportError:
# NumPy based default
from scaaml.stats.cpa.cpa import CPA # type: ignore[assignment]

__all__ = [
"CPA",
Expand Down
Loading
Loading