Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ This code is written for modern PyTorch (version 2.7 or newer) using DTensor-bas

## Quick Start

Install dependences:
Install dependencies:
```bash
pip install -r requirements.txt
```
Expand Down Expand Up @@ -139,7 +139,7 @@ We summarize the above in this table. Let `d_in` be the input dimension of the u
| Embedding | `nn.Embedding.weight` | `"lion"` / `"adamw"` | `lr` |
| Unembedding | `nn.Linear.weight` (must identify manually) | `"lion"` / `"adamw"` | `lr / math.sqrt(d_in)` |

We emphasize again that **particular care** needs to be taken with **embedding and unembedding layers**. They must be isolated from ordinary matrix parameters, and the unembedding layer futhermore should use a scaled learning rate. Merely checking the dimensions of a parameter (such as `if p.ndim == 2`) or the type of the module (such as `if isinstance(module, nn.Linear)`) **is not sufficient** to identify these special parameters. This is why we require manual parameter group creation.
We emphasize again that **particular care** needs to be taken with **embedding and unembedding layers**. They must be isolated from ordinary matrix parameters, and the unembedding layer furthermore should use a scaled learning rate. Merely checking the dimensions of a parameter (such as `if p.ndim == 2`) or the type of the module (such as `if isinstance(module, nn.Linear)`) **is not sufficient** to identify these special parameters. This is why we require manual parameter group creation.

The optimizer cannot tell if a given parameter is a weight matrix, embedding, or unembedding, because they are all two-dimensional tensors. You will not receive any errors if these parameters are incorrectly grouped with matrix weights!

Expand Down Expand Up @@ -242,7 +242,7 @@ optimizer = Dion(

Muon uses different device mesh arguments from Dion.

Our implementation of Muon takes a single 1D device mesh as a generic `distributed_mesh` argument. If this mesh is used for sharding parameters, Muon will efficiently perform unsharding using all-to-all. If this mesh is not used for sharding, Muon will distribue work across this mesh and all-gather the final results.
Our implementation of Muon takes a single 1D device mesh as a generic `distributed_mesh` argument. If this mesh is used for sharding parameters, Muon will efficiently perform unsharding using all-to-all. If this mesh is not used for sharding, Muon will distribute work across this mesh and all-gather the final results.

2D sharding is not supported by Muon---use Dion instead. For hybrid-sharded data parallel, with a replicated mesh dimension and a sharded dimension, pass only the sharded sub-mesh to Muon.

Expand Down Expand Up @@ -457,3 +457,5 @@ If you use Dion in your research, please cite:
year={2025}
}
```

LocalWords: Orthonormal
Comment on lines +460 to +461
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LocalWords: Orthonormal