Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
730 changes: 19 additions & 711 deletions torchrec/distributed/embedding.py

Large diffs are not rendered by default.

794 changes: 794 additions & 0 deletions torchrec/distributed/embeddingbag.py

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
from torch import nn
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.embedding import (
from torchrec.distributed.embeddingbag import (
EmbeddingBagCollectionSharder,
QuantEmbeddingBagCollectionSharder,
filter_state_dict,
)
from torchrec.distributed.planner import EmbeddingShardingPlanner, sharder_name
from torchrec.distributed.types import (
Expand All @@ -21,6 +20,7 @@
ShardingEnv,
)
from torchrec.distributed.utils import append_prefix
from torchrec.distributed.utils import filter_state_dict
from torchrec.optim.fused import FusedOptimizerModule
from torchrec.optim.keyed import KeyedOptimizer, CombinedOptimizer

Expand Down Expand Up @@ -53,7 +53,9 @@ def init_weights(m):
device: this device, defaults to cpu,
plan: plan to use when sharding, defaults to EmbeddingShardingPlanner.collective_plan(),
sharders: ModuleSharders available to shard with, defaults to EmbeddingBagCollectionSharder(),
init_data_parallel: data-parallel modules can be lazy, i.e. they delay parameter initialization until the first forward pass. Pass True if that's a case to delay initialization of data parallel modules. Do first forward pass and then call DistributedModelParallel.init_data_parallel().
init_data_parallel: data-parallel modules can be lazy, i.e. they delay parameter initialization until
the first forward pass. Pass True if that's a case to delay initialization of data parallel modules.
Do first forward pass and then call DistributedModelParallel.init_data_parallel().
init_parameters: initialize parameters for modules still on meta device.

Call Args:
Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/planner/new/tests/test_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import unittest

from torchrec.distributed.embedding import (
from torchrec.distributed.embedding_types import EmbeddingTableConfig
from torchrec.distributed.embeddingbag import (
EmbeddingBagCollectionSharder,
)
from torchrec.distributed.embedding_types import EmbeddingTableConfig
from torchrec.distributed.planner.new.calculators import EmbeddingWTCostCalculator
from torchrec.distributed.planner.new.enumerators import ShardingEnumerator
from torchrec.distributed.planner.new.types import Topology
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/planner/new/tests/test_enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import unittest
from typing import List

from torchrec.distributed.embedding import EmbeddingBagCollectionSharder
from torchrec.distributed.embedding_types import (
EmbeddingComputeKernel,
)
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner.new.constants import (
BIGINT_DTYPE,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import List

from torch import nn
from torchrec.distributed.embedding import EmbeddingBagCollectionSharder
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner.new.enumerators import ShardingEnumerator
from torchrec.distributed.planner.new.partitioners import GreedyCostPartitioner
from torchrec.distributed.planner.new.types import Storage, Topology, PartitionByType
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/planner/new/tests/test_placers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import torch
from torch import nn
from torchrec.distributed.embedding import EmbeddingBagCollectionSharder
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner.new.calculators import EmbeddingWTCostCalculator
from torchrec.distributed.planner.new.enumerators import ShardingEnumerator
from torchrec.distributed.planner.new.partitioners import GreedyCostPartitioner
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/planner/new/tests/test_rankers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import unittest

from torchrec.distributed.embedding import (
from torchrec.distributed.embeddingbag import (
EmbeddingBagCollectionSharder,
)
from torchrec.distributed.planner.new.calculators import EmbeddingWTCostCalculator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from unittest.mock import MagicMock, patch, call

from torch.distributed._sharding_spec import ShardMetadata, EnumerableShardingSpec
from torchrec.distributed.embedding import EmbeddingBagCollectionSharder
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner.embedding_planner import EmbeddingShardingPlanner
from torchrec.distributed.planner.parameter_sharding import _rw_shard_table_rows
from torchrec.distributed.planner.types import ParameterHints
Expand Down
6 changes: 3 additions & 3 deletions torchrec/distributed/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import torch
import torch.nn as nn
from torchrec.distributed.embedding import (
EmbeddingBagCollectionSharder,
from torchrec.distributed.embedding_types import EmbeddingTableConfig
from torchrec.distributed.embeddingbag import (
EmbeddingBagSharder,
EmbeddingBagCollectionSharder,
)
from torchrec.distributed.embedding_types import EmbeddingTableConfig
from torchrec.modules.embedding_configs import EmbeddingBagConfig, BaseEmbeddingConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
Expand Down
6 changes: 2 additions & 4 deletions torchrec/distributed/tests/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
import torch.nn as nn
from fbgemm_gpu.split_embedding_configs import EmbOptimType
from hypothesis import Verbosity, given, settings
from torchrec.distributed.embedding import (
EmbeddingBagCollectionSharder,
EmbeddingBagSharder,
)
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.embeddingbag import EmbeddingBagSharder
from torchrec.distributed.model_parallel import (
DistributedModelParallel,
default_sharders,
Expand Down
7 changes: 3 additions & 4 deletions torchrec/distributed/tests/test_quant_model_parallel.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
#!/usr/bin/env python3

import copy
import os
import unittest
from typing import List

import torch
import torch.distributed as dist
from torch import nn
from torch import quantization as quant
from torchrec.distributed.embedding import QuantEmbeddingBagCollectionSharder
from torchrec.distributed.embedding_lookup import (
GroupedEmbeddingBag,
BatchedFusedEmbeddingBag,
BatchedDenseEmbeddingBag,
QuantBatchedEmbeddingBag,
)
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import (
QuantEmbeddingBagCollectionSharder,
)
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.tests.test_model import (
TestSparseNN,
Expand All @@ -29,7 +29,6 @@
from torchrec.quant.embedding_modules import (
EmbeddingBagCollection as QuantEmbeddingBagCollection,
)
from torchrec.tests.utils import get_free_port


class TestQuantEBCSharder(QuantEmbeddingBagCollectionSharder):
Expand Down
8 changes: 4 additions & 4 deletions torchrec/distributed/tests/test_train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
import torch.distributed as dist
from torch import nn, optim
from torchrec.distributed import DistributedModelParallel
from torchrec.distributed.embedding import (
ShardedEmbeddingBagCollection,
EmbeddingBagCollectionSharder,
)
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embedding_types import (
SparseFeaturesList,
)
from torchrec.distributed.embeddingbag import (
ShardedEmbeddingBagCollection,
EmbeddingBagCollectionSharder,
)
from torchrec.distributed.tests.test_model import (
TestSparseNN,
ModelInput,
Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import numpy as np
import torch
import torch.distributed as dist
from torchrec.distributed.embedding import (
from torchrec.distributed.embedding_sharding import bucketize_kjt_before_all2all
from torchrec.distributed.embeddingbag import (
EmbeddingBagCollectionSharder,
)
from torchrec.distributed.embedding_sharding import bucketize_kjt_before_all2all
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.tests.test_model import TestSparseNN
from torchrec.distributed.utils import get_unsharded_module_names
Expand Down
12 changes: 12 additions & 0 deletions torchrec/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3

from collections import OrderedDict
from typing import List, Set, Union

import torch
Expand All @@ -13,6 +14,17 @@ def append_prefix(prefix: str, name: str) -> str:
return prefix + name


def filter_state_dict(
state_dict: "OrderedDict[str, torch.Tensor]", name: str
) -> "OrderedDict[str, torch.Tensor]":
rtn_dict = OrderedDict()
for key, value in state_dict.items():
if key.startswith(name):
# + 1 to length is to remove the '.' after the key
rtn_dict[key[len(name) + 1 :]] = value
return rtn_dict


def _get_unsharded_module_names_helper(
model: torch.nn.Module,
path: str,
Expand Down
Empty file added torchrec/examples/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions torchrec/examples/dlrm/README.MD
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Running

## Torchx
We recommend using [torchx](https://pytorch.org/torchx/main/quickstart.html) to run.
Here we use the [DDP builtin](https://pytorch.org/torchx/main/components/distributed.html)

1. pip install torchx
2. (optional) setup a slurm or kubernetes cluster
3.
a. locally: torchx run dist.ddp -j 1x2 --script dlrm_main.py
b. remotely: torchx run -s slurm dist.ddp -j 1x8 --script dlrm_main.py
Empty file.
Loading