Skip to content

Commit

Permalink
feat: add JAX as Computation Backend (#1646)
Browse files Browse the repository at this point in the history
Signed-off-by: agaraman0 <agaraman0@gmail.com>
  • Loading branch information
agaraman0 committed Jul 18, 2023
1 parent d2e1858 commit b306c80
Show file tree
Hide file tree
Showing 34 changed files with 2,169 additions and 60 deletions.
48 changes: 46 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ jobs:
python -m pip install poetry
poetry install --without dev
poetry run pip install tensorflow==2.11.0
poetry run pip install jax
- name: Test basic import
run: poetry run python -c 'from docarray import DocList, BaseDoc'

Expand Down Expand Up @@ -111,7 +112,7 @@ jobs:
- name: Test
id: test
run: |
poetry run pytest -m "not (tensorflow or benchmark or index)" --cov=docarray --cov-report=xml ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py
poetry run pytest -m "not (tensorflow or benchmark or index or jax)" --cov=docarray --cov-report=xml ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py
echo "flag it as docarray for codeoverage"
echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
Expand Down Expand Up @@ -158,7 +159,7 @@ jobs:
- name: Test
id: test
run: |
poetry run pytest -m "not (tensorflow or benchmark or index)" --cov=docarray --cov-report=xml tests/integrations/store/test_jac.py
poetry run pytest -m "not (tensorflow or benchmark or index or jax)" --cov=docarray --cov-report=xml tests/integrations/store/test_jac.py
echo "flag it as docarray for codeoverage"
echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
Expand Down Expand Up @@ -357,6 +358,49 @@ jobs:
flags: ${{ steps.test.outputs.codecov_flag }}
fail_ci_if_error: false

docarray-test-jax:
needs: [import-test]
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2.5.0
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
poetry install --all-extras
poetry run pip install jaxlib
poetry run pip install jax
- name: Test
id: test
run: |
poetry run pytest -m 'jax' --cov=docarray --cov-report=xml tests
echo "flag it as docarray for codeoverage"
echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
- name: Check codecov file
id: check_files
uses: andstor/file-existence-action@v1
with:
files: "coverage.xml"
- name: Upload coverage from test to Codecov
uses: codecov/codecov-action@v3.1.1
if: steps.check_files.outputs.files_exists == 'true' && ${{ matrix.python-version }} == '3.8'
with:
file: coverage.xml
name: benchmark-test-codecov
flags: ${{ steps.test.outputs.codecov_flag }}
fail_ci_if_error: false



docarray-test-benchmarks:
needs: [import-test]
Expand Down
28 changes: 26 additions & 2 deletions docarray/array/doc_vec/doc_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
from docarray.typing import NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal._typing import is_tensor_union, safe_issubclass
from docarray.utils._internal.misc import is_tf_available, is_torch_available
from docarray.utils._internal.misc import (
is_jax_available,
is_tf_available,
is_torch_available,
)

if TYPE_CHECKING:
import csv
Expand Down Expand Up @@ -60,6 +64,14 @@
else:
TensorFlowTensor = None # type: ignore

jnp_available = is_jax_available()
if jnp_available:
import jax.numpy as jnp # type: ignore

from docarray.typing import JaxArray # noqa: F401
else:
JaxArray = None # type: ignore

T_doc = TypeVar('T_doc', bound=BaseDoc)
T = TypeVar('T', bound='DocVec')
T_io_mixin = TypeVar('T_io_mixin', bound='IOMixinArray')
Expand Down Expand Up @@ -262,6 +274,19 @@ def _check_doc_field_not_none(field_name, doc):

stacked: tf.Tensor = tf.stack(tf_stack)
tensor_columns[field_name] = TensorFlowTensor(stacked)
elif jnp_available and issubclass(field_type, JaxArray):
if first_doc_is_none:
_verify_optional_field_of_docs(docs)
tensor_columns[field_name] = None
else:
tf_stack = []
for i, doc in enumerate(docs):
val = getattr(doc, field_name)
_check_doc_field_not_none(field_name, doc)
tf_stack.append(val.tensor)

jax_stacked: jnp.ndarray = jnp.stack(tf_stack)
tensor_columns[field_name] = JaxArray(jax_stacked)

elif safe_issubclass(field_type, AbstractTensor):
if first_doc_is_none:
Expand Down Expand Up @@ -835,7 +860,6 @@ def to_doc_list(self: T) -> DocList[T_doc]:
unstacked_doc_column[field] = doc_col.to_doc_list() if doc_col else None

for field, da_col in self._storage.docs_vec_columns.items():

unstacked_da_column[field] = (
[docs.to_doc_list() for docs in da_col] if da_col else None
)
Expand Down
15 changes: 14 additions & 1 deletion docarray/array/list_advance_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from typing_extensions import SupportsIndex

from docarray.utils._internal.misc import (
is_torch_available,
is_jax_available,
is_tf_available,
is_torch_available,
)

torch_available = is_torch_available()
Expand All @@ -24,7 +25,13 @@
tf_available = is_tf_available()
if tf_available:
import tensorflow as tf # type: ignore

from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor
jax_available = is_jax_available()
if jax_available:
import jax.numpy as jnp

from docarray.typing.tensor.jaxarray import JaxArray

T_item = TypeVar('T_item')
T = TypeVar('T', bound='ListAdvancedIndexing')
Expand Down Expand Up @@ -100,6 +107,12 @@ def _normalize_index_item(
if isinstance(item, TensorFlowTensor):
return item.tensor.numpy().tolist()

if jax_available:
if isinstance(item, jnp.ndarray):
return item.__array__().tolist()
if isinstance(item, JaxArray):
return item.tensor.__array__().tolist()

return item

def _get_from_indices(self: T, item: Iterable[int]) -> T:
Expand Down

0 comments on commit b306c80

Please sign in to comment.