Skip to content

Commit

Permalink
[Training] Add Adagrad optimizer operator (#1955)
Browse files Browse the repository at this point in the history
* Adagrad draft

* MIMO

* Support multiple tensors to be optimized

* Address comments

* Move optimizers to a new place

Remove copied

Add momentum

Save

Remove momentum

Fix

Move constants to attributes

* Fix build

* Add shape test

Add two node tests

Update test coverage

* Fix shape inf

* Fix shape inf

* fix shape inf

* Format

* Add function type

* Merge lines

* Format

* Fix version number

* Update op version in model files

* Fix a test function and update related test files

* Update onnx/backend/test/case/node/adagrad.py

* Remove unused file

* sync docs

* Fix shape test

* sync doc

* sync with master

* Update onnx/defs/training/defs.cc

Co-Authored-By: Michał Karzyński <postrational@users.noreply.github.com>

* sync doc

* address comments

* address a minor comment

* Polish one line

Co-authored-by: Michał Karzyński <postrational@users.noreply.github.com>
  • Loading branch information
wschin and postrational committed Mar 11, 2020
1 parent 94b01cd commit 8d15705
Show file tree
Hide file tree
Showing 30 changed files with 662 additions and 5 deletions.
99 changes: 98 additions & 1 deletion docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -14997,7 +14997,7 @@ This version of the operator has been available since version 12 of the default
<br/>
Where number of blocks extracted from each spatial dimension d is:
```
num_blocks[d] = floor((input_spatial_shape[d] + 2 * padding[d] dilation[d] * (kernel_size[d] 1) 1) / stride[d]) + 1
num_blocks[d] = floor((input_spatial_shape[d] + 2 * padding[d] - dilation[d] * (kernel_size[d] - 1) - 1) / stride[d]) + 1
```

#### Version
Expand Down Expand Up @@ -15040,6 +15040,103 @@ This version of the operator has been available since version 12 of the default

# ai.onnx.training
## Version 1 of the 'ai.onnx.training' operator set
### <a name="ai.onnx.training.Adagrad-1"></a>**ai.onnx.training.Adagrad-1**</a>

Compute one iteration of ADAGRAD, a stochastic gradient based optimization
algorithm. This operator can conduct the optimization of multiple tensor variables.

Let's define the behavior of this operator. As you can imagine, ADAGRAD requires
some parameters:

- The initial learning-rate "R".
- The update count "T". That is, the number of training iterations conducted.
- A L2-norm regularization coefficient "norm_coefficient".
- A learning-rate decay factor "decay_factor".
- A small constant "epsilon" to avoid dividing-by-zero.

At each ADAGRAD iteration, the optimized tensors are moved along a direction
computed based on their estimated gradient and accumulated squared gradient. Assume
that only a single tensor "X" is updated by this operator. We need the value of "X",
its gradient "G", and its accumulated squared gradient "H". Therefore, variables in
this operator's input list are sequentially "R", "T", "X", "G", and "H". Other
parameters are given as attributes because they are usually constants. Also, the
corresponding output tensors are the new value of "X" (called "X_new"), and then
the new accumulated squared gradient (called "H_new"). Those outputs are computed
from the given inputs following the pseudo code below.

Let "+", "-", "*", and "/" are all element-wise arithmetic operations with
numpy-style broadcasting support. The pseudo code to compute those outputs is:

// Compute a scalar learning-rate factor. At the first update of X, T is generally
// 0 (0-based update index) or 1 (1-based update index).
r = R / (1 + T * decay_factor);

// Add gradient of 0.5 * norm_coefficient * ||X||_2^2, where ||X||_2 is the 2-norm.
G_regularized = norm_coefficient * X + G;

// Compute new accumulated squared gradient.
H_new = H + G_regularized * G_regularized;

// Compute the adaptive part of per-coordinate learning rate. Note that Sqrt(...)
// computes element-wise square-root.
H_adaptive = Sqrt(H_new) + epsilon

// Compute the new value of "X".
X_new = X - r * G_regularized / H_adaptive;

If one assign this operators to optimize multiple inputs, for example, "X_1" and "X_2", the same
pseudo code may be extended to handle all tensors jointly. More specifically, we can view "X" as a
concatenation of "X_1" and "X_2" (of course, their gradient and accumulate gradient should
be concatenated too) and then just reuse the entire pseudo code.

Note that ADAGRAD was first proposed in http://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.
In that reference paper, this operator is a special case of the Figure 1's composite mirror
descent update.

#### Version

This version of the operator has been available since version 1 of the 'ai.onnx.training' operator set.

#### Attributes

<dl>
<dt><tt>decay_factor</tt> : float (default is 0.0)</dt>
<dd>The decay factor of learning rate after one update.The effective learning rate is computed by r = R / (1 + T * decay_factor). Default to 0 so that increasing update counts doesn't reduce the learning rate.</dd>
<dt><tt>epsilon</tt> : float (default is 0.0)</dt>
<dd>Small scalar to avoid dividing by zero.</dd>
<dt><tt>norm_coefficient</tt> : float (default is 0.0)</dt>
<dd>Regularization coefficient in 0.5 * norm_coefficient * ||X||_2^2. Default to 0, which means no regularization.</dd>
</dl>

#### Inputs (3 - &#8734;)

<dl>
<dt><tt>R</tt> : T1</dt>
<dd>The initial learning rate.</dd>
<dt><tt>T</tt> : T2</dt>
<dd>The update count of "X". It should be a scalar.</dd>
<dt><tt>inputs</tt> (variadic, heterogeneous) : T3</dt>
<dd>The current values of optimized tensors, followed by their respective gradients, followed by their respective accumulated squared gradients.For example, if two tensor "X_1" and "X_2" are optimized, The input list would be ["X_1", "X_2", gradient of "X_1", gradient of "X_2", accumulated squared gradient of "X_1", accumulated squared gradient of "X_2"].</dd>
</dl>

#### Outputs (1 - &#8734;)

<dl>
<dt><tt>outputs</tt> (variadic, heterogeneous) : T3</dt>
<dd>Updated values of optimized tensors, followed by their updated values of accumulated squared gradients. For example, if two tensor "X_1" and "X_2" are optimized, the output list would be [new value of "X_1," new value of "X_2" new accumulated squared gradient of "X_1", new accumulated squared gradient of "X_2"].</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T1</tt> : tensor(float), tensor(double)</dt>
<dd>Constrain input types to float scalars.</dd>
<dt><tt>T2</tt> : tensor(int64)</dt>
<dd>Constrain input types to 64-bit integer scalars.</dd>
<dt><tt>T3</tt> : tensor(float), tensor(double)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>

### <a name="ai.onnx.training.Gradient-1"></a>**ai.onnx.training.Gradient-1**</a>

Gradient operator computes the partial derivatives of a specific tensor w.r.t.
Expand Down
189 changes: 188 additions & 1 deletion docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
* <a href="#Range">Range</a>
* <a href="#SoftmaxCrossEntropyLoss">SoftmaxCrossEntropyLoss</a>
* ai.onnx.training
* <a href="#ai.onnx.training.Adagrad">ai.onnx.training.Adagrad</a>
* <a href="#ai.onnx.training.Gradient">ai.onnx.training.Gradient</a>
* <a href="#ai.onnx.training.GraphCall">ai.onnx.training.GraphCall</a>

Expand Down Expand Up @@ -20122,7 +20123,7 @@ expect(node, inputs=[data], outputs=[transposed],
<br/>
Where number of blocks extracted from each spatial dimension d is:
```
num_blocks[d] = floor((input_spatial_shape[d] + 2 * padding[d] dilation[d] * (kernel_size[d] 1) 1) / stride[d]) + 1
num_blocks[d] = floor((input_spatial_shape[d] + 2 * padding[d] - dilation[d] * (kernel_size[d] - 1) - 1) / stride[d]) + 1
```

#### Version
Expand Down Expand Up @@ -20949,6 +20950,192 @@ expect(node, inputs=[x, y], outputs=[z],


## ai.onnx.training
### <a name="ai.onnx.training.Adagrad"></a><a name="ai.onnx.training.adagrad">**ai.onnx.training.Adagrad**</a>

Compute one iteration of ADAGRAD, a stochastic gradient based optimization
algorithm. This operator can conduct the optimization of multiple tensor variables.

Let's define the behavior of this operator. As you can imagine, ADAGRAD requires
some parameters:

- The initial learning-rate "R".
- The update count "T". That is, the number of training iterations conducted.
- A L2-norm regularization coefficient "norm_coefficient".
- A learning-rate decay factor "decay_factor".
- A small constant "epsilon" to avoid dividing-by-zero.

At each ADAGRAD iteration, the optimized tensors are moved along a direction
computed based on their estimated gradient and accumulated squared gradient. Assume
that only a single tensor "X" is updated by this operator. We need the value of "X",
its gradient "G", and its accumulated squared gradient "H". Therefore, variables in
this operator's input list are sequentially "R", "T", "X", "G", and "H". Other
parameters are given as attributes because they are usually constants. Also, the
corresponding output tensors are the new value of "X" (called "X_new"), and then
the new accumulated squared gradient (called "H_new"). Those outputs are computed
from the given inputs following the pseudo code below.

Let "+", "-", "*", and "/" are all element-wise arithmetic operations with
numpy-style broadcasting support. The pseudo code to compute those outputs is:

// Compute a scalar learning-rate factor. At the first update of X, T is generally
// 0 (0-based update index) or 1 (1-based update index).
r = R / (1 + T * decay_factor);

// Add gradient of 0.5 * norm_coefficient * ||X||_2^2, where ||X||_2 is the 2-norm.
G_regularized = norm_coefficient * X + G;

// Compute new accumulated squared gradient.
H_new = H + G_regularized * G_regularized;

// Compute the adaptive part of per-coordinate learning rate. Note that Sqrt(...)
// computes element-wise square-root.
H_adaptive = Sqrt(H_new) + epsilon

// Compute the new value of "X".
X_new = X - r * G_regularized / H_adaptive;

If one assign this operators to optimize multiple inputs, for example, "X_1" and "X_2", the same
pseudo code may be extended to handle all tensors jointly. More specifically, we can view "X" as a
concatenation of "X_1" and "X_2" (of course, their gradient and accumulate gradient should
be concatenated too) and then just reuse the entire pseudo code.

Note that ADAGRAD was first proposed in http://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf.
In that reference paper, this operator is a special case of the Figure 1's composite mirror
descent update.

#### Version

This version of the operator has been available since version 1 of the 'ai.onnx.training' operator set.

#### Attributes

<dl>
<dt><tt>decay_factor</tt> : float (default is 0.0)</dt>
<dd>The decay factor of learning rate after one update.The effective learning rate is computed by r = R / (1 + T * decay_factor). Default to 0 so that increasing update counts doesn't reduce the learning rate.</dd>
<dt><tt>epsilon</tt> : float (default is 0.0)</dt>
<dd>Small scalar to avoid dividing by zero.</dd>
<dt><tt>norm_coefficient</tt> : float (default is 0.0)</dt>
<dd>Regularization coefficient in 0.5 * norm_coefficient * ||X||_2^2. Default to 0, which means no regularization.</dd>
</dl>

#### Inputs (3 - &#8734;)

<dl>
<dt><tt>R</tt> : T1</dt>
<dd>The initial learning rate.</dd>
<dt><tt>T</tt> : T2</dt>
<dd>The update count of "X". It should be a scalar.</dd>
<dt><tt>inputs</tt> (variadic, heterogeneous) : T3</dt>
<dd>The current values of optimized tensors, followed by their respective gradients, followed by their respective accumulated squared gradients.For example, if two tensor "X_1" and "X_2" are optimized, The input list would be ["X_1", "X_2", gradient of "X_1", gradient of "X_2", accumulated squared gradient of "X_1", accumulated squared gradient of "X_2"].</dd>
</dl>

#### Outputs (1 - &#8734;)

<dl>
<dt><tt>outputs</tt> (variadic, heterogeneous) : T3</dt>
<dd>Updated values of optimized tensors, followed by their updated values of accumulated squared gradients. For example, if two tensor "X_1" and "X_2" are optimized, the output list would be [new value of "X_1," new value of "X_2" new accumulated squared gradient of "X_1", new accumulated squared gradient of "X_2"].</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T1</tt> : tensor(float), tensor(double)</dt>
<dd>Constrain input types to float scalars.</dd>
<dt><tt>T2</tt> : tensor(int64)</dt>
<dd>Constrain input types to 64-bit integer scalars.</dd>
<dt><tt>T3</tt> : tensor(float), tensor(double)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>


#### Examples

<details>
<summary>adagrad</summary>

```python
# Define operator attributes.
norm_coefficient = 0.001
epsilon = 1e-5
decay_factor = 0.1

# Create operator.
node = onnx.helper.make_node('Adagrad',
inputs=['R', 'T', 'X', 'G', 'H'],
outputs=['X_new', 'H_new'],
norm_coefficient=norm_coefficient,
epsilon=epsilon,
decay_factor=decay_factor,
domain='ai.onnx.training'
)

# Define operator inputs.
r = np.array(0.1, dtype=np.float32) # scalar
t = np.array(0, dtype=np.int64) # scalar
x = np.array([1.0], dtype=np.float32)
g = np.array([-1.0], dtype=np.float32)
h = np.array([2.0], dtype=np.float32)

# Compute expected outputs of Adagrad.
x_new, h_new = apply_adagrad(r, t, x, g, h,
norm_coefficient, epsilon, decay_factor)

# Check results.
expect(node, inputs=[r, t, x, g, h],
outputs=[x_new, h_new], name='test_adagrad',
opset_imports=[onnx.helper.make_opsetid('ai.onnx.training', 1)])
```

</details>


<details>
<summary>adagrad_multiple</summary>

```python
# Define operator attributes.
norm_coefficient = 0.001
epsilon = 1e-5
decay_factor = 0.1

node = onnx.helper.make_node('Adagrad',
inputs=['R', 'T', 'X1', 'X2',
'G1', 'G2', 'H1', 'H2'],
outputs=['X1_new', 'X2_new',
'H1_new', 'H2_new'],
norm_coefficient=norm_coefficient,
epsilon=epsilon,
decay_factor=decay_factor,
domain='ai.onnx.training'
)

# Define operator inputs.
r = np.array(0.1, dtype=np.float32) # scalar
t = np.array(0, dtype=np.int64) # scalar

x1 = np.array([1.0], dtype=np.float32)
g1 = np.array([-1.0], dtype=np.float32)
h1 = np.array([2.0], dtype=np.float32)

x2 = np.array([1.0, 2.0], dtype=np.float32)
g2 = np.array([-1.0, -3.0], dtype=np.float32)
h2 = np.array([4.0, 1.0], dtype=np.float32)

# Compute expected outputs of Adagrad.
x1_new, h1_new = apply_adagrad(r, t, x1, g1, h1,
norm_coefficient, epsilon, decay_factor)
x2_new, h2_new = apply_adagrad(r, t, x2, g2, h2,
norm_coefficient, epsilon, decay_factor)

# Check results.
expect(node, inputs=[r, t, x1, x2, g1, g2, h1, h2],
outputs=[x1_new, x2_new, h1_new, h2_new], name='test_adagrad_multiple',
opset_imports=[onnx.helper.make_opsetid('ai.onnx.training', 1)])
```

</details>


### <a name="ai.onnx.training.Gradient"></a><a name="ai.onnx.training.gradient">**ai.onnx.training.Gradient**</a>

Gradient operator computes the partial derivatives of a specific tensor w.r.t.
Expand Down

0 comments on commit 8d15705

Please sign in to comment.