Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: 3.9
python-version: 3.12
architecture: x64
packages: |
ufmt==2.5.1
Expand Down
7 changes: 0 additions & 7 deletions examples/retrieval/knn_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def get_index(
num_subquantizers: int,
bits_per_code: int,
device: Optional[torch.device] = None,
# pyre-ignore[11]
) -> Union[faiss.GpuIndexIVFPQ, faiss.IndexIVFPQ]:
"""
returns a FAISS IVFPQ index, placed on the device passed in
Expand All @@ -39,25 +38,19 @@ def get_index(

"""
if device is not None and device.type == "cuda":
# pyre-fixme[16]
res = faiss.StandardGpuResources()
# pyre-fixme[16]
config = faiss.GpuIndexIVFPQConfig()
# pyre-ignore[16]
index = faiss.GpuIndexIVFPQ(
res,
embedding_dim,
num_centroids,
num_subquantizers,
bits_per_code,
# pyre-fixme[16]
faiss.METRIC_L2,
config,
)
else:
# pyre-fixme[16]
quantizer = faiss.IndexFlatL2(embedding_dim)
# pyre-fixme[16]
index = faiss.IndexIVFPQ(
quantizer,
embedding_dim,
Expand Down
2 changes: 1 addition & 1 deletion examples/retrieval/modules/two_tower.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ class TwoTowerRetrieval(nn.Module):

def __init__(
self,
# pyre-ignore[11]
faiss_index: Union[faiss.GpuIndexIVFPQ, faiss.IndexIVFPQ],
query_ebc: EmbeddingBagCollection,
candidate_ebc: EmbeddingBagCollection,
Expand Down Expand Up @@ -222,6 +221,7 @@ def forward(self, query_kjt: KeyedJaggedTensor) -> torch.Tensor:
(batch_size, self.k), device=self.device, dtype=torch.int64
)
query_embedding = query_embedding.to(torch.float32) # required by faiss
# pyre-ignore[19]
self.faiss_index.search(query_embedding, self.k, distances, candidates)

# candidate lookup
Expand Down
11 changes: 7 additions & 4 deletions examples/retrieval/two_tower_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,9 @@ def infer(
retrieval_sd = None
if load_dir is not None:
load_dir = load_dir.rstrip("/")
# pyre-ignore[16]
index = faiss.index_cpu_to_gpu(
# pyre-ignore[16]
faiss.StandardGpuResources(),
faiss_device_idx,
# pyre-ignore[16]
faiss.read_index(f"{load_dir}/faiss.index"),
)
two_tower_sd = torch.load(f"{load_dir}/model.pt", weights_only=True)
Expand All @@ -158,7 +155,13 @@ def infer(
index.add(embeddings)

retrieval_model = TwoTowerRetrieval(
index, ebcs[0], ebcs[1], layer_sizes, k, device, dtype=torch.float16
index, # pyre-ignore[6]
ebcs[0],
ebcs[1],
layer_sizes,
k,
device,
dtype=torch.float16,
)

constraints = {}
Expand Down
1 change: 0 additions & 1 deletion examples/retrieval/two_tower_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def train(
model, dtype=torch.qint8, inplace=True
)
torch.save(quant_model.state_dict(), f"{save_dir}/model.pt")
# pyre-ignore[16]
faiss.write_index(index, f"{save_dir}/faiss.index")


Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/benchmark/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ class BenchFuncConfig:
world_size: int
num_profiles: int
num_benchmarks: int
profile_dir: str
profile_dir: str = ""
device_type: str = "cuda"
pre_gpu_load: int = 0
export_stacks: bool = False
Expand Down
72 changes: 6 additions & 66 deletions torchrec/distributed/benchmark/benchmark_train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
See benchmark_pipeline_utils.py for step-by-step instructions.
"""

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Type
from dataclasses import dataclass
from typing import List, Optional

import torch
from fbgemm_gpu.split_embedding_configs import EmbOptimType
Expand All @@ -37,8 +37,8 @@
from torchrec.distributed.test_utils.input_config import ModelInputConfig
from torchrec.distributed.test_utils.model_config import (
BaseModelConfig,
create_model_config,
generate_sharded_model_and_optimizer,
ModelSelectionConfig,
)
from torchrec.distributed.test_utils.model_input import ModelInput

Expand All @@ -49,7 +49,6 @@
from torchrec.distributed.test_utils.pipeline_config import PipelineConfig
from torchrec.distributed.test_utils.sharding_config import PlannerConfig
from torchrec.distributed.test_utils.table_config import EmbeddingTablesConfig
from torchrec.distributed.test_utils.test_model import TestOverArchLarge
from torchrec.distributed.train_pipeline import TrainPipeline
from torchrec.distributed.types import ShardingType
from torchrec.modules.embedding_configs import EmbeddingBagConfig
Expand Down Expand Up @@ -94,11 +93,11 @@ class RunOptions(BenchFuncConfig):
"""

world_size: int = 2
batch_size: int = 1024 * 32
num_float_features: int = 10
num_batches: int = 10
sharding_type: ShardingType = ShardingType.TABLE_WISE
input_type: str = "kjt"
name: str = ""
profile_dir: str = ""
num_benchmarks: int = 5
num_profiles: int = 2
num_poolings: Optional[List[float]] = None
Expand All @@ -113,39 +112,6 @@ class RunOptions(BenchFuncConfig):
export_stacks: bool = False


@dataclass
class ModelSelectionConfig:
model_name: str = "test_sparse_nn"

# Common config for all model types
batch_size: int = 1024 * 32
batch_sizes: Optional[List[int]] = None
num_float_features: int = 10
feature_pooling_avg: int = 10
use_offsets: bool = False
dev_str: str = ""
long_kjt_indices: bool = True
long_kjt_offsets: bool = True
long_kjt_lengths: bool = True
pin_memory: bool = True

# TestSparseNN specific config
embedding_groups: Optional[Dict[str, List[str]]] = None
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None
max_feature_lengths: Optional[Dict[str, int]] = None
over_arch_clazz: Type[nn.Module] = TestOverArchLarge
postproc_module: Optional[nn.Module] = None
zch: bool = False

# DeepFM specific config
hidden_layer_size: int = 20
deep_fm_dimension: int = 5

# DLRM specific config
dense_arch_layer_sizes: List[int] = field(default_factory=lambda: [20, 128])
over_arch_layer_sizes: List[int] = field(default_factory=lambda: [5, 1])


# single-rank runner
def runner(
rank: int,
Expand Down Expand Up @@ -303,35 +269,9 @@ def main(
pipeline_config: PipelineConfig,
input_config: ModelInputConfig,
planner_config: PlannerConfig,
model_config: Optional[BaseModelConfig] = None,
) -> None:
tables, weighted_tables, *_ = table_config.generate_tables()

if model_config is None:
model_config = create_model_config(
model_name=model_selection.model_name,
batch_size=model_selection.batch_size,
batch_sizes=model_selection.batch_sizes,
num_float_features=model_selection.num_float_features,
feature_pooling_avg=model_selection.feature_pooling_avg,
use_offsets=model_selection.use_offsets,
dev_str=model_selection.dev_str,
long_kjt_indices=model_selection.long_kjt_indices,
long_kjt_offsets=model_selection.long_kjt_offsets,
long_kjt_lengths=model_selection.long_kjt_lengths,
pin_memory=model_selection.pin_memory,
embedding_groups=model_selection.embedding_groups,
feature_processor_modules=model_selection.feature_processor_modules,
max_feature_lengths=model_selection.max_feature_lengths,
over_arch_clazz=model_selection.over_arch_clazz,
postproc_module=model_selection.postproc_module,
zch=model_selection.zch,
hidden_layer_size=model_selection.hidden_layer_size,
deep_fm_dimension=model_selection.deep_fm_dimension,
dense_arch_layer_sizes=model_selection.dense_arch_layer_sizes,
over_arch_layer_sizes=model_selection.over_arch_layer_sizes,
)

model_config = model_selection.create_model_config()
# launch trainers
run_multi_process_func(
func=runner,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/benchmark/yaml/sparse_data_dist_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ RunOptions:
# export_stacks: True # enable this to export stack traces
PipelineConfig:
pipeline: "sparse"
ModelInputConfig:
feature_pooling_avg: 10
EmbeddingTablesConfig:
num_unweighted_features: 100
num_weighted_features: 100
Expand Down
76 changes: 58 additions & 18 deletions torchrec/distributed/test_utils/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import copy
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from dataclasses import dataclass, field, fields
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch
Expand All @@ -31,6 +31,7 @@
from torchrec.distributed.planner.planners import HeteroEmbeddingShardingPlanner
from torchrec.distributed.sharding_plan import get_default_sharders
from torchrec.distributed.test_utils.test_model import (
TestOverArchLarge,
TestSparseNN,
TestTowerCollectionSparseNN,
TestTowerSparseNN,
Expand All @@ -51,8 +52,9 @@ class BaseModelConfig(ABC):
and requires each concrete implementation to provide its own generate_model method.
"""

# Common parameters for all model types
num_float_features: int # we assume all model arch has a single dense feature layer
## Common parameters for all model types, please do not set default values here
# we assume all model arch has a single dense feature layer
num_float_features: int

@abstractmethod
def generate_model(
Expand Down Expand Up @@ -80,12 +82,12 @@ def generate_model(
class TestSparseNNConfig(BaseModelConfig):
"""Configuration for TestSparseNN model."""

embedding_groups: Optional[Dict[str, List[str]]]
feature_processor_modules: Optional[Dict[str, torch.nn.Module]]
max_feature_lengths: Optional[Dict[str, int]]
over_arch_clazz: Type[nn.Module]
postproc_module: Optional[nn.Module]
zch: bool
embedding_groups: Optional[Dict[str, List[str]]] = None
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None
max_feature_lengths: Optional[Dict[str, int]] = None
over_arch_clazz: Type[nn.Module] = TestOverArchLarge
postproc_module: Optional[nn.Module] = None
zch: bool = False

def generate_model(
self,
Expand Down Expand Up @@ -113,8 +115,8 @@ def generate_model(
class TestTowerSparseNNConfig(BaseModelConfig):
"""Configuration for TestTowerSparseNN model."""

embedding_groups: Optional[Dict[str, List[str]]]
feature_processor_modules: Optional[Dict[str, torch.nn.Module]]
embedding_groups: Optional[Dict[str, List[str]]] = None
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None

def generate_model(
self,
Expand All @@ -138,8 +140,8 @@ def generate_model(
class TestTowerCollectionSparseNNConfig(BaseModelConfig):
"""Configuration for TestTowerCollectionSparseNN model."""

embedding_groups: Optional[Dict[str, List[str]]]
feature_processor_modules: Optional[Dict[str, torch.nn.Module]]
embedding_groups: Optional[Dict[str, List[str]]] = None
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None

def generate_model(
self,
Expand All @@ -163,8 +165,8 @@ def generate_model(
class DeepFMConfig(BaseModelConfig):
"""Configuration for DeepFM model."""

hidden_layer_size: int
deep_fm_dimension: int
hidden_layer_size: int = 20
deep_fm_dimension: int = 5

def generate_model(
self,
Expand All @@ -189,8 +191,8 @@ def generate_model(
class DLRMConfig(BaseModelConfig):
"""Configuration for DLRM model."""

dense_arch_layer_sizes: List[int]
over_arch_layer_sizes: List[int]
dense_arch_layer_sizes: List[int] = field(default_factory=lambda: [20, 128])
over_arch_layer_sizes: List[int] = field(default_factory=lambda: [5, 1])

def generate_model(
self,
Expand All @@ -213,7 +215,9 @@ def generate_model(

# pyre-ignore[2]: Missing parameter annotation
def create_model_config(model_name: str, **kwargs) -> BaseModelConfig:

"""
deprecated function, please use ModelSelectionConfig.create_model_config instead
"""
model_configs = {
"test_sparse_nn": TestSparseNNConfig,
"test_tower_sparse_nn": TestTowerSparseNNConfig,
Expand Down Expand Up @@ -309,3 +313,39 @@ def generate_sharded_model_and_optimizer(
optimizer = optimizer_class(dense_params, **optimizer_kwargs)

return sharded_model, optimizer


@dataclass
class ModelSelectionConfig:
model_name: str = "test_sparse_nn"
model_config: Dict[str, Any] = field(
default_factory=lambda: {"num_float_features": 10}
)

def get_model_config_class(self) -> Type[BaseModelConfig]:
match self.model_name:
case "test_sparse_nn":
return TestSparseNNConfig
case "test_tower_sparse_nn":
return TestTowerSparseNNConfig
case "test_tower_collection_sparse_nn":
return TestTowerCollectionSparseNNConfig
case "deepfm":
return DeepFMConfig
case "dlrm":
return DLRMConfig
case _:
raise ValueError(f"Unknown model name: {self.model_name}")

def create_model_config(self) -> BaseModelConfig:
config_class = self.get_model_config_class()
valid_field_names = {field.name for field in fields(config_class)}
filtered_kwargs = {
k: v for k, v in self.model_config.items() if k in valid_field_names
}
# pyre-ignore[45]: Invalid class instantiation
return config_class(**filtered_kwargs)

def create_test_model(self, **kwargs: Any) -> nn.Module:
model_config = self.create_model_config()
return model_config.generate_model(**kwargs)
2 changes: 1 addition & 1 deletion torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2364,7 +2364,7 @@ def __init__(
if device is None:
device = torch.device("cpu")
if max_sequence_length is None:
max_sequence_length = 10
max_sequence_length = 20
if dense_arch_out_size is None:
dense_arch_out_size = DENSE_LAYER_OUT_SIZE
if over_arch_out_size is None:
Expand Down
2 changes: 1 addition & 1 deletion torchrec/modules/object_pool_lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def update(self, ids: torch.Tensor, values: JaggedTensor) -> None:
.sum(axis=1)
)
key_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(key_lengths)
padded_values = torch.ops.fbgemm.jagged_to_padded_dense(
padded_values: torch.Tensor = torch.ops.fbgemm.jagged_to_padded_dense(
values.values(),
[key_offsets],
[self._bit_dims],
Expand Down
Loading