Skip to content

Commit

Permalink
fix: lint and improve lint workflow (#28757)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed May 23, 2024
1 parent e00376d commit 2ba7e79
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 30 deletions.
25 changes: 22 additions & 3 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,25 @@ permissions:
contents: write

jobs:
lint:
uses: unifyai/workflows/.github/workflows/lint.yml@main
secrets: inherit
check-formatting:
runs-on: ubuntu-latest
steps:
- name: Checkout 🛎️ ${{ github.event.repository.name }}
uses: actions/checkout@v4

- name: Get changed files
if: github.event_name == 'pull_request'
id: changed-files
uses: lots0logs/gh-action-get-changed-files@2.1.4
with:
token: ${{ secrets.GITHUB_TOKEN }}

- name: Setup Python 🐍
uses: actions/setup-python@v4
with:
python-version: 3.10.14

- name: Run pre-commit 🚀
uses: pre-commit/action@v3.0.0
with:
extra_args: ${{ github.event_name == 'pull_request' && format('--files {0}', join(fromJSON(steps.changed-files.outputs.all), ' ')) || '--all-files' }}
2 changes: 1 addition & 1 deletion docker/requirement_mappings_gpu.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
"tensorflow": [
"tensorflow-probability"
]
}
}
2 changes: 1 addition & 1 deletion docker/requirement_mappings_multiversion.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@
"torch": [
"torchvision"
]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,32 @@ def test_jax_nanmin(
)


# nanpercentile
@handle_frontend_test(
fn_tree="jax.numpy.nanpercentile",
dtype_and_x=_percentile_helper(),
keep_dims=st.booleans(),
test_gradients=st.just(False),
test_with_out=st.just(False),
)
def test_jax_nanpercentile(
*, dtype_and_x, keep_dims, test_flags, backend_fw, fn_name, on_device
):
input_dtype, x, axis, interpolation, q = dtype_and_x
helpers.test_function(
input_dtypes=input_dtype,
test_flags=test_flags,
backend_to_test=backend_fw,
fn_name=fn_name,
on_device=on_device,
a=x[0],
q=q,
axis=axis,
interpolation=interpolation[0],
keepdims=keep_dims,
)


# nanstd
@handle_frontend_test(
fn_tree="jax.numpy.nanstd",
Expand Down Expand Up @@ -1320,28 +1346,3 @@ def test_jax_var(
atol=1e-3,
rtol=1e-3,
)


@handle_frontend_test(
fn_tree="jax.numpy.nanpercentile",
dtype_and_x=_percentile_helper(),
keep_dims=st.booleans(),
test_gradients=st.just(False),
test_with_out=st.just(False),
)
def test_jax_nanpercentile(
*, dtype_and_x, keep_dims, test_flags, backend_fw, fn_name, on_device
):
input_dtype, x, axis, interpolation, q = dtype_and_x
helpers.test_function(
input_dtypes=input_dtype,
test_flags=test_flags,
backend_to_test=backend_fw,
fn_name=fn_name,
on_device=on_device,
a=x[0],
q=q,
axis=axis,
interpolation=interpolation[0],
keepdims=keep_dims,
)

0 comments on commit 2ba7e79

Please sign in to comment.