-
Notifications
You must be signed in to change notification settings - Fork 391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change order of updates in metric forward to increase efficiency #126
Comments
Hi! thanks for your contribution!, great first issue! |
@PyTorchLightning/core-metrics thoughts? |
This makes sense to me. Another nit/question is that def forward(self, *args, **kwargs):
if not self.compute_on_step:
with torch.no_grad():
self.update(*args, **kwargs)
return
self._to_sync = self.dist_sync_on_step # why are we resetting this on every forward btw?
# save context before switch
cache = {attr: getattr(self, attr) for attr in self._defaults.keys()}
# call reset, update, compute, on single batch
self.reset()
self.update(*args, **kwargs)
result = self.compute()
# merge new and old context without recomputing update
for attr, val in cache.items():
setattr(self, attr, self._reductions[attr](val, getattr(self, attr)))
return result
|
@maximsch2 I think the |
|
Yes, I will make a PR :) But probably not today, but on the weekend. |
@SkafteNicki the current code is a bit problematic when it comes to the use of For this to work fine, the |
@janvainer i agree that this is problemation. I am pretty sure that However, it is troublesome that we do not have a explicit check. To get this correct, we would need a API change, with the most simple being that instead of the user has to do the operations in the # accuracy example
def __init__(self, ...)
self.add_state("correct", torch.tensor(0), dist_reduce_fx = 'sum')
self.add_state("nobs", torch.tensor(0), dist_reduce_fx = 'sum')
def update(self, preds, target):
equal = (preds==target).sum()
n_obs = preds.numel()
return equal, n_obs # need to return in same order as states are initilized
def _wrap_update(self, update):
@functools.wraps(update)
def wrapped_func(*args, **kwargs):
self._computed = None
out = update(*args, **kwargs)
for state, step_val in zip(self._defaults.keys(), out):
setattr(state, self._reductions[state]([getattr(self, state), step_val]))
return None
return wrapped_func |
Thanks, what you are suggesting makes sense to me. This looks relatively nice :) The question is how big api changer this would be for existing users? Should I implement this?
|
Ok I like this solution. The api change seems quite big though. Is it ok? |
@janvainer as this is a fundamental API change we need more input on it. |
One potential issue here is that this prevents in-place updates, right? In https://github.com/PyTorchLightning/metrics/pull/128/files#diff-a605698e7c4a7849117d5d944263ea2218cc58795426cd2c98165794dc31365eR70-R74 I'm going over each column separately to avoid constructing full matrix as it OOMs for us otherwise (very large number of classes). |
@janvainer after some going back and forth internally, here is an API change that suits the needs without breaking backwards compatability: new
|
Hi @SkafteNicki, thanks for the follow-up. The suggested api changes make sense to me! I unfortunately do not have the capacity to look into it in the upcoming weeks, so feel free to reassign if anyone wants to pick this up (it should be possible to continue from my draft PR). I will be able to work on this some time in June I think. |
@janvainer thanks for letting us know. I am going to un-assign you and also close the PR you have created, then we will see if someone feels like picking it up. Else feel free to ping me if you find the time to contribute again :] |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
馃殌 Feature
Refactor
Metric.forward()
to callupdate
only once.Motivation
The
update()
method inMetric
gets computed twice inforward()
in case thecompute_on_step
is True.This means repeated computation, which can slow down execution. For example, I have a custom
SmoothL1Metric
and theupdate
function calculates element-wiseL1
distance (see below). The problem arisesd when the tensors on which the metric is computed have many dimensions and the computation itself is slow.Suggestion
How about something like:
The code probably does not work now, but the idea should be clear. What do you think?
The text was updated successfully, but these errors were encountered: