Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions torchrec/distributed/benchmark/yaml/sparse_data_dist_base_vbe.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# this is a very basic sparse data dist config
# runs on 2 ranks, showing traces with reasonable workloads
RunOptions:
world_size: 2
batch_size: 16384
num_batches: 10
num_benchmarks: 1
num_profiles: 1
sharding_type: table_wise
profile_dir: "."
name: "sparse_data_dist_base"
# export_stacks: True # enable this to export stack traces
PipelineConfig:
pipeline: "sparse"
ModelInputConfig:
feature_pooling_avg: 30
use_variable_batch: True
EmbeddingTablesConfig:
num_unweighted_features: 90
num_weighted_features: 80
embedding_feature_dim: 256
additional_tables:
- - name: FP16_table
embedding_dim: 512
num_embeddings: 100_000
feature_names: ["additional_0_0"]
data_type: FP16
- name: large_table
embedding_dim: 2048
num_embeddings: 1_000_000
feature_names: ["additional_0_1"]
- []
- - name: skipped_table
embedding_dim: 128
num_embeddings: 100_000
feature_names: ["additional_2_1"]
PlannerConfig:
additional_constraints:
large_table:
sharding_types: [column_wise]
28 changes: 26 additions & 2 deletions torchrec/distributed/test_utils/input_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
from torchrec.modules.embedding_configs import EmbeddingBagConfig

from .model_input import ModelInput
from .model_input import ModelInput, VariableBatchModelInput


@dataclass
Expand All @@ -30,6 +30,7 @@ class ModelInputConfig:
long_kjt_offsets: bool = True
long_kjt_lengths: bool = True
pin_memory: bool = True
use_variable_batch: bool = False

def generate_batches(
self,
Expand All @@ -47,6 +48,29 @@ def generate_batches(
"""
device = torch.device(self.device) if self.device is not None else None

if self.use_variable_batch:
return [
VariableBatchModelInput.generate(
batch_size=self.batch_size,
num_float_features=self.num_float_features,
tables=tables,
weighted_tables=weighted_tables,
use_offsets=self.use_offsets,
indices_dtype=(
torch.int64 if self.long_kjt_indices else torch.int32
),
offsets_dtype=(
torch.int64 if self.long_kjt_offsets else torch.int32
),
lengths_dtype=(
torch.int64 if self.long_kjt_lengths else torch.int32
),
device=device,
pin_memory=self.pin_memory,
)
for _ in range(self.num_batches)
]

return [
ModelInput.generate(
batch_size=self.batch_size,
Expand All @@ -61,5 +85,5 @@ def generate_batches(
lengths_dtype=(torch.int64 if self.long_kjt_lengths else torch.int32),
pin_memory=self.pin_memory,
)
for batch_size in range(self.num_batches)
for _ in range(self.num_batches)
]
253 changes: 234 additions & 19 deletions torchrec/distributed/test_utils/model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

# pyre-strict

import random
from dataclasses import dataclass
from typing import cast, List, Optional, Tuple, Union
from typing import cast, Dict, List, Optional, Tuple, Union

import torch
from tensordict import TensorDict
Expand Down Expand Up @@ -396,16 +397,12 @@ def generate(
else torch.rand((batch_size,), device=device)
)
if pin_memory:
# all tensors in `ModelInput` should be on pinned memory otherwise
# the `_to_copy` (host-to-device) data transfer still blocks cpu execution
float_features = float_features.pin_memory()
label = label.pin_memory()
idlist_features = (
None if idlist_features is None else idlist_features.pin_memory()
)
idscore_features = (
None if idscore_features is None else idscore_features.pin_memory()
float_features, idlist_features, idscore_features, label = (
ModelInput._pin_memory(
float_features, idlist_features, idscore_features, label
)
)

return ModelInput(
float_features=float_features,
idlist_features=idlist_features,
Expand Down Expand Up @@ -523,6 +520,31 @@ def _assemble_kjt(
lengths = None
return KeyedJaggedTensor(features, indices, weights, lengths, offsets)

@staticmethod
def _pin_memory(
float_features: torch.Tensor,
idlist_features: Optional[KeyedJaggedTensor],
idscore_features: Optional[KeyedJaggedTensor],
label: torch.Tensor,
) -> Tuple[
torch.Tensor,
Optional[KeyedJaggedTensor],
Optional[KeyedJaggedTensor],
torch.Tensor,
]:
"""
Pin memory for all tensors in `ModelInput`

All tensors in `ModelInput` should be on pinned memory otherwise
the `_to_copy` (host-to-device) data transfer still blocks cpu execution
"""
return (
float_features.pin_memory(),
idlist_features.pin_memory(),
idscore_features.pin_memory(),
label.pin_memory(),
)

@staticmethod
def create_standard_kjt(
batch_size: int,
Expand Down Expand Up @@ -634,17 +656,210 @@ def _create_batched_standard_kjts(
return global_kjt, local_kjts


# @dataclass
# class VbModelInput(ModelInput):
# pass
@dataclass
class VariableBatchModelInput(ModelInput):

# @staticmethod
# def _create_variable_batch_kjt() -> KeyedJaggedTensor:
# pass
float_features: torch.Tensor
idlist_features: Optional[KeyedJaggedTensor]
idscore_features: Optional[KeyedJaggedTensor]
label: torch.Tensor

# @staticmethod
# def _merge_variable_batch_kjts(kjts: List[KeyedJaggedTensor]) -> KeyedJaggedTensor:
# pass
@classmethod
def generate(
cls,
batch_size: int = 1,
num_float_features: int = 16,
dedup_factor: int = 2,
tables: Optional[
Union[
List[EmbeddingTableConfig],
List[EmbeddingBagConfig],
List[EmbeddingConfig],
]
] = None,
weighted_tables: Optional[
Union[
List[EmbeddingTableConfig],
List[EmbeddingBagConfig],
List[EmbeddingConfig],
]
] = None,
pooling_avg: int = 10,
tables_pooling: Optional[List[int]] = None,
max_feature_lengths: Optional[List[int]] = None,
use_offsets: bool = False,
indices_dtype: torch.dtype = torch.int64,
offsets_dtype: torch.dtype = torch.int64,
lengths_dtype: torch.dtype = torch.int64,
all_zeros: bool = False,
device: Optional[torch.device] = None,
pin_memory: bool = False, # pin_memory is needed for training job qps benchmark
) -> "VariableBatchModelInput":
"""
Returns a single batch of `VariableBatchModelInput`

Different from `ModelInput`, `batch_size` is the average batch size which
is used together with the `dedup_factor` to get the actual batch size.
"""

float_features = torch.rand(
(dedup_factor * batch_size, num_float_features), device=device
)

idlist_features = (
VariableBatchModelInput._create_variable_batch_kjt(
tables=tables,
average_batch_size=batch_size,
dedup_factor=dedup_factor,
use_offsets=use_offsets,
indices_dtype=indices_dtype,
offsets_dtype=offsets_dtype,
lengths_dtype=lengths_dtype,
device=device,
)
if tables is not None and len(tables) > 0
else None
)

idscore_features = (
VariableBatchModelInput._create_variable_batch_kjt(
tables=weighted_tables,
average_batch_size=batch_size,
dedup_factor=dedup_factor,
use_offsets=use_offsets,
indices_dtype=indices_dtype,
offsets_dtype=offsets_dtype,
lengths_dtype=lengths_dtype,
device=device,
)
if weighted_tables is not None and len(weighted_tables) > 0
else None
)

label = torch.rand((dedup_factor * batch_size), device=device)

if pin_memory:
float_features, idlist_features, idscore_features, label = (
ModelInput._pin_memory(
float_features, idlist_features, idscore_features, label
)
)

return VariableBatchModelInput(
float_features=float_features,
idlist_features=idlist_features,
idscore_features=idscore_features,
label=label,
)

@staticmethod
def _create_variable_batch_kjt(
tables: Union[
List[EmbeddingTableConfig], List[EmbeddingBagConfig], List[EmbeddingConfig]
],
average_batch_size: int,
dedup_factor: int,
use_offsets: bool = False,
indices_dtype: torch.dtype = torch.int64,
offsets_dtype: torch.dtype = torch.int64,
lengths_dtype: torch.dtype = torch.int64,
device: Optional[torch.device] = None,
) -> KeyedJaggedTensor:

is_weighted = (
True if tables and getattr(tables[0], "is_weighted", False) else False
)

feature_num_embeddings = {}
for table in tables:
for feature_name in table.feature_names:
feature_num_embeddings[feature_name] = (
table.num_embeddings_post_pruning
if table.num_embeddings_post_pruning
else table.num_embeddings
)

keys = list(feature_num_embeddings.keys())
lengths_per_feature = {}
values_per_feature = {}
strides_per_feature = {}
inverse_indices_per_feature = {}
weights_per_feature = {} if is_weighted else None

for key, num_embeddings in feature_num_embeddings.items():
batch_size = random.randint(1, average_batch_size * dedup_factor - 1)
lengths = torch.randint(
low=0,
high=5,
size=(batch_size,),
dtype=lengths_dtype,
device=device,
)
lengths_per_feature[key] = lengths
lengths_sum = sum(lengths.tolist())
values = torch.randint(
0,
num_embeddings,
(lengths_sum,),
dtype=indices_dtype,
device=device,
)
values_per_feature[key] = values
if weights_per_feature is not None:
weights_per_feature[key] = torch.rand(
lengths_sum,
device=device,
)
strides_per_feature[key] = batch_size
inverse_indices_per_feature[key] = torch.randint(
0,
batch_size,
(dedup_factor * average_batch_size,),
dtype=indices_dtype,
device=device,
)

values = torch.cat(list(values_per_feature.values()))
lengths = torch.cat(list(lengths_per_feature.values()))
weights = (
torch.cat(list(weights_per_feature.values()))
if weights_per_feature is not None
else None
)
inverse_indices = (
keys,
torch.stack(list(inverse_indices_per_feature.values())),
)
strides = [[stride] for stride in strides_per_feature.values()]

if use_offsets:
offsets = torch.cat(
[
torch.tensor(
[0],
dtype=offsets_dtype,
device=device,
),
lengths.cumsum(0),
]
)
return KeyedJaggedTensor(
keys=keys,
values=values,
offsets=offsets,
weights=weights,
stride_per_key_per_rank=strides,
inverse_indices=inverse_indices,
)

return KeyedJaggedTensor(
keys=keys,
values=values,
lengths=lengths,
weights=weights,
stride_per_key_per_rank=strides,
inverse_indices=inverse_indices,
)


@dataclass
Expand Down
Loading