-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add two node tests
- Loading branch information
Showing
26 changed files
with
228 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
from __future__ import unicode_literals | ||
|
||
import numpy as np # type: ignore | ||
|
||
import onnx | ||
from ..base import Base | ||
from . import expect | ||
|
||
def apply_adagrad(r, t, x, g, h, norm_coefficient, epsilon, decay_factor): | ||
# Compute adjusted learning-rate. | ||
r_ = r * (1 + t * decay_factor) | ||
# Add gradient of regularization term. | ||
g_regularized = norm_coefficient * x + g | ||
# Update squared accumulated gradient. | ||
h_new = h + g * g | ||
# Compute ADAGRAD's gradient scaling factors | ||
h_sqrt = np.sqrt(h_new) + epsilon | ||
# Apply ADAGRAD update rule. | ||
x_new = x - r_ * g_regularized / h_sqrt | ||
return (x_new, h_new) | ||
|
||
class Adagrad(Base): | ||
|
||
@staticmethod | ||
def export_adagrad(): # type: () -> None | ||
# Define operator attributes. | ||
norm_coefficient = 0.001 | ||
epsilon = 1e-5 | ||
decay_factor = 0.1 | ||
|
||
# Create operator. | ||
node = onnx.helper.make_node('Adagrad', | ||
inputs=['R', 'T', 'X', 'G', 'H'], | ||
outputs=['X_new', 'H_new'], | ||
norm_coefficient=norm_coefficient, | ||
epsilon=epsilon, | ||
decay_factor=decay_factor | ||
) | ||
|
||
# Define operator inputs. | ||
r = np.array(0.1, dtype=np.float32) # scalar | ||
t = np.array(0, dtype=np.int64) # scalar | ||
x = np.array([1.0], dtype=np.float32) | ||
g = np.array([-1.0], dtype=np.float32) | ||
h = np.array([2.0], dtype=np.float32) | ||
|
||
# Compute expected outputs of Adagrad. | ||
x_new, h_new = apply_adagrad(r, t, x, g, h, | ||
norm_coefficient, epsilon, decay_factor) | ||
|
||
# Check results. | ||
expect(node, inputs=[r, t, x, g, h], | ||
outputs=[x_new, h_new], name='test_adagrad') | ||
|
||
@staticmethod | ||
def export_adagrad_multiple(): # type: () -> None | ||
# Define operator attributes. | ||
norm_coefficient = 0.001 | ||
epsilon = 1e-5 | ||
decay_factor = 0.1 | ||
|
||
node = onnx.helper.make_node('Adagrad', | ||
inputs=['R', 'T', 'X1', 'X2', | ||
'G1', 'G2', 'H1', 'H2'], | ||
outputs=['X1_new', 'X2_new', | ||
'H1_new', 'H2_new'], | ||
norm_coefficient=norm_coefficient, | ||
epsilon=epsilon, | ||
decay_factor=decay_factor | ||
) | ||
|
||
# Define operator inputs. | ||
r = np.array(0.1, dtype=np.float32) # scalar | ||
t = np.array(0, dtype=np.int64) # scalar | ||
|
||
x1 = np.array([1.0], dtype=np.float32) | ||
g1 = np.array([-1.0], dtype=np.float32) | ||
h1 = np.array([2.0], dtype=np.float32) | ||
|
||
x2 = np.array([1.0, 2.0], dtype=np.float32) | ||
g2 = np.array([-1.0, -3.0], dtype=np.float32) | ||
h2 = np.array([4.0, 1.0], dtype=np.float32) | ||
|
||
# Compute expected outputs of Adagrad. | ||
x1_new, h1_new = apply_adagrad(r, t, x1, g1, h1, | ||
norm_coefficient, epsilon, decay_factor) | ||
x2_new, h2_new = apply_adagrad(r, t, x2, g2, h2, | ||
norm_coefficient, epsilon, decay_factor) | ||
|
||
# Check results. | ||
expect(node, inputs=[r, t, x1, x2, g1, g2, h1, h2], | ||
outputs=[x1_new, x2_new, h1_new, h2_new], name='test_adagrad_multiple') |
Binary file not shown.
1 change: 1 addition & 0 deletions
1
onnx/backend/test/data/node/test_adagrad/test_data_set_0/input_0.pb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
BRJ���= |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions
1
onnx/backend/test/data/node/test_adagrad/test_data_set_0/output_0.pb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
BX_newJ�a�? |
Binary file added
BIN
+17 Bytes
onnx/backend/test/data/node/test_adagrad/test_data_set_0/output_1.pb
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions
1
onnx/backend/test/data/node/test_adagrad_multiple/test_data_set_0/input_0.pb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
BRJ���= |
Binary file added
BIN
+15 Bytes
onnx/backend/test/data/node/test_adagrad_multiple/test_data_set_0/input_1.pb
Binary file not shown.
Binary file added
BIN
+14 Bytes
onnx/backend/test/data/node/test_adagrad_multiple/test_data_set_0/input_2.pb
Binary file not shown.
Binary file added
BIN
+18 Bytes
onnx/backend/test/data/node/test_adagrad_multiple/test_data_set_0/input_3.pb
Binary file not shown.
Binary file added
BIN
+14 Bytes
onnx/backend/test/data/node/test_adagrad_multiple/test_data_set_0/input_4.pb
Binary file not shown.
Binary file added
BIN
+18 Bytes
onnx/backend/test/data/node/test_adagrad_multiple/test_data_set_0/input_5.pb
Binary file not shown.
Binary file added
BIN
+14 Bytes
onnx/backend/test/data/node/test_adagrad_multiple/test_data_set_0/input_6.pb
Binary file not shown.
Binary file added
BIN
+18 Bytes
onnx/backend/test/data/node/test_adagrad_multiple/test_data_set_0/input_7.pb
Binary file not shown.
1 change: 1 addition & 0 deletions
1
onnx/backend/test/data/node/test_adagrad_multiple/test_data_set_0/output_0.pb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
BX1_newJ�a�? |
1 change: 1 addition & 0 deletions
1
onnx/backend/test/data/node/test_adagrad_multiple/test_data_set_0/output_1.pb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
BX2_newJ���?H@ |
Binary file added
BIN
+18 Bytes
onnx/backend/test/data/node/test_adagrad_multiple/test_data_set_0/output_2.pb
Binary file not shown.
Binary file added
BIN
+22 Bytes
onnx/backend/test/data/node/test_adagrad_multiple/test_data_set_0/output_3.pb
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters