Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Training] Add Adagrad optimizer operator #1955

Merged
merged 44 commits into from Mar 11, 2020
Merged

[Training] Add Adagrad optimizer operator #1955

merged 44 commits into from Mar 11, 2020

Conversation

wschin
Copy link
Contributor

@wschin wschin commented Apr 21, 2019

PR #2314 is a single place for reviewing the whole training story.

This PR proposes a stateless optimizer based on a stochastic gradient method. This operator covers Tensorflow's and Pytorch's ADAGRAD optimizer.

Optimizer itself can be defined as a stateful operator which, for example, maintain accumulated momentum and accumulated squared gradient. However, this way might implicitly introduce assignment semantic into ONNX and therefore breaks SSA assumption.

Below is a script I used to verify the Adagrad function added in this PR's test. It shows that consistent results in TF, Pytorch, and the Adagrad function used to create ONNX tests in this PR.

import numpy as np
import torch
import tensorflow as tf
from torch import optim
from torch import nn
import torch.random as rnd

lr=0.2
lr_decay=0.0
weight_decay=0.0
n = 10
l = 5
X_ = torch.randn(l, n)
Y_ = torch.randn(l, 1)

def apply_adagrad(r, t, x, g, h, norm_coefficient, epsilon, decay_factor):  # type: ignore
    # Compute adjusted learning-rate.
    r_ = r / (1 + t * decay_factor)
    # Add gradient of regularization term.
    g_regularized = norm_coefficient * x + g
    # Update squared accumulated gradient.
    h_new = h + g_regularized * g_regularized
    # Compute ADAGRAD's gradient scaling factors
    h_sqrt = np.sqrt(h_new) + epsilon
    # Apply ADAGRAD update rule.
    x_new = x - r_ * g_regularized / h_sqrt
    return (x_new, h_new)

def show_pytorch():
    X = X_.clone()
    Y = Y_.clone()

    rnd.manual_seed(0)
    model = nn.Sequential(
        nn.Linear(n, 1, bias=False)
    )

    loss_fn = nn.MSELoss(reduction='sum')

    solver = optim.Adagrad(model.parameters(), lr=lr, lr_decay=lr_decay, weight_decay=weight_decay)

    for t in range(8):
        Y_pred = model(X)
        loss = loss_fn(Y_pred, Y)
        model.zero_grad()
        loss.backward()
        solver.step()
        print(loss)

def show_tensorflow():
    rnd.manual_seed(0)
    layer = nn.Linear(n, 1, bias=False)

    X = tf.placeholder('float', shape=[l, n])
    W = tf.Variable(torch.Tensor.numpy(layer.weight.detach()))
    Y = tf.placeholder('float', shape=[l, 1])
    Y_pred = tf.matmul(X, W, transpose_b=True)
    loss = tf.reduce_sum(tf.square(Y - Y_pred))
    optimizer = tf.train.AdagradOptimizer(learning_rate=lr, initial_accumulator_value=0.0000000001)
    minimizer = optimizer.minimize(loss)

    sess = tf.Session()

    init = tf.global_variables_initializer()

    sess.run(init)

    for t in range(8):
        print(sess.run(loss, {X: X_, Y: Y_}))
        result = sess.run([minimizer], {X: X_, Y: Y_})


def show_onnx():
    X = X_.clone()
    Y = Y_.clone()

    rnd.manual_seed(0)
    model = nn.Sequential(
        nn.Linear(n, 1, bias=False)
    )

    loss_fn = nn.MSELoss(reduction='sum')

    solver = optim.Adagrad(model.parameters(), lr=lr, lr_decay=lr_decay, weight_decay=weight_decay)

    for t in range(8):
        Y_pred = model(X)
        loss = loss_fn(Y_pred, Y)
        model.zero_grad()
        loss.backward()

        with torch.no_grad():
            for param in model.parameters():
                new_tensor, new_state = apply_adagrad(r=lr, t=t, x=param.data, g=param.grad.data, h=solver.state[param]['sum'].data,
                                                      norm_coefficient=weight_decay, epsilon=1e-10, decay_factor=lr_decay)
                solver.state[param]['sum'].data = new_state.data
                param.data = new_tensor.data

        print(loss)

show_tensorflow()
print('---------')
show_pytorch()
print('---------')
show_onnx()

@wschin wschin requested a review from a team as a code owner April 21, 2019 23:19
@wschin wschin changed the title Adagrad optimizer draft [WIP] Adagrad optimizer draft Apr 21, 2019
onnx/defs/controlflow/defs.cc Outdated Show resolved Hide resolved
onnx/defs/controlflow/defs.cc Outdated Show resolved Hide resolved
onnx/defs/controlflow/defs.cc Outdated Show resolved Hide resolved
@SherlockNoMad
Copy link
Contributor

I echo that having a stateless optimizer is a good design with existing ONNX without breaking SSA assumption.

Different backends can come up with their own implementation on how to handle the new state and update the weights.

onnx/defs/controlflow/defs.cc Outdated Show resolved Hide resolved
@gramalingam
Copy link
Contributor

Hi, Is there a plan to add the complete Adagrad optimization (loop) as another operator? The proposed operator seems only one part of the solution. It would help to understand how the whole thing would work.

@wschin
Copy link
Contributor Author

wschin commented Apr 25, 2019

Hi, Is there a plan to add the complete Adagrad optimization (loop) as another operator? The proposed operator seems only one part of the solution. It would help to understand how the whole thing would work.

There are two phases of supporting training in ONNX.

  1. Training the model with one iteration.
  2. Training the model with the whole data set.

Those operators are defined for Phase 1 and will be used in Phase 2 to compose multi-iteration algorithms. Several difficults that we can't go to Phase 2 directly are

  1. Defining a full training algorithm is similar to writing a program.
  2. Defining a full training algorithm may put extra constraints on the inference stage. For example, if a written training algorithm wants to run 100 iterations, user will have to feed 100 batches into ONNXRuntime.
  3. If we can define one training iteration properly, conceptually users can execute multiple iterations.

Remove copied

Add momentum

Save

Remove momentum

Fix

Move constants to attributes
Add two node tests

Update test coverage
@wschin wschin changed the title [WIP] Adagrad optimizer draft Adagrad optimizer draft Apr 29, 2019
@wschin wschin changed the title Adagrad optimizer draft [Training] Add an optimizer operator Apr 30, 2019
@wschin wschin removed this from the 1.7 milestone Feb 26, 2020
Copy link
Contributor

@postrational postrational left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall the proposal looks very good. It seems that this operator could be composed as a Function, could we add a FunctionBody to its definition?

onnx/defs/training/defs.cc Outdated Show resolved Hide resolved
onnx/defs/training/defs.cc Show resolved Hide resolved
Copy link
Contributor

@postrational postrational left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Transforming this to a function may be problematic at this stage due to limitations of the Split and Function.
We should use this op as a test case for expanding the Function and move this op to a function in the future.

wschin and others added 4 commits March 6, 2020 11:12
Co-Authored-By: Michał Karzyński <postrational@users.noreply.github.com>
Copy link
Contributor

@gramalingam gramalingam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks!

@wschin wschin merged commit 8d15705 into onnx:master Mar 11, 2020
chinhuang007 added a commit that referenced this pull request Mar 11, 2020
* Fix Greater/LessOrEqual function definition (#2645)

* Fix Greater/LessOrEqual function definition

* Update test data

Co-authored-by: Ke Zhang <kezhan@microsoft.com>

* Suppress a warning in unsqueeze (#2637)

I keep getting this warning when building PyTorch:

```
In file included from
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/utils.h:6,
                 from
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/defs.cc:4:
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/defs.cc: In
lambda function:
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/defs.cc:1414:22:
warning: unnecessary parentheses in declaration of �i�
[-Wparentheses]
           for (size_t(i) = 0; i < axes.size(); ++i) {
                      ^
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/schema.h:959:12:
note: in definition of macro �ONNX_OPERATOR_SET_SCHEMA_EX�
     return impl.SetName(#name)
\
            ^~~~
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/defs.cc:1369:1:
note: in expansion of macro �ONNX_OPERATOR_SET_SCHEMA�
 ONNX_OPERATOR_SET_SCHEMA(
```

This commit should fix it and modernize the code a bit.

Co-authored-by: Ke Zhang <kezhan@microsoft.com>

* [Training] Add Adagrad optimizer operator (#1955)

* Adagrad draft

* MIMO

* Support multiple tensors to be optimized

* Address comments

* Move optimizers to a new place

Remove copied

Add momentum

Save

Remove momentum

Fix

Move constants to attributes

* Fix build

* Add shape test

Add two node tests

Update test coverage

* Fix shape inf

* Fix shape inf

* fix shape inf

* Format

* Add function type

* Merge lines

* Format

* Fix version number

* Update op version in model files

* Fix a test function and update related test files

* Update onnx/backend/test/case/node/adagrad.py

* Remove unused file

* sync docs

* Fix shape test

* sync doc

* sync with master

* Update onnx/defs/training/defs.cc

Co-Authored-By: Michał Karzyński <postrational@users.noreply.github.com>

* sync doc

* address comments

* address a minor comment

* Polish one line

Co-authored-by: Michał Karzyński <postrational@users.noreply.github.com>

* Add optimizer PR to release branch

This PR is to add the newly merged optimizer PR to the
release 1.7 branch and set a temporary version number
for unit test in TestPypi.

Co-authored-by: Takeshi Watanabe <take-cheeze@users.noreply.github.com>
Co-authored-by: Ke Zhang <kezhan@microsoft.com>
Co-authored-by: Hong Xu <hong@topbug.net>
Co-authored-by: Wei-Sheng Chin <wschin@outlook.com>
Co-authored-by: Michał Karzyński <postrational@users.noreply.github.com>
@chinhuang007 chinhuang007 added this to the 1.7 milestone Mar 12, 2020
linkerzhang added a commit that referenced this pull request Mar 31, 2020
* Fix Greater/LessOrEqual function definition (#2645)

* Fix Greater/LessOrEqual function definition

* Update test data

Co-authored-by: Ke Zhang <kezhan@microsoft.com>

* Suppress a warning in unsqueeze (#2637)

I keep getting this warning when building PyTorch:

```
In file included from
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/utils.h:6,
                 from
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/defs.cc:4:
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/defs.cc: In
lambda function:
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/defs.cc:1414:22:
warning: unnecessary parentheses in declaration of �i�
[-Wparentheses]
           for (size_t(i) = 0; i < axes.size(); ++i) {
                      ^
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/schema.h:959:12:
note: in definition of macro �ONNX_OPERATOR_SET_SCHEMA_EX�
     return impl.SetName(#name)
\
            ^~~~
/home/hong/wsrc/pytorch/third_party/onnx/onnx/defs/tensor/defs.cc:1369:1:
note: in expansion of macro �ONNX_OPERATOR_SET_SCHEMA�
 ONNX_OPERATOR_SET_SCHEMA(
```

This commit should fix it and modernize the code a bit.

Co-authored-by: Ke Zhang <kezhan@microsoft.com>

* [Training] Add Adagrad optimizer operator (#1955)

* Adagrad draft

* MIMO

* Support multiple tensors to be optimized

* Address comments

* Move optimizers to a new place

Remove copied

Add momentum

Save

Remove momentum

Fix

Move constants to attributes

* Fix build

* Add shape test

Add two node tests

Update test coverage

* Fix shape inf

* Fix shape inf

* fix shape inf

* Format

* Add function type

* Merge lines

* Format

* Fix version number

* Update op version in model files

* Fix a test function and update related test files

* Update onnx/backend/test/case/node/adagrad.py

* Remove unused file

* sync docs

* Fix shape test

* sync doc

* sync with master

* Update onnx/defs/training/defs.cc

Co-Authored-By: Michał Karzyński <postrational@users.noreply.github.com>

* sync doc

* address comments

* address a minor comment

* Polish one line

Co-authored-by: Michał Karzyński <postrational@users.noreply.github.com>

* [Training] SG with Momentum Optimizer (#1959)

* SG with Momentum

* Registrate Op

Fix

Update other docs

* Add shape inference code and polish definition

* Update docs

* Add test cases and fix several bugs

* Remove accidently added copy

* Alpha -> alpha & Beta -> beta

* Clarify an attribute

* Fix an attribute

* Fix bug

* Fix missing attributes

* sync doc

* Remove unused domain

* sync with master

Co-authored-by: Chin Huang <chhuang@us.ibm.com>

* Change type of label tensor to int32/int64 in SoftmaxCrossEntropyLoss spec. (#2667)

* Update Pow input types in Opset 12 (#2666)

* Update Pow input types in Opset 12

* gen doc and tests

* remove uints and 8 bit ints

* add tests

* remove uint int x tets

* Adding CI for ONNX Debug mode (Linux, OSX) (#2651)

* adding an osx build, linux build, with and without onnx_ml for debug mode

* test debug mode with ONNX_ML=1

* Rename OPTIONAL to OPTIONAL_VALUE (#2682)

Co-authored-by: G. Ramalingam <grama@microsoft.com>

* Update Batchnorm test (#2674)

* Update Batchnorm test

* relax shape inference on scalar

* Remove unnecessary copies and std::move (#2684)

* Update sequence test case so input is not scalar and splits are specified (#2675)

* Update sequence test case to input is not scalar and splits are specified

* Add spaces to make the checker happy

* Use cmake GNUInstallDirs (#2661)

https://cmake.org/cmake/help/latest/module/GNUInstallDirs.html
this make allow install the libraries (and headers) in different location than `lib` (Gentoo uses lib64 for 64-bits libs)
also change the .cmake files for avoid conclicts if build both 32-bis and 64-bits (avoids conflict/overwrite files)

Co-authored-by: Ke Zhang <kezhan@microsoft.com>

* Add 'ignore_index' input in the spec for SoftmaxCrossEntropyLoss and NLLLoss. (#2680)

* Add 'ignore_index' input in the spec for SoftmaxCrossEntropyLoss and NLLLoss.

* Add tests.

* build break.

* build break.

* clean up.

* build break.

* Change ignore_index to attribute.

* Change ignore_index to attribute.

* PR feedback.

* PR feedback.

* Make ignore_index optional in NLLLoss.

* Build break.

* remove trailing spaces to fix build break.

* Build break.

* Update spec doc.

* Fix NLLLoss function definition to fix test: test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_ignore_index_expanded

* PR feedback.

* Fix test for softmax cross entropy loss to exclude ignored_index'ed weights from the sum of weights.

* Build break.

* Reduce binary size of libraries consuming ONNX (part 1/2) (#2643)

* Change the return type for the zipmap operator to match the description in the spec.

* Reduce binary size of libraries consuming ONNX (part 1/2)

* Fix build error

* Replace separate Get*Doc() functions with easy macro for greater convenience

* Add one more macro for complicated operator doc documentation.

Co-authored-by: Ke Zhang <kezhan@microsoft.com>

* Update pybind (#2340) (#2688)

* Change version number for release verification

Change version number for release verification

Co-authored-by: Takeshi Watanabe <take-cheeze@users.noreply.github.com>
Co-authored-by: Ke Zhang <kezhan@microsoft.com>
Co-authored-by: Hong Xu <hong@topbug.net>
Co-authored-by: Wei-Sheng Chin <wschin@outlook.com>
Co-authored-by: Michał Karzyński <postrational@users.noreply.github.com>
Co-authored-by: M. Zeeshan Siddiqui <mzs@microsoft.com>
Co-authored-by: Lara Haidar <haidar.lara@gmail.com>
Co-authored-by: Vinitra Swamy <vinitras@gmail.com>
Co-authored-by: Changming Sun <chasun@microsoft.com>
Co-authored-by: G. Ramalingam <grama@microsoft.com>
Co-authored-by: Changming Sun <me@sunchangming.com>
Co-authored-by: Scott McKay <skottmckay@gmail.com>
Co-authored-by: Gustavo Alvarez <462213+sl1pkn07@users.noreply.github.com>
Co-authored-by: Pranav Sharma <prs@microsoft.com>
jcwchen pushed a commit to jcwchen/onnx that referenced this pull request Sep 23, 2020
* Adagrad draft

* MIMO

* Support multiple tensors to be optimized

* Address comments

* Move optimizers to a new place

Remove copied

Add momentum

Save

Remove momentum

Fix

Move constants to attributes

* Fix build

* Add shape test

Add two node tests

Update test coverage

* Fix shape inf

* Fix shape inf

* fix shape inf

* Format

* Add function type

* Merge lines

* Format

* Fix version number

* Update op version in model files

* Fix a test function and update related test files

* Update onnx/backend/test/case/node/adagrad.py

* Remove unused file

* sync docs

* Fix shape test

* sync doc

* sync with master

* Update onnx/defs/training/defs.cc

Co-Authored-By: Michał Karzyński <postrational@users.noreply.github.com>

* sync doc

* address comments

* address a minor comment

* Polish one line

Co-authored-by: Michał Karzyński <postrational@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
operator Issues related to ONNX operators training Issues related to ONNX training
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants