-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes torch DDP distributed metric computation for AUROC #3234
Conversation
…t for distributed metric computation
ludwig/features/base_feature.py
Outdated
@@ -352,7 +353,10 @@ def get_metrics(self): | |||
try: | |||
metric_vals[metric_name] = get_scalar_from_ludwig_metric(metric_fn) | |||
except Exception as e: | |||
logger.error(f"Caught exception computing metric: {metric_name}. Exception: {e}") | |||
logger.error( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should just be:
logger.exception(f"Caught exception computing metric: {metric_name}")
Then you get the stack trace for free.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, is that true? We're not re-raising the exception, so it seems like it would get swallowed. With the original line of code,
logger.error(f"Caught exception computing metric: {metric_name}. Exception: {e}")
I was only seeing the exception string, but not the stack trace. With the suggested line of code, the exception is not included in the log message, so I would expect it to get swallowed entirely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logger.exception != logger.error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, didn't catch that. Okay will try that. Thanks!
ludwig/utils/horovod_utils.py
Outdated
|
||
# This is to match the output of the torchmetrics gather_all_tensors function | ||
# and ensures that the return value is usable by torchmetrics.compute downstream. | ||
if len(result.shape) >= 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, so my understanding is that DDP allgather contract is like:
tensor[*shape] -> [tensor[*shape]] * world_size
While Horovod is:
tensor[shape] -> tensor[world_size, *shape]
In the previous implementation, we split the tensor along the rank
dimension (dim=0) into [tensor[shape]] * world_size
to match the DDP format. But with this change it seems if the input is a scalar, meaning the allgathered output is a 1D vector, then we just return it as-is (unless it's a bool, in which case we iterate over it and turn it into a list of bool tensors).
Seems like the output format is potentially inconsistent then, right? It could be a list of tensors or it could be a single tensor. Am I missing something.
The other aspect here is the difference between casting to a list vs calling split. Running this in a terminal, I get:
>>> t
tensor([[1., 1., 2.],
[2., 3., 1.],
[5., 2., 4.]])
>>> list(t)
[tensor([1., 1., 2.]), tensor([2., 3., 1.]), tensor([5., 2., 4.])]
>>> t.split(1, dim=0)
(tensor([[1., 1., 2.]]), tensor([[2., 3., 1.]]), tensor([[5., 2., 4.]]))
So it seems like the difference here is that split
preserves the rank of the tensor, while list
removes the dimension that's being split upon entirely. So that does seem to better align with the DDP format.
As such, is the right thing to do here just to change lines 77 and 78 to:
gathered = _HVD.allgather(result)
gathered_result = list(gathered)
In other words, why only do this when len(result.shape) >= 2
?
Otherwise, it makes sense, as I do believe DDP allgather does not add an extra dimension to the tensors being gathered, as the code seemed to be doing previously.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh interesting– so double-checking this, it seems like the guard I've added is pretty much always triggered. I believe this is because of these two lines:
https://github.com/ludwig-ai/ludwig/blob/master/ludwig/utils/horovod_utils.py#L63-L66
https://github.com/ludwig-ai/ludwig/blob/master/ludwig/utils/horovod_utils.py#L73-L75
Tensors always have a rank of at least one because of the first code block, and the second code block ensures that the rank is at least two. I'll remove the guard since it is not necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Should we add this for 0.7.3? |
This PR ensures that AUROC can be computed for the torch DDP strategy. Before this change,
ludwig.modules.metric_modules.BinaryAUROCMetric
was not being instantiated correctly, meaning that the override forBinaryAUROCMetric.update
was not called. This is in turn meant that the target was not being correctly cast into a boolean.The fix was directly subclassing
torchmetrics.classification.BinaryAUROC
. Before this change, we were subclassingtorchmetrics.AUROC
, which overrides the__new__
method and messes up class inheritance.Once class inheritance was working correctly, we unveiled an issue with
horovod_utils.all_gather_tensors
. The issue was previously hidden because of the broken class inheritance issue. This issue meant that we were actually using torchmetrics'all_gather_tensors
function for binary AUROC computation.The proposed fix for
horovod_utils.all_gather_tensors
was to ensure that the returned list of tensors matched the shape of those returned by torchmetrics'all_gather_tensors
.