Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Error in tutorials/3_Meta_Optimizer.ipynb #70

Closed
3 tasks done
xianghang opened this issue Sep 9, 2022 · 3 comments · Fixed by #71
Closed
3 tasks done

[BUG] Error in tutorials/3_Meta_Optimizer.ipynb #70

xianghang opened this issue Sep 9, 2022 · 3 comments · Fixed by #71
Assignees
Labels
bug Something isn't working

Comments

@xianghang
Copy link

Describe the bug

Error in running tutorials/3_Meta_Optimizer.ipynb cell [6] (the first cell under section 2.1 Basic API)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-9-ad332932d4e7>](https://localhost:8080/#) in <module>
     15 for i in range(2):
     16     inner_loss = net(x)
---> 17     optim.step(inner_loss)
     18 
     19 print(f'a = {net.a!r}')

6 frames
[/usr/local/lib/python3.7/dist-packages/torchopt/_src/optimizer/meta/base.py](https://localhost:8080/#) in step(self, loss)
     68                 new_state,
     69                 params=flattened_params,
---> 70                 inplace=False,
     71             )
     72             self.state_groups[i] = new_state

[/usr/local/lib/python3.7/dist-packages/torchopt/_src/transform.py](https://localhost:8080/#) in update_fn(updates, state, params, inplace)
     65 
     66         flattened_updates, state = inner.update(
---> 67             flattened_updates, state, params=params, inplace=inplace
     68         )
     69         updates = pytree.tree_unflatten(treedef, flattened_updates)

[/usr/local/lib/python3.7/dist-packages/torchopt/_src/base.py](https://localhost:8080/#) in update_fn(updates, state, params, inplace)
    181             new_state = []
    182             for s, fn in zip(state, update_fns):  # pylint: disable=invalid-name
--> 183                 updates, new_s = fn(updates, s, params=params, inplace=inplace)
    184                 new_state.append(new_s)
    185             return updates, tuple(new_state)

[/usr/local/lib/python3.7/dist-packages/torchopt/_src/transform.py](https://localhost:8080/#) in update_fn(updates, state, params, inplace)
    312     def update_fn(updates, state, *, params=None, inplace=True):  # pylint: disable=unused-argument
    313         mu = _update_moment(
--> 314             updates, state.mu, b1, order=1, inplace=inplace, already_flattened=already_flattened
    315         )
    316         nu = _update_moment(

[/usr/local/lib/python3.7/dist-packages/torchopt/_src/transform.py](https://localhost:8080/#) in _update_moment(updates, moments, decay, order, inplace, already_flattened)
    214 
    215     if already_flattened:
--> 216         return map_flattened(f, updates, moments)
    217     return pytree.tree_map(f, updates, moments)
    218 

[/usr/local/lib/python3.7/dist-packages/torchopt/_src/transform.py](https://localhost:8080/#) in map_flattened(func, *args)
     49 def map_flattened(func: Callable, *args: Any) -> List[Any]:
     50     """Apply a function to each element of a flattened list."""
---> 51     return list(map(func, *args))
     52 
     53 

[/usr/local/lib/python3.7/dist-packages/torchopt/_src/transform.py](https://localhost:8080/#) in f(g, t)
    211 
    212             def f(g, t):
--> 213                 return t.mul(decay).add_(g, alpha=1 - decay) if g is not None else t
    214 
    215     if already_flattened:

RuntimeError: output with shape [] doesn't match the broadcast shape [1]

To Reproduce

See the following notebook for detailed steps to reproduce
https://colab.research.google.com/drive/14S7DzcovrUwDkZuRxdq5OW5KwYqWTJsy?usp=sharing

Expected behavior

Screenshots

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
import torchopt, numpy, sys
print(torchopt.__version__, numpy.__version__, sys.version, sys.platform)
0.5.0 1.21.6 3.7.13 (default, Apr 24 2022, 01:04:09) 
[GCC 7.5.0] linux

Additional context

Add any other context about the problem here.

Reason and Possible fixes

If you know or suspect the reason for this bug, paste the code lines and suggest modifications.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@xianghang xianghang added the bug Something isn't working label Sep 9, 2022
@XuehaiPan
Copy link
Member

@xianghang Thanks for reporting this! It's a bug that the step count promotes the tensor shape. I'll add a new commit to fix this.

XuehaiPan added a commit to XuehaiPan/torchopt that referenced this issue Sep 9, 2022
XuehaiPan added a commit to XuehaiPan/torchopt that referenced this issue Sep 9, 2022
@XuehaiPan
Copy link
Member

@xianghang Hi, we have just published new wheels with the fix on PyPI (0.5.0.post1), run:

pip3 install --upgrade torchopt

will fix your issue.

@xianghang
Copy link
Author

pip3 install --upgrade torchopt

It works for me. Many thanks for the prompt response.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Development

Successfully merging a pull request may close this issue.

3 participants