Skip to content

Commit 995c12d

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
TWA Fused Tasks (#3317)
Summary: TensorWeightedAvgMetric currently does not support FUSED_TASKS computation. With this patch, TWA supports FUSED_TASKS mode Updated unit tests and created new ones for FUSED mode Differential Revision: D77958663
1 parent a29e47a commit 995c12d

File tree

4 files changed

+389
-128
lines changed

4 files changed

+389
-128
lines changed

torchrec/metrics/rec_metric.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,27 @@ def _update(
623623
labels, torch.Tensor
624624
)
625625

626+
# Metrics such as TensorWeightedAvgMetric will have tensors that we also need to stack.
627+
# Stack in task order: (n_tasks, batch_size)
628+
if "required_inputs" in kwargs:
629+
target_tensors: list[torch.Tensor] = []
630+
for task in self._tasks:
631+
if (
632+
task.tensor_name
633+
and task.tensor_name in kwargs["required_inputs"]
634+
):
635+
target_tensors.append(
636+
kwargs["required_inputs"][task.tensor_name]
637+
)
638+
639+
if target_tensors:
640+
stacked_tensor = torch.stack(target_tensors)
641+
642+
# Reshape the stacked_tensor to size([len(self._tasks), self._batch_size])
643+
stacked_tensor = stacked_tensor.view(len(self._tasks), -1)
644+
assert isinstance(stacked_tensor, torch.Tensor)
645+
kwargs["required_inputs"]["target_tensor"] = stacked_tensor
646+
626647
predictions = (
627648
# Reshape the predictions to size([len(self._tasks), self._batch_size])
628649
predictions.view(len(self._tasks), -1)

torchrec/metrics/tensor_weighted_avg.py

Lines changed: 98 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,27 @@ class TensorWeightedAvgMetricComputation(RecMetricComputation):
3030
3131
It is a sibling to WeightedAvgMetricComputation, but it computes the weighted average of a tensor
3232
passed in as a required input instead of the predictions tensor.
33+
34+
FUSED_TASKS_COMPUTATION:
35+
This class requires all target tensors from tasks to be stacked together in RecMetrics._update().
36+
During TensorWeightedAvgMetricComputation.update(), the target tensor is sliced into the correct
3337
"""
3438

3539
def __init__(
3640
self,
3741
*args: Any,
38-
tensor_name: Optional[str] = None,
39-
weighted: bool = True,
40-
description: Optional[str] = None,
42+
tasks: List[RecTaskInfo],
4143
**kwargs: Any,
4244
) -> None:
4345
super().__init__(*args, **kwargs)
44-
if tensor_name is None:
45-
raise RecMetricException(
46-
f"TensorWeightedAvgMetricComputation expects tensor_name to not be None got {tensor_name}"
47-
)
48-
self.tensor_name: str = tensor_name
49-
self.weighted: bool = weighted
46+
self.tasks: List[RecTaskInfo] = tasks
47+
48+
for task in self.tasks:
49+
if task.tensor_name is None:
50+
raise RecMetricException(
51+
"TensorWeightedAvgMetricComputation expects all tasks to have tensor_name, but got None."
52+
)
53+
5054
self._add_state(
5155
"weighted_sum",
5256
torch.zeros(self._n_tasks, dtype=torch.double),
@@ -61,7 +65,13 @@ def __init__(
6165
dist_reduce_fx="sum",
6266
persistent=True,
6367
)
64-
self._description = description
68+
69+
self.weighted_mask: torch.Tensor = torch.tensor(
70+
[task.weighted for task in self.tasks]
71+
).unsqueeze(dim=-1)
72+
73+
if torch.cuda.is_available():
74+
self.weighted_mask = self.weighted_mask.cuda()
6575

6676
def update(
6777
self,
@@ -71,25 +81,58 @@ def update(
7181
weights: Optional[torch.Tensor],
7282
**kwargs: Dict[str, Any],
7383
) -> None:
74-
if (
75-
"required_inputs" not in kwargs
76-
or self.tensor_name not in kwargs["required_inputs"]
77-
):
84+
85+
target_tensor: torch.Tensor
86+
87+
if "required_inputs" not in kwargs:
7888
raise RecMetricException(
79-
f"TensorWeightedAvgMetricComputation expects {self.tensor_name} in the required_inputs"
89+
"TensorWeightedAvgMetricComputation expects 'required_inputs' to exist."
8090
)
91+
else:
92+
if len(self.tasks) > 1:
93+
# In FUSED mode, RecMetric._update() always creates "target_tensor" for the stacked tensor.
94+
# Note that RecMetric._update() only stacks if the tensor_name exists in kwargs["required_inputs"].
95+
target_tensor = cast(
96+
torch.Tensor,
97+
kwargs["required_inputs"]["target_tensor"],
98+
)
99+
elif len(self.tasks) == 1:
100+
# UNFUSED_TASKS_COMPUTATION
101+
tensor_name = self.tasks[0].tensor_name
102+
if tensor_name not in kwargs["required_inputs"]:
103+
raise RecMetricException(
104+
f"TensorWeightedAvgMetricComputation expects required_inputs to contain target tensor {self.tasks[0].tensor_name}"
105+
)
106+
else:
107+
target_tensor = cast(
108+
torch.Tensor,
109+
kwargs["required_inputs"][tensor_name],
110+
)
111+
else:
112+
raise RecMetricException(
113+
"TensorWeightedAvgMetricComputation expects at least one task."
114+
)
115+
81116
num_samples = labels.shape[0]
82-
target_tensor = cast(torch.Tensor, kwargs["required_inputs"][self.tensor_name])
83117
weights = cast(torch.Tensor, weights)
118+
119+
# Vectorized computation using masks
120+
weighted_values = torch.where(
121+
self.weighted_mask, target_tensor * weights, target_tensor
122+
)
123+
124+
weighted_counts = torch.where(
125+
self.weighted_mask, weights, torch.ones_like(weights)
126+
)
127+
128+
# Sum across batch dimension to Shape(n_tasks,)
129+
weighted_sum = weighted_values.sum(dim=-1)
130+
weighted_num_samples = weighted_counts.sum(dim=-1)
131+
132+
# Update states
84133
states = {
85-
"weighted_sum": (
86-
target_tensor * weights if self.weighted else target_tensor
87-
).sum(dim=-1),
88-
"weighted_num_samples": (
89-
weights.sum(dim=-1)
90-
if self.weighted
91-
else torch.ones(weights.shape).sum(dim=-1).to(device=weights.device)
92-
),
134+
"weighted_sum": weighted_sum,
135+
"weighted_num_samples": weighted_num_samples,
93136
}
94137
for state_name, state_value in states.items():
95138
state = getattr(self, state_name)
@@ -105,7 +148,6 @@ def _compute(self) -> List[MetricComputationReport]:
105148
cast(torch.Tensor, self.weighted_sum),
106149
cast(torch.Tensor, self.weighted_num_samples),
107150
),
108-
description=self._description,
109151
),
110152
MetricComputationReport(
111153
name=MetricName.WEIGHTED_AVG,
@@ -114,7 +156,6 @@ def _compute(self) -> List[MetricComputationReport]:
114156
self.get_window_state("weighted_sum"),
115157
self.get_window_state("weighted_num_samples"),
116158
),
117-
description=self._description,
118159
),
119160
]
120161

@@ -126,23 +167,40 @@ class TensorWeightedAvgMetric(RecMetric):
126167
def _get_task_kwargs(
127168
self, task_config: Union[RecTaskInfo, List[RecTaskInfo]]
128169
) -> Dict[str, Any]:
129-
if not isinstance(task_config, RecTaskInfo):
130-
raise RecMetricException(
131-
f"TensorWeightedAvgMetric expects task_config to be RecTaskInfo not {type(task_config)}. Check the FUSED_TASKS_COMPUTATION settings."
132-
)
170+
all_tasks = (
171+
[task_config] if isinstance(task_config, RecTaskInfo) else task_config
172+
)
133173
return {
134-
"tensor_name": task_config.tensor_name,
135-
"weighted": task_config.weighted,
174+
"tasks": all_tasks,
136175
}
137176

138177
def _get_task_required_inputs(
139178
self, task_config: Union[RecTaskInfo, List[RecTaskInfo]]
140179
) -> Set[str]:
141-
if not isinstance(task_config, RecTaskInfo):
142-
raise RecMetricException(
143-
f"TensorWeightedAvgMetric expects task_config to be RecTaskInfo not {type(task_config)}. Check the FUSED_TASKS_COMPUTATION settings."
144-
)
145-
required_inputs = set()
146-
if task_config.tensor_name is not None:
147-
required_inputs.add(task_config.tensor_name)
148-
return required_inputs
180+
"""
181+
Returns the required inputs for the task.
182+
183+
FUSED_TASKS_COMPUTATION:
184+
- Given two tasks with the same tensor_name, assume the same tensor reference
185+
- For a given tensor_name, assume all tasks have the same weighted flag
186+
"""
187+
all_tasks = (
188+
[task_config] if isinstance(task_config, RecTaskInfo) else task_config
189+
)
190+
191+
required_inputs: dict[str, bool] = {}
192+
for task in all_tasks:
193+
if task.tensor_name is not None:
194+
if (
195+
task.tensor_name in required_inputs
196+
and task.weighted is not required_inputs[task.tensor_name]
197+
):
198+
existing_weighted_flag = required_inputs[task.tensor_name]
199+
raise RecMetricException(
200+
f"This target tensor was already registered as weighted={existing_weighted_flag}. "
201+
f"This target tensor cannot be re-registered with weighted={task.weighted}"
202+
)
203+
else:
204+
required_inputs[str(task.tensor_name)] = task.weighted
205+
206+
return set(required_inputs.keys())

torchrec/metrics/test_utils/__init__.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
Dict[str, torch.Tensor],
3232
]
3333

34+
cpu_device: torch.device = torch.device("cpu")
35+
3436

3537
def gen_test_batch(
3638
batch_size: int,
@@ -45,6 +47,7 @@ def gen_test_batch(
4547
mask: Optional[torch.Tensor] = None,
4648
n_classes: Optional[int] = None,
4749
seed: Optional[int] = None,
50+
device: torch.device = cpu_device,
4851
) -> Dict[str, torch.Tensor]:
4952
if seed is not None:
5053
torch.manual_seed(seed)
@@ -65,14 +68,14 @@ def gen_test_batch(
6568
else:
6669
weight = torch.rand(batch_size, dtype=torch.double)
6770
test_batch = {
68-
label_name: label,
69-
prediction_name: prediction,
70-
weight_name: weight,
71-
tensor_name: torch.rand(batch_size, dtype=torch.double),
71+
label_name: label.to(device),
72+
prediction_name: prediction.to(device),
73+
weight_name: weight.to(device),
74+
tensor_name: torch.rand(batch_size, dtype=torch.double).to(device),
7275
}
7376
if mask_tensor_name is not None:
7477
if mask is None:
75-
mask = torch.ones(batch_size, dtype=torch.double)
78+
mask = torch.ones(batch_size, dtype=torch.double).to(device)
7679
test_batch[mask_tensor_name] = mask
7780

7881
return test_batch
@@ -240,6 +243,7 @@ def rec_metric_value_test_helper(
240243
n_classes: Optional[int] = None,
241244
zero_weights: bool = False,
242245
zero_labels: bool = False,
246+
device: torch.device = cpu_device,
243247
**kwargs: Any,
244248
) -> Tuple[Dict[str, torch.Tensor], Tuple[Dict[str, torch.Tensor], ...]]:
245249
tasks = gen_test_tasks(task_names)
@@ -263,6 +267,7 @@ def rec_metric_value_test_helper(
263267
n_classes=n_classes,
264268
weight_value=weight_value,
265269
label_value=label_value,
270+
device=device,
266271
)
267272
for task in tasks
268273
]
@@ -293,7 +298,8 @@ def get_target_rec_metric_value(
293298
compute_on_all_ranks=compute_on_all_ranks,
294299
should_validate_update=should_validate_update,
295300
**kwargs,
296-
)
301+
).to(device)
302+
297303
for i in range(nsteps):
298304
# Get required_inputs_list from the target metric
299305
required_inputs_list = list(target_metric_obj.get_required_inputs())
@@ -381,6 +387,7 @@ def rec_metric_gpu_sync_test_launcher(
381387
entry_point: Callable[..., None],
382388
batch_size: int = BATCH_SIZE,
383389
batch_window_size: int = BATCH_WINDOW_SIZE,
390+
device: torch.device = cpu_device,
384391
**kwargs: Dict[str, Any],
385392
) -> None:
386393
with tempfile.TemporaryDirectory() as tmpdir:
@@ -402,6 +409,8 @@ def rec_metric_gpu_sync_test_launcher(
402409
batch_size,
403410
batch_window_size,
404411
kwargs.get("n_classes", None),
412+
False,
413+
device,
405414
)
406415

407416

@@ -419,6 +428,7 @@ def sync_test_helper(
419428
batch_window_size: int = BATCH_WINDOW_SIZE,
420429
n_classes: Optional[int] = None,
421430
zero_weights: bool = False,
431+
device: torch.device = cpu_device,
422432
**kwargs: Dict[str, Any],
423433
) -> None:
424434
rank = int(os.environ["RANK"])
@@ -444,7 +454,7 @@ def sync_test_helper(
444454
window_size=batch_window_size * world_size,
445455
# pyre-ignore[6]: Incompatible parameter type
446456
**kwargs,
447-
)
457+
).to(device)
448458

449459
weight_value: Optional[torch.Tensor] = None
450460

@@ -458,6 +468,7 @@ def sync_test_helper(
458468
n_classes=n_classes,
459469
weight_value=weight_value,
460470
seed=42, # we set seed because of how test metric places tensors on ranks
471+
device=device,
461472
)
462473
for task in tasks
463474
]
@@ -575,6 +586,7 @@ def rec_metric_value_test_launcher(
575586
n_classes: Optional[int] = None,
576587
zero_weights: bool = False,
577588
zero_labels: bool = False,
589+
device: torch.device = cpu_device,
578590
**kwargs: Any,
579591
) -> None:
580592
with tempfile.TemporaryDirectory() as tmpdir:
@@ -600,9 +612,13 @@ def rec_metric_value_test_launcher(
600612
n_classes=n_classes,
601613
zero_weights=zero_weights,
602614
zero_labels=zero_labels,
615+
device=device,
603616
**kwargs,
604617
)
605618

619+
is_time_dependent = kwargs.pop("is_time_dependent", False)
620+
time_dependent_metric = kwargs.pop("time_dependent_metric", None)
621+
606622
pet.elastic_launch(lc, entrypoint=entry_point)(
607623
target_clazz,
608624
target_compute_mode,
@@ -616,6 +632,9 @@ def rec_metric_value_test_launcher(
616632
n_classes,
617633
test_nsteps,
618634
zero_weights,
635+
is_time_dependent,
636+
time_dependent_metric,
637+
device,
619638
)
620639

621640

@@ -644,6 +663,7 @@ def metric_test_helper(
644663
zero_weights: bool = False,
645664
is_time_dependent: bool = False,
646665
time_dependent_metric: Optional[Dict[Type[RecMetric], str]] = None,
666+
device: torch.device = cpu_device,
647667
**kwargs: Any,
648668
) -> None:
649669
rank = int(os.environ["RANK"])
@@ -670,6 +690,7 @@ def metric_test_helper(
670690
is_time_dependent=is_time_dependent,
671691
time_dependent_metric=time_dependent_metric,
672692
zero_weights=zero_weights,
693+
device=device,
673694
**kwargs,
674695
)
675696

0 commit comments

Comments
 (0)