Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions install-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
fbgemm-gpu
tensordict
torchmetrics==1.0.3
tqdm
pyre-extensions
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ numpy
pandas
pyre-extensions
scikit-build
tensordict
torchmetrics==1.0.3
torchx
tqdm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

#!/usr/bin/env python3

from typing import Dict, List

import click

import torch
Expand Down Expand Up @@ -82,9 +84,10 @@ def op_bench(
)

def _func_to_benchmark(
kjt: KeyedJaggedTensor,
kjts: List[Dict[str, KeyedJaggedTensor]],
model: torch.nn.Module,
) -> torch.Tensor:
kjt = kjts[0]["feature"]
return model.forward(kjt.values(), kjt.offsets())

# breakpoint() # import fbvscode; fbvscode.set_trace()
Expand All @@ -108,8 +111,8 @@ def _func_to_benchmark(

result = benchmark_func(
name=f"SplitTableBatchedEmbeddingBagsCodegen-{num_embeddings}-{embedding_dim}-{num_tables}-{batch_size}-{bag_size}",
bench_inputs=inputs, # pyre-ignore
prof_inputs=inputs, # pyre-ignore
bench_inputs=[{"feature": inputs}],
prof_inputs=[{"feature": inputs}],
num_benchmarks=10,
num_profiles=10,
profile_dir=".",
Expand Down
5 changes: 4 additions & 1 deletion torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,11 +374,14 @@ def get_inputs(

if train:
sparse_features_by_rank = [
model_input.idlist_features for model_input in model_input_by_rank
model_input.idlist_features
for model_input in model_input_by_rank
if isinstance(model_input.idlist_features, KeyedJaggedTensor)
]
inputs_batch.append(sparse_features_by_rank)
else:
sparse_features = model_input_by_rank[0].idlist_features
assert isinstance(sparse_features, KeyedJaggedTensor)
inputs_batch.append([sparse_features])

# Transpose if train, as inputs_by_rank is currently in [B X R] format
Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/test_utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def model_input_to_forward_args_kjt(
Optional[torch.Tensor],
]:
kjt = mi.idlist_features
assert isinstance(kjt, KeyedJaggedTensor)
return (
kjt._keys,
kjt._values,
Expand Down Expand Up @@ -289,7 +290,8 @@ def model_input_to_forward_args(
]:
idlist_kjt = mi.idlist_features
idscore_kjt = mi.idscore_features
assert idscore_kjt is not None
assert isinstance(idlist_kjt, KeyedJaggedTensor)
assert isinstance(idscore_kjt, KeyedJaggedTensor)
return (
mi.float_features,
idlist_kjt._keys,
Expand Down
131 changes: 92 additions & 39 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import torch.nn as nn
from tensordict import TensorDict
from torchrec.distributed.embedding_tower_sharding import (
EmbeddingTowerCollectionSharder,
EmbeddingTowerSharder,
Expand Down Expand Up @@ -46,8 +47,8 @@
@dataclass
class ModelInput(Pipelineable):
float_features: torch.Tensor
idlist_features: KeyedJaggedTensor
idscore_features: Optional[KeyedJaggedTensor]
idlist_features: Union[KeyedJaggedTensor, TensorDict]
idscore_features: Optional[Union[KeyedJaggedTensor, TensorDict]]
label: torch.Tensor

@staticmethod
Expand Down Expand Up @@ -76,11 +77,13 @@ def generate(
randomize_indices: bool = True,
device: Optional[torch.device] = None,
max_feature_lengths: Optional[List[int]] = None,
input_type: str = "kjt",
) -> Tuple["ModelInput", List["ModelInput"]]:
"""
Returns a global (single-rank training) batch
and a list of local (multi-rank training) batches of world_size.
"""

batch_size_by_rank = [batch_size] * world_size
if variable_batch_size:
batch_size_by_rank = [
Expand Down Expand Up @@ -199,11 +202,26 @@ def _validate_pooling_factor(
)
global_idlist_lengths.append(lengths)
global_idlist_indices.append(indices)
global_idlist_kjt = KeyedJaggedTensor(
keys=idlist_features,
values=torch.cat(global_idlist_indices),
lengths=torch.cat(global_idlist_lengths),
)

if input_type == "kjt":
global_idlist_input = KeyedJaggedTensor(
keys=idlist_features,
values=torch.cat(global_idlist_indices),
lengths=torch.cat(global_idlist_lengths),
)
elif input_type == "td":
dict_of_nt = {
k: torch.nested.nested_tensor_from_jagged(
values=values,
lengths=lengths,
)
for k, values, lengths in zip(
idlist_features, global_idlist_indices, global_idlist_lengths
)
}
global_idlist_input = TensorDict(source=dict_of_nt)
else:
raise ValueError(f"For IdList features, unknown input type {input_type}")

for idx in range(len(idscore_ind_ranges)):
ind_range = idscore_ind_ranges[idx]
Expand Down Expand Up @@ -245,16 +263,25 @@ def _validate_pooling_factor(
global_idscore_lengths.append(lengths)
global_idscore_indices.append(indices)
global_idscore_weights.append(weights)
global_idscore_kjt = (
KeyedJaggedTensor(
keys=idscore_features,
values=torch.cat(global_idscore_indices),
lengths=torch.cat(global_idscore_lengths),
weights=torch.cat(global_idscore_weights),

if input_type == "kjt":
global_idscore_input = (
KeyedJaggedTensor(
keys=idscore_features,
values=torch.cat(global_idscore_indices),
lengths=torch.cat(global_idscore_lengths),
weights=torch.cat(global_idscore_weights),
)
if global_idscore_indices
else None
)
if global_idscore_indices
else None
)
elif input_type == "td":
assert (
len(idscore_features) == 0
), "TensorDict does not support weighted features"
global_idscore_input = None
else:
raise ValueError(f"For weighted features, unknown input type {input_type}")

if randomize_indices:
global_float = torch.rand(
Expand Down Expand Up @@ -303,36 +330,57 @@ def _validate_pooling_factor(
weights[lengths_cumsum[r] : lengths_cumsum[r + 1]]
)

local_idlist_kjt = KeyedJaggedTensor(
keys=idlist_features,
values=torch.cat(local_idlist_indices),
lengths=torch.cat(local_idlist_lengths),
)
if input_type == "kjt":
local_idlist_input = KeyedJaggedTensor(
keys=idlist_features,
values=torch.cat(local_idlist_indices),
lengths=torch.cat(local_idlist_lengths),
)

local_idscore_kjt = (
KeyedJaggedTensor(
keys=idscore_features,
values=torch.cat(local_idscore_indices),
lengths=torch.cat(local_idscore_lengths),
weights=torch.cat(local_idscore_weights),
local_idscore_input = (
KeyedJaggedTensor(
keys=idscore_features,
values=torch.cat(local_idscore_indices),
lengths=torch.cat(local_idscore_lengths),
weights=torch.cat(local_idscore_weights),
)
if local_idscore_indices
else None
)
elif input_type == "td":
dict_of_nt = {
k: torch.nested.nested_tensor_from_jagged(
values=values,
lengths=lengths,
)
for k, values, lengths in zip(
idlist_features, local_idlist_indices, local_idlist_lengths
)
}
local_idlist_input = TensorDict(source=dict_of_nt)
assert (
len(idscore_features) == 0
), "TensorDict does not support weighted features"
local_idscore_input = None

else:
raise ValueError(
f"For weighted features, unknown input type {input_type}"
)
if local_idscore_indices
else None
)

local_input = ModelInput(
float_features=global_float[r * batch_size : (r + 1) * batch_size],
idlist_features=local_idlist_kjt,
idscore_features=local_idscore_kjt,
idlist_features=local_idlist_input,
idscore_features=local_idscore_input,
label=global_label[r * batch_size : (r + 1) * batch_size],
)
local_inputs.append(local_input)

return (
ModelInput(
float_features=global_float,
idlist_features=global_idlist_kjt,
idscore_features=global_idscore_kjt,
idlist_features=global_idlist_input,
idscore_features=global_idscore_input,
label=global_label,
),
local_inputs,
Expand Down Expand Up @@ -623,8 +671,9 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput":

def record_stream(self, stream: torch.Stream) -> None:
self.float_features.record_stream(stream)
self.idlist_features.record_stream(stream)
if self.idscore_features is not None:
if isinstance(self.idlist_features, KeyedJaggedTensor):
self.idlist_features.record_stream(stream)
if isinstance(self.idscore_features, KeyedJaggedTensor):
self.idscore_features.record_stream(stream)
self.label.record_stream(stream)

Expand Down Expand Up @@ -1753,10 +1802,12 @@ def forward(
if self._preproc_module is not None:
modified_input = self._preproc_module(modified_input)
elif self._run_preproc_inline:
idlist_features = modified_input.idlist_features
assert isinstance(idlist_features, KeyedJaggedTensor)
modified_input.idlist_features = KeyedJaggedTensor.from_lengths_sync(
modified_input.idlist_features.keys(),
modified_input.idlist_features.values(),
modified_input.idlist_features.lengths(),
idlist_features.keys(),
idlist_features.values(),
idlist_features.lengths(),
)

modified_idlist_features = self.preproc_nonweighted(
Expand Down Expand Up @@ -1820,6 +1871,8 @@ def forward(self, input: ModelInput) -> ModelInput:
)

# stride will be same but features will be joined
assert isinstance(modified_input.idlist_features, KeyedJaggedTensor)
assert isinstance(self._extra_input.idlist_features, KeyedJaggedTensor)
modified_input.idlist_features = KeyedJaggedTensor.concat(
[modified_input.idlist_features, self._extra_input.idlist_features]
)
Expand Down
3 changes: 3 additions & 0 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1969,6 +1969,7 @@ def test_sharded_quant_fp_ebc_tw(
inputs = []
for model_input in model_inputs:
kjt = model_input.idlist_features
assert isinstance(kjt, KeyedJaggedTensor)
kjt = kjt.to(local_device)
weights = torch.rand(
kjt._values.size(0), dtype=torch.float, device=local_device
Expand Down Expand Up @@ -2149,6 +2150,7 @@ def test_sharded_quant_mc_ec_rw(
inputs = []
for model_input in model_inputs:
kjt = model_input.idlist_features
assert isinstance(kjt, KeyedJaggedTensor)
kjt = kjt.to(local_device)
weights = None
inputs.append(
Expand Down Expand Up @@ -2285,6 +2287,7 @@ def test_sharded_quant_fp_ebc_tw_meta(self, compute_device: str) -> None:
)
inputs = []
kjt = model_inputs[0].idlist_features
assert isinstance(kjt, KeyedJaggedTensor)
kjt = kjt.to(local_device)
weights = torch.rand(
kjt._values.size(0), dtype=torch.float, device=local_device
Expand Down
12 changes: 10 additions & 2 deletions torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def _gen_pipelines(
default=100,
help="Total number of sparse embeddings to be used.",
)
@click.option(
"--ratio_features_weighted",
default=0.4,
help="percentage of features weighted vs unweighted",
)
@click.option(
"--dim_emb",
type=int,
Expand Down Expand Up @@ -132,6 +137,7 @@ def _gen_pipelines(
def main(
world_size: int,
n_features: int,
ratio_features_weighted: float,
dim_emb: int,
n_batches: int,
batch_size: int,
Expand All @@ -149,8 +155,9 @@ def main(
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())

num_features = n_features // 2
num_weighted_features = n_features // 2
num_weighted_features = int(n_features * ratio_features_weighted)
num_features = n_features - num_weighted_features

tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 1000,
Expand Down Expand Up @@ -257,6 +264,7 @@ def _generate_data(
world_size=world_size,
num_float_features=num_float_features,
pooling_avg=pooling_factor,
input_type=input_type,
)[1]
for i in range(num_batches)
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,11 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None:
optimizer_cpu = optim.SGD(model_cpu.model.parameters(), lr=0.01)
optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01)

data = [i.idlist_features for i in local_model_inputs]
data = [
i.idlist_features
for i in local_model_inputs
if isinstance(i.idlist_features, KeyedJaggedTensor)
]
dataloader = iter(data)
pipeline = TrainPipelinePT2(
model_gpu, optimizer_gpu, self.device, input_transformer=kjt_for_pt2_tracing
Expand Down
1 change: 1 addition & 0 deletions torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def generate_kjt(
randomize_indices=True,
device=device,
)[0]
assert isinstance(global_input.idlist_features, KeyedJaggedTensor)
return global_input.idlist_features


Expand Down
Loading