Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add affine_grid op #5225

Merged
merged 39 commits into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
c54ccf4
initial change
liqunfu May 13, 2023
6de815c
passed 2d non-align_corners function body in parser test
liqunfu May 14, 2023
275293d
passed 2/3-d align or not corvers with parser test
liqunfu May 14, 2023
6a02339
update def FunctionBody and test models
liqunfu May 14, 2023
13ff153
shape inference (incomplete), reference implementation, and more tests
liqunfu May 22, 2023
3cec59d
all tests passed, fixed most comments
liqunfu May 22, 2023
cfdd63c
replace mul of ones with add of zeros
liqunfu May 22, 2023
ba3d48f
merge
liqunfu May 23, 2023
eed4463
lint and undo parser test
liqunfu May 23, 2023
eee0c33
size to be int64 only - shape dtype
liqunfu May 23, 2023
86233f7
skip affine_grid test_backend_onnxruntime.py
liqunfu May 23, 2023
945b5be
skip test comment
liqunfu May 23, 2023
f195e19
lint
liqunfu May 23, 2023
82329ed
lint
liqunfu May 23, 2023
3122077
all_tensor_types_ir4
liqunfu May 23, 2023
30ef8c4
docs
liqunfu May 23, 2023
2de0700
formatting
liqunfu May 23, 2023
9f20a65
formatting
liqunfu May 23, 2023
0cb6410
formatting
liqunfu May 23, 2023
2639591
function mody to avoid using epsilon
liqunfu May 30, 2023
dcd567a
Merge branch 'main' into liqun/affine_grid
liqunfu Jun 21, 2023
6de7e00
use reference implementation to generate test case for affine_grid
liqunfu Jun 22, 2023
364c39c
has both test types
liqunfu Jun 27, 2023
5bd2de8
merge main
liqunfu Jul 9, 2023
090ae54
lint
liqunfu Jul 10, 2023
76b0aad
lint
liqunfu Jul 10, 2023
05e6ae5
crlf to lf
liqunfu Jul 11, 2023
ac4627e
Merge branch 'main' into liqun/affine_grid
liqunfu Jul 11, 2023
a5b9d23
TestCoverage.md
liqunfu Jul 11, 2023
ba2f2e3
symbolic tests
liqunfu Jul 11, 2023
31d670e
Merge branch 'main' into liqun/affine_grid
liqunfu Jul 19, 2023
63baf6d
docs/TestCoverage.md
liqunfu Jul 19, 2023
07e48cc
reviewers' comments
liqunfu Jul 27, 2023
29e72ac
Merge branch 'main' into liqun/affine_grid
liqunfu Jul 27, 2023
c9b9cc8
remove unneeded test cases that use ref evaluator
liqunfu Jul 27, 2023
57c49e8
lint
liqunfu Jul 27, 2023
bc12bd6
add comment for assert
liqunfu Jul 31, 2023
45ac099
Merge branch 'main' into liqun/affine_grid
liqunfu Jul 31, 2023
f47cdcd
lint
liqunfu Jul 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
66 changes: 66 additions & 0 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -23881,6 +23881,72 @@ This version of the operator has been available since version 19 of the default
</dl>

## Version 20 of the default ONNX operator set
### <a name="AffineGrid-20"></a>**AffineGrid-20**</a>

Generates a 2D or 3D flow field (sampling grid), given a batch of affine matrices theta
(https://pytorch.org/docs/stable/generated/torch.nn.functional.affine_grid.html).
An affine matrix `theta` is applied to a position tensor represented in its homogeneous expression. Here is an example in 3D:
```
[r00, r01, r02, t0] [x] [x']
[r10, r11, r12, t1] * [y] = [y']
[r20, r21, r22, t2] [z] [z']
[0, 0, 0, 1 ] [1] [1 ]
```
where `(x, y, z)` is the position in the original space, `(x', y', z')` is the position in the output space.
The last row is always `[0, 0, 0, 1]` and is not stored in the affine matrix. Therefore we have `theta` of shape `(N, 2, 3)` for 2D or `(N, 3, 4)` for 3D.

Input `size` is used to define grid of positions evenly spaced in the original 2D or 3D space, with dimensions ranging from `-1` to `1`.
The output `grid` contains positions in the output space.

When `align_corners=1`, consider `-1` and `1` to refer to the centers of the corner pixels (mark `v` in illustration).
```
v v v v
|-------------------|------------------|
-1 0 1
```
When `align_corners=0`, consider `-1` and `1` to refer to the outer edge of the corner pixels.
```
v v v v
|------------------|-------------------|
-1 0 1
```

#### Version

This version of the operator has been available since version 20 of the default ONNX operator set.

#### Attributes

<dl>
<dt><tt>align_corners</tt> : int (default is 0)</dt>
<dd>if align_corners=1, consider -1 and 1 to refer to the centers of the corner pixels. if align_corners=0, consider -1 and 1 to refer to the outer edge the corner pixels.</dd>
</dl>

#### Inputs

<dl>
<dt><tt>theta</tt> (non-differentiable) : T1</dt>
<dd>input batch of affine matrices with shape (N, 2, 3) for 2D or (N, 3, 4) for 3D</dd>
<dt><tt>size</tt> (non-differentiable) : T2</dt>
<dd>the target output image size (N, C, H, W) for 2D or (N, C, D, H, W) for 3D</dd>
</dl>

#### Outputs

<dl>
<dt><tt>grid</tt> (differentiable) : T1</dt>
<dd>output tensor of shape (N, C, H, W, 2) of 2D sample coordinates or (N, C, D, H, W, 3) of 3D sample coordinates.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T1</tt> : tensor(bfloat16), tensor(float16), tensor(float), tensor(double)</dt>
<dd>Constrain grid types to float tensors.</dd>
<dt><tt>T2</tt> : tensor(int64)</dt>
<dd>Constrain size's type to int64 tensors.</dd>
</dl>

### <a name="ConstantOfShape-20"></a>**ConstantOfShape-20**</a>

Generate a tensor with given value and shape.
Expand Down
134 changes: 134 additions & 0 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ For an operator input/output's differentiability, it can be differentiable,
|<a href="#Where">Where</a>|<a href="Changelog.md#Where-16">16</a>, <a href="Changelog.md#Where-9">9</a>|
|<a href="#Xor">Xor</a>|<a href="Changelog.md#Xor-7">7</a>, <a href="Changelog.md#Xor-1">1</a>|
|**Function**|**Since version**|**Function version**|
|<a href="#AffineGrid">AffineGrid</a>|<a href="Changelog.md#AffineGrid-20">20</a>|20|
|<a href="#Bernoulli">Bernoulli</a>|<a href="Changelog.md#Bernoulli-15">15</a>|15|
|<a href="#BlackmanWindow">BlackmanWindow</a>|<a href="Changelog.md#BlackmanWindow-17">17</a>|17|
|<a href="#CastLike">CastLike</a>|<a href="Changelog.md#CastLike-19">19</a>, <a href="Changelog.md#CastLike-15">15</a>|19|
Expand Down Expand Up @@ -487,6 +488,139 @@ expect(node, inputs=[x, y], outputs=[x + y], name="test_add_uint8")
</details>


### <a name="AffineGrid"></a><a name="affinegrid">**AffineGrid**</a>

Generates a 2D or 3D flow field (sampling grid), given a batch of affine matrices theta
(https://pytorch.org/docs/stable/generated/torch.nn.functional.affine_grid.html).
An affine matrix `theta` is applied to a position tensor represented in its homogeneous expression. Here is an example in 3D:
```
[r00, r01, r02, t0] [x] [x']
[r10, r11, r12, t1] * [y] = [y']
[r20, r21, r22, t2] [z] [z']
[0, 0, 0, 1 ] [1] [1 ]
```
where `(x, y, z)` is the position in the original space, `(x', y', z')` is the position in the output space.
The last row is always `[0, 0, 0, 1]` and is not stored in the affine matrix. Therefore we have `theta` of shape `(N, 2, 3)` for 2D or `(N, 3, 4)` for 3D.

Input `size` is used to define grid of positions evenly spaced in the original 2D or 3D space, with dimensions ranging from `-1` to `1`.
The output `grid` contains positions in the output space.

When `align_corners=1`, consider `-1` and `1` to refer to the centers of the corner pixels (mark `v` in illustration).
```
v v v v
|-------------------|------------------|
-1 0 1
```
When `align_corners=0`, consider `-1` and `1` to refer to the outer edge of the corner pixels.
```
v v v v
|------------------|-------------------|
-1 0 1
```

#### Version

This version of the operator has been available since version 20 of the default ONNX operator set.

#### Attributes

<dl>
<dt><tt>align_corners</tt> : int (default is 0)</dt>
<dd>if align_corners=1, consider -1 and 1 to refer to the centers of the corner pixels. if align_corners=0, consider -1 and 1 to refer to the outer edge the corner pixels.</dd>
</dl>

#### Inputs

<dl>
<dt><tt>theta</tt> (non-differentiable) : T1</dt>
<dd>input batch of affine matrices with shape (N, 2, 3) for 2D or (N, 3, 4) for 3D</dd>
<dt><tt>size</tt> (non-differentiable) : T2</dt>
<dd>the target output image size (N, C, H, W) for 2D or (N, C, D, H, W) for 3D</dd>
</dl>

#### Outputs

<dl>
<dt><tt>grid</tt> (differentiable) : T1</dt>
<dd>output tensor of shape (N, C, H, W, 2) of 2D sample coordinates or (N, C, D, H, W, 3) of 3D sample coordinates.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T1</tt> : tensor(bfloat16), tensor(float16), tensor(float), tensor(double)</dt>
<dd>Constrain grid types to float tensors.</dd>
<dt><tt>T2</tt> : tensor(int64)</dt>
<dd>Constrain size's type to int64 tensors.</dd>
</dl>


#### Examples

<details>
<summary>2d_no_reference_evaluator</summary>

```python
theta_2d = create_theta_2d()
N, C, W, H = len(theta_2d), 3, 5, 6
data_size = (W, H)
for align_corners in (0, 1):
node = onnx.helper.make_node(
"AffineGrid",
inputs=["theta", "size"],
outputs=["grid"],
align_corners=align_corners,
)

original_grid = construct_original_grid(data_size, align_corners)
grid = apply_affine_transform(theta_2d, original_grid)

test_name = "test_affine_grid_2d"
if align_corners == 1:
test_name += "_align_corners"
expect(
node,
inputs=[theta_2d, np.array([N, C, W, H], dtype=np.int64)],
outputs=[grid],
name=test_name,
)
```

</details>


<details>
<summary>3d_no_reference_evaluator</summary>

```python
theta_3d = create_theta_3d()
N, C, D, W, H = len(theta_3d), 3, 4, 5, 6
data_size = (D, W, H)
for align_corners in (0, 1):
node = onnx.helper.make_node(
"AffineGrid",
inputs=["theta", "size"],
outputs=["grid"],
align_corners=align_corners,
)

original_grid = construct_original_grid(data_size, align_corners)
grid = apply_affine_transform(theta_3d, original_grid)

test_name = "test_affine_grid_3d"
if align_corners == 1:
test_name += "_align_corners"
expect(
node,
inputs=[theta_3d, np.array([N, C, D, W, H], dtype=np.int64)],
outputs=[grid],
name=test_name,
)
```

</details>


### <a name="And"></a><a name="and">**And**</a>

Returns the tensor resulted from performing the `and` logical operation
Expand Down
66 changes: 65 additions & 1 deletion docs/TestCoverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* [Overall Test Coverage](#overall-test-coverage)
# Node Test Coverage
## Summary
Node tests have covered 175/188 (93.09%, 5 generators excluded) common operators.
Node tests have covered 176/189 (93.12%, 5 generators excluded) common operators.

Node tests have covered 0/0 (N/A) experimental operators.

Expand Down Expand Up @@ -344,6 +344,70 @@ expect(node, inputs=[x, y], outputs=[x + y], name="test_add_uint8")
</details>


### AffineGrid
There are 2 test cases, listed as following:
<details>
<summary>2d_no_reference_evaluator</summary>

```python
theta_2d = create_theta_2d()
N, C, W, H = len(theta_2d), 3, 5, 6
data_size = (W, H)
for align_corners in (0, 1):
node = onnx.helper.make_node(
"AffineGrid",
inputs=["theta", "size"],
outputs=["grid"],
align_corners=align_corners,
)

original_grid = construct_original_grid(data_size, align_corners)
grid = apply_affine_transform(theta_2d, original_grid)

test_name = "test_affine_grid_2d"
if align_corners == 1:
test_name += "_align_corners"
expect(
node,
inputs=[theta_2d, np.array([N, C, W, H], dtype=np.int64)],
outputs=[grid],
name=test_name,
)
```

</details>
<details>
<summary>3d_no_reference_evaluator</summary>

```python
theta_3d = create_theta_3d()
N, C, D, W, H = len(theta_3d), 3, 4, 5, 6
data_size = (D, W, H)
for align_corners in (0, 1):
node = onnx.helper.make_node(
"AffineGrid",
inputs=["theta", "size"],
outputs=["grid"],
align_corners=align_corners,
)

original_grid = construct_original_grid(data_size, align_corners)
grid = apply_affine_transform(theta_3d, original_grid)

test_name = "test_affine_grid_3d"
if align_corners == 1:
test_name += "_align_corners"
expect(
node,
inputs=[theta_3d, np.array([N, C, D, W, H], dtype=np.int64)],
outputs=[grid],
name=test_name,
)
```

</details>


### And
There are 2 test cases, listed as following:
<details>
Expand Down