Skip to content

Conversation

TroyGarden
Copy link
Contributor

Summary:

details

  • serialization schedule
comp.ebc does NOT require further serialization of its children
comp.ebc.embedding_bags is skipped for further serialization
comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.ebc.embedding_bags.t3 is skipped for further serialization
comp.comp is resumed for serialization
comp.comp Requires further serialization of its children
comp.comp.ebc does NOT require further serialization of its children
comp.comp.ebc.embedding_bags is skipped for further serialization
comp.comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.comp.ebc.embedding_bags.t3 is skipped for further serialization
comp.comp.comp is resumed for serialization
comp.comp.comp Requires further serialization of its children
comp.comp.comp.ebc does NOT require further serialization of its children
comp.comp.comp.ebc.embedding_bags is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t3 is skipped for further serialization
  • parent_fqn's children
comp.comp.comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
)} None None
comp.comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
), 'comp': CompoundModule(
  (ebc): EmbeddingBagCollection(
    (embedding_bags): ModuleDict(
      (t1): EmbeddingBag(10, 4, mode='sum')
      (t2): EmbeddingBag(10, 4, mode='sum')
      (t3): EmbeddingBag(10, 4, mode='sum')
    )
  )
)} None None
comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
), 'comp': CompoundModule(
  (ebc): EmbeddingBagCollection(
    (embedding_bags): ModuleDict(
      (t1): EmbeddingBag(10, 4, mode='sum')
      (t2): EmbeddingBag(10, 4, mode='sum')
      (t3): EmbeddingBag(10, 4, mode='sum')
    )
  )
  (comp): CompoundModule(
    (ebc): EmbeddingBagCollection(
      (embedding_bags): ModuleDict(
        (t1): EmbeddingBag(10, 4, mode='sum')
        (t2): EmbeddingBag(10, 4, mode='sum')
        (t3): EmbeddingBag(10, 4, mode='sum')
      )
    )
  )
)} None None

Differential Revision: D58221182

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 6, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58221182

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Jun 6, 2024
Summary:

# details
* serialization schedule
```
comp.ebc does NOT require further serialization of its children
comp.ebc.embedding_bags is skipped for further serialization
comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.ebc.embedding_bags.t3 is skipped for further serialization
comp.comp is resumed for serialization
comp.comp Requires further serialization of its children
comp.comp.ebc does NOT require further serialization of its children
comp.comp.ebc.embedding_bags is skipped for further serialization
comp.comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.comp.ebc.embedding_bags.t3 is skipped for further serialization
comp.comp.comp is resumed for serialization
comp.comp.comp Requires further serialization of its children
comp.comp.comp.ebc does NOT require further serialization of its children
comp.comp.comp.ebc.embedding_bags is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t3 is skipped for further serialization
```
* `parent_fqn`'s children
```
comp.comp.comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
)} None None
```
```
comp.comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
), 'comp': CompoundModule(
  (ebc): EmbeddingBagCollection(
    (embedding_bags): ModuleDict(
      (t1): EmbeddingBag(10, 4, mode='sum')
      (t2): EmbeddingBag(10, 4, mode='sum')
      (t3): EmbeddingBag(10, 4, mode='sum')
    )
  )
)} None None
```
```
comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
), 'comp': CompoundModule(
  (ebc): EmbeddingBagCollection(
    (embedding_bags): ModuleDict(
      (t1): EmbeddingBag(10, 4, mode='sum')
      (t2): EmbeddingBag(10, 4, mode='sum')
      (t3): EmbeddingBag(10, 4, mode='sum')
    )
  )
  (comp): CompoundModule(
    (ebc): EmbeddingBagCollection(
      (embedding_bags): ModuleDict(
        (t1): EmbeddingBag(10, 4, mode='sum')
        (t2): EmbeddingBag(10, 4, mode='sum')
        (t3): EmbeddingBag(10, 4, mode='sum')
      )
    )
  )
)} None None
```

Differential Revision: D58221182
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58221182

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58221182

Summary:
Pull Request resolved: meta-pytorch#2082

# context
* to support compound module serialization such as fpEBC and fpPEA, we modified the serializer interface a little
* the basic idea is that:
a. if a compound module A consists of child module B and C, such that module B and C have their own serializer available.
b. for serialization of module A, we can just capture it relation from B and C, and used it when deserialization
c. specifically, during the deserialization, A's children will be passed in as a dict of (child_fqn => child_module)

# design doc
[**TorchRec Composable Serializer Design**](https://docs.google.com/document/d/1WUtmzdcqZmwLd4Do8g1fQRjChnRw0ZimyUrCNhxa4nA/edit#heading=h.ezrtdguw0lwq)

# note
* in order to apply this approach, it requires that module A's construction **takes in it's children as objects**
* **DO NOT** create A's children in the A's construction

# details
* serialization schedule
```
comp.ebc does NOT require further serialization of its children
comp.ebc.embedding_bags is skipped for further serialization
comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.ebc.embedding_bags.t3 is skipped for further serialization
comp.comp is resumed for serialization
comp.comp Requires further serialization of its children
comp.comp.ebc does NOT require further serialization of its children
comp.comp.ebc.embedding_bags is skipped for further serialization
comp.comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.comp.ebc.embedding_bags.t3 is skipped for further serialization
comp.comp.comp is resumed for serialization
comp.comp.comp Requires further serialization of its children
comp.comp.comp.ebc does NOT require further serialization of its children
comp.comp.comp.ebc.embedding_bags is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t1 is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t2 is skipped for further serialization
comp.comp.comp.ebc.embedding_bags.t3 is skipped for further serialization
```
* `parent_fqn`'s children
```
comp.comp.comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
)} None None
```
```
comp.comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
), 'comp': CompoundModule(
  (ebc): EmbeddingBagCollection(
    (embedding_bags): ModuleDict(
      (t1): EmbeddingBag(10, 4, mode='sum')
      (t2): EmbeddingBag(10, 4, mode='sum')
      (t3): EmbeddingBag(10, 4, mode='sum')
    )
  )
)} None None
```
```
comp {'ebc': EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (t1): EmbeddingBag(10, 4, mode='sum')
    (t2): EmbeddingBag(10, 4, mode='sum')
    (t3): EmbeddingBag(10, 4, mode='sum')
  )
), 'comp': CompoundModule(
  (ebc): EmbeddingBagCollection(
    (embedding_bags): ModuleDict(
      (t1): EmbeddingBag(10, 4, mode='sum')
      (t2): EmbeddingBag(10, 4, mode='sum')
      (t3): EmbeddingBag(10, 4, mode='sum')
    )
  )
  (comp): CompoundModule(
    (ebc): EmbeddingBagCollection(
      (embedding_bags): ModuleDict(
        (t1): EmbeddingBag(10, 4, mode='sum')
        (t2): EmbeddingBag(10, 4, mode='sum')
        (t3): EmbeddingBag(10, 4, mode='sum')
      )
    )
  )
)} None None
```

Differential Revision: D58221182
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58221182

@TroyGarden TroyGarden closed this Jul 15, 2024
@TroyGarden TroyGarden deleted the export-D58221182 branch August 8, 2024 22:14
TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Apr 16, 2025
Summary:

# context
* previously we use a util function [`create_sharding_infos_by_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding.py#L150-L229) to group the `sharding_info`s so that a sharded module can create an [`EmbeddingSharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding.py#L601-L643) with grouped sharding_infos.
* after recent refactoring [meta-pytorch#2887](meta-pytorch#2887) the `create_embedding_sharding` becomes a public API, it's also reasonable to promote create_sharding_infos_by_sharding as a public API (classmethod) so that user can subclass it and overrides it. 
* since "grouping" is more relevant to this function, we'll rename it as "create_grouped_sharding_infos".

Differential Revision: D58221182
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants