@@ -30,23 +30,27 @@ class TensorWeightedAvgMetricComputation(RecMetricComputation):
30
30
31
31
It is a sibling to WeightedAvgMetricComputation, but it computes the weighted average of a tensor
32
32
passed in as a required input instead of the predictions tensor.
33
+
34
+ FUSED_TASKS_COMPUTATION:
35
+ This class requires all target tensors from tasks to be stacked together in RecMetrics._update().
36
+ During TensorWeightedAvgMetricComputation.update(), the target tensor is sliced into the correct
33
37
"""
34
38
35
39
def __init__ (
36
40
self ,
37
41
* args : Any ,
38
- tensor_name : Optional [str ] = None ,
39
- weighted : bool = True ,
40
- description : Optional [str ] = None ,
42
+ tasks : List [RecTaskInfo ],
41
43
** kwargs : Any ,
42
44
) -> None :
43
45
super ().__init__ (* args , ** kwargs )
44
- if tensor_name is None :
45
- raise RecMetricException (
46
- f"TensorWeightedAvgMetricComputation expects tensor_name to not be None got { tensor_name } "
47
- )
48
- self .tensor_name : str = tensor_name
49
- self .weighted : bool = weighted
46
+ self .tasks : List [RecTaskInfo ] = tasks
47
+
48
+ for task in self .tasks :
49
+ if task .tensor_name is None :
50
+ raise RecMetricException (
51
+ "TensorWeightedAvgMetricComputation expects all tasks to have tensor_name, but got None."
52
+ )
53
+
50
54
self ._add_state (
51
55
"weighted_sum" ,
52
56
torch .zeros (self ._n_tasks , dtype = torch .double ),
@@ -61,7 +65,13 @@ def __init__(
61
65
dist_reduce_fx = "sum" ,
62
66
persistent = True ,
63
67
)
64
- self ._description = description
68
+
69
+ self .weighted_mask : torch .Tensor = torch .tensor (
70
+ [task .weighted for task in self .tasks ]
71
+ ).unsqueeze (dim = - 1 )
72
+
73
+ if torch .cuda .is_available ():
74
+ self .weighted_mask = self .weighted_mask .cuda ()
65
75
66
76
def update (
67
77
self ,
@@ -71,25 +81,58 @@ def update(
71
81
weights : Optional [torch .Tensor ],
72
82
** kwargs : Dict [str , Any ],
73
83
) -> None :
74
- if (
75
- "required_inputs" not in kwargs
76
- or self . tensor_name not in kwargs [ "required_inputs" ]
77
- ) :
84
+
85
+ target_tensor : torch . Tensor
86
+
87
+ if "required_inputs" not in kwargs :
78
88
raise RecMetricException (
79
- f "TensorWeightedAvgMetricComputation expects { self . tensor_name } in the required_inputs "
89
+ "TensorWeightedAvgMetricComputation expects 'required_inputs' to exist. "
80
90
)
91
+ else :
92
+ if len (self .tasks ) > 1 :
93
+ # In FUSED mode, RecMetric._update() always creates "target_tensor" for the stacked tensor.
94
+ # Note that RecMetric._update() only stacks if the tensor_name exists in kwargs["required_inputs"].
95
+ target_tensor = cast (
96
+ torch .Tensor ,
97
+ kwargs ["required_inputs" ]["target_tensor" ],
98
+ )
99
+ elif len (self .tasks ) == 1 :
100
+ # UNFUSED_TASKS_COMPUTATION
101
+ tensor_name = self .tasks [0 ].tensor_name
102
+ if tensor_name not in kwargs ["required_inputs" ]:
103
+ raise RecMetricException (
104
+ f"TensorWeightedAvgMetricComputation expects required_inputs to contain target tensor { self .tasks [0 ].tensor_name } "
105
+ )
106
+ else :
107
+ target_tensor = cast (
108
+ torch .Tensor ,
109
+ kwargs ["required_inputs" ][tensor_name ],
110
+ )
111
+ else :
112
+ raise RecMetricException (
113
+ "TensorWeightedAvgMetricComputation expects at least one task."
114
+ )
115
+
81
116
num_samples = labels .shape [0 ]
82
- target_tensor = cast (torch .Tensor , kwargs ["required_inputs" ][self .tensor_name ])
83
117
weights = cast (torch .Tensor , weights )
118
+
119
+ # Vectorized computation using masks
120
+ weighted_values = torch .where (
121
+ self .weighted_mask , target_tensor * weights , target_tensor
122
+ )
123
+
124
+ weighted_counts = torch .where (
125
+ self .weighted_mask , weights , torch .ones_like (weights )
126
+ )
127
+
128
+ # Sum across batch dimension to Shape(n_tasks,)
129
+ weighted_sum = weighted_values .sum (dim = - 1 )
130
+ weighted_num_samples = weighted_counts .sum (dim = - 1 )
131
+
132
+ # Update states
84
133
states = {
85
- "weighted_sum" : (
86
- target_tensor * weights if self .weighted else target_tensor
87
- ).sum (dim = - 1 ),
88
- "weighted_num_samples" : (
89
- weights .sum (dim = - 1 )
90
- if self .weighted
91
- else torch .ones (weights .shape ).sum (dim = - 1 ).to (device = weights .device )
92
- ),
134
+ "weighted_sum" : weighted_sum ,
135
+ "weighted_num_samples" : weighted_num_samples ,
93
136
}
94
137
for state_name , state_value in states .items ():
95
138
state = getattr (self , state_name )
@@ -105,7 +148,6 @@ def _compute(self) -> List[MetricComputationReport]:
105
148
cast (torch .Tensor , self .weighted_sum ),
106
149
cast (torch .Tensor , self .weighted_num_samples ),
107
150
),
108
- description = self ._description ,
109
151
),
110
152
MetricComputationReport (
111
153
name = MetricName .WEIGHTED_AVG ,
@@ -114,7 +156,6 @@ def _compute(self) -> List[MetricComputationReport]:
114
156
self .get_window_state ("weighted_sum" ),
115
157
self .get_window_state ("weighted_num_samples" ),
116
158
),
117
- description = self ._description ,
118
159
),
119
160
]
120
161
@@ -126,23 +167,40 @@ class TensorWeightedAvgMetric(RecMetric):
126
167
def _get_task_kwargs (
127
168
self , task_config : Union [RecTaskInfo , List [RecTaskInfo ]]
128
169
) -> Dict [str , Any ]:
129
- if not isinstance (task_config , RecTaskInfo ):
130
- raise RecMetricException (
131
- f"TensorWeightedAvgMetric expects task_config to be RecTaskInfo not { type (task_config )} . Check the FUSED_TASKS_COMPUTATION settings."
132
- )
170
+ all_tasks = (
171
+ [task_config ] if isinstance (task_config , RecTaskInfo ) else task_config
172
+ )
133
173
return {
134
- "tensor_name" : task_config .tensor_name ,
135
- "weighted" : task_config .weighted ,
174
+ "tasks" : all_tasks ,
136
175
}
137
176
138
177
def _get_task_required_inputs (
139
178
self , task_config : Union [RecTaskInfo , List [RecTaskInfo ]]
140
179
) -> Set [str ]:
141
- if not isinstance (task_config , RecTaskInfo ):
142
- raise RecMetricException (
143
- f"TensorWeightedAvgMetric expects task_config to be RecTaskInfo not { type (task_config )} . Check the FUSED_TASKS_COMPUTATION settings."
144
- )
145
- required_inputs = set ()
146
- if task_config .tensor_name is not None :
147
- required_inputs .add (task_config .tensor_name )
148
- return required_inputs
180
+ """
181
+ Returns the required inputs for the task.
182
+
183
+ FUSED_TASKS_COMPUTATION:
184
+ - Given two tasks with the same tensor_name, assume the same tensor reference
185
+ - For a given tensor_name, assume all tasks have the same weighted flag
186
+ """
187
+ all_tasks = (
188
+ [task_config ] if isinstance (task_config , RecTaskInfo ) else task_config
189
+ )
190
+
191
+ required_inputs : dict [str , bool ] = {}
192
+ for task in all_tasks :
193
+ if task .tensor_name is not None :
194
+ if (
195
+ task .tensor_name in required_inputs
196
+ and task .weighted is not required_inputs [task .tensor_name ]
197
+ ):
198
+ existing_weighted_flag = required_inputs [task .tensor_name ]
199
+ raise RecMetricException (
200
+ f"This target tensor was already registered as weighted={ existing_weighted_flag } . "
201
+ f"This target tensor cannot be re-registered with weighted={ task .weighted } "
202
+ )
203
+ else :
204
+ required_inputs [str (task .tensor_name )] = task .weighted
205
+
206
+ return set (required_inputs .keys ())
0 commit comments