Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][Training] Refactor Optimizer and Gradient #121

Merged
merged 13 commits into from
Feb 7, 2023

Conversation

Ubospica
Copy link
Contributor

@Ubospica Ubospica commented Feb 7, 2023

Update optimizer APIs.

  • Remove @property state and @state.setter
  • Add init() interface
  • Remove Optimizer.__call__()
  • Remove underscores before attributes, and unnecessary attributes

Current interfaces:

class Optimizer:
    dtype: str
    name: str
    param_list: List[Var]
    state: tvm.runtime.container.ADT

    def __init__(self, name: str) -> None:
        self.name = name
        self.param_list = None
        self.state = None
        self.dtype = None

    def init(self, params: Union[Var, List[Var]]) -> "Optimizer":
        """Set the parameters, determine the dtype, and build the initial state for the optimizer."""
		pass

    def get_function(self) -> Function:
        """Use blockbuilder to build an optimizer function that executes updates of the parameters
        and the optimizer state."""
		pass

Use examples:

See https://github.com/ACMClass-TVM-20/AD-Example/blob/dc255150dc6a4a6de2fffc2c093a8b2bacc1b030/optimizer_api_example.py

And also updates Gradient APIs:

  • Before: def Gradient(global_var: GlobalVar, require_grads: Optional[Union[Var, List[Var]]]) -> tvm.ir.transform.Pass
  • After: def Gradient(func_name: str, require_grads: Optional[Union[Var, List[Var]]]) -> tvm.ir.transform.Pass

Unit tests are changed accordingly.

@Ubospica Ubospica changed the title Update optimizer [Relax][Training] Update optimizer Feb 7, 2023
@Ubospica Ubospica changed the title [Relax][Training] Update optimizer [Relax][Training] Refactorize Optimizer and Gradient Feb 7, 2023
@Ubospica Ubospica changed the title [Relax][Training] Refactorize Optimizer and Gradient [Relax][Training] Refactor Optimizer and Gradient Feb 7, 2023
@MasterJH5574 MasterJH5574 merged commit 86768e6 into mlc-ai:relax Feb 7, 2023
MasterJH5574 pushed a commit that referenced this pull request Feb 8, 2023
MasterJH5574 pushed a commit that referenced this pull request Feb 8, 2023
Update optimizer APIs.
- Remove `@property state` and `@state.setter`
- Add `init()` interface
- Remove `Optimizer.__call__()`
- Remove underscores before attributes, and unnecessary attributes

Current interfaces:
```python
class Optimizer:
    dtype: str
    name: str
    param_list: List[Var]
    state: tvm.runtime.container.ADT

    def __init__(self, name: str) -> None:
        self.name = name
        self.param_list = None
        self.state = None
        self.dtype = None

    def init(self, params: Union[Var, List[Var]]) -> "Optimizer":
        """Set the parameters, determine the dtype, and build the initial state for the optimizer."""
		pass

    def get_function(self) -> Function:
        """Use blockbuilder to build an optimizer function that executes updates of the parameters
        and the optimizer state."""
		pass
```

Use examples:

See
<https://github.com/ACMClass-TVM-20/AD-Example/blob/dc255150dc6a4a6de2fffc2c093a8b2bacc1b030/optimizer_api_example.py>

And also updates Gradient APIs:
- Before: `def Gradient(global_var: GlobalVar, require_grads:
Optional[Union[Var, List[Var]]]) -> tvm.ir.transform.Pass`
- After: `def Gradient(func_name: str, require_grads:
Optional[Union[Var, List[Var]]]) -> tvm.ir.transform.Pass`

Unit tests are changed accordingly.
spectrometerHBH pushed a commit to spectrometerHBH/relax that referenced this pull request Feb 9, 2023
MasterJH5574 pushed a commit that referenced this pull request Feb 12, 2023
Update optimizer APIs.
- Remove `@property state` and `@state.setter`
- Add `init()` interface
- Remove `Optimizer.__call__()`
- Remove underscores before attributes, and unnecessary attributes

Current interfaces:
```python
class Optimizer:
    dtype: str
    name: str
    param_list: List[Var]
    state: tvm.runtime.container.ADT

    def __init__(self, name: str) -> None:
        self.name = name
        self.param_list = None
        self.state = None
        self.dtype = None

    def init(self, params: Union[Var, List[Var]]) -> "Optimizer":
        """Set the parameters, determine the dtype, and build the initial state for the optimizer."""
		pass

    def get_function(self) -> Function:
        """Use blockbuilder to build an optimizer function that executes updates of the parameters
        and the optimizer state."""
		pass
```

Use examples:

See
<https://github.com/ACMClass-TVM-20/AD-Example/blob/dc255150dc6a4a6de2fffc2c093a8b2bacc1b030/optimizer_api_example.py>

And also updates Gradient APIs:
- Before: `def Gradient(global_var: GlobalVar, require_grads:
Optional[Union[Var, List[Var]]]) -> tvm.ir.transform.Pass`
- After: `def Gradient(func_name: str, require_grads:
Optional[Union[Var, List[Var]]]) -> tvm.ir.transform.Pass`

Unit tests are changed accordingly.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants