Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change order of updates in metric forward to increase efficiency #126

Closed
janvainer opened this issue Mar 24, 2021 · 16 comments
Closed

Change order of updates in metric forward to increase efficiency #126

janvainer opened this issue Mar 24, 2021 · 16 comments
Assignees
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@janvainer
Copy link

janvainer commented Mar 24, 2021

馃殌 Feature

Refactor Metric.forward() to call update only once.

Motivation

The update() method in Metric gets computed twice in forward() in case the compute_on_step is True.
This means repeated computation, which can slow down execution. For example, I have a custom SmoothL1Metric and the update function calculates element-wise L1 distance (see below). The problem arisesd when the tensors on which the metric is computed have many dimensions and the computation itself is slow.

class SmoothL1Metric(Metric):
    def __init__(self, mask_dim, dist_sync_on_step: bool = False, compute_on_step: bool = True):
        super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=compute_on_step)
        self.loss = torch.nn.SmoothL1Loss(reduction="sum")
        self.mask_dim = mask_dim

        self.add_state("sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("numel", default=torch.tensor(0.0), dist_reduce_fx="sum")

    def update(self, input, target, lens):
        mask = get_mask(input, lens, self.mask_dim).type(input.dtype)
        # this is a heavy computation that should not be executed twice
        self.sum += self.loss(input * mask, target * mask)
        self.numel += mask.sum()

    def compute(self):
        return self.sum / self.numel

Suggestion

How about something like:

def forward(self, *args, **kwargs):

    if self.compute_on_step:
        self._to_sync = self.dist_sync_on_step

        # save context before switch
        cache = {attr: getattr(self, attr) for attr in self._defaults.keys()}

        # call reset, update, compute, on single batch
        self.reset()
        self.update(*args, **kwargs)
        self._forward_cache = self.compute()

        # merge new and old context without recomputing update
        for attr, val in cache.items():
            setattr(self, attr, self._reductions[attr](val, getattr(self, attr)))
    else:
        with torch.no_grad():
            self.update(*args, **kwargs)
        self._forward_cache = None

    return self._forward_cache

The code probably does not work now, but the idea should be clear. What do you think?

@janvainer janvainer added enhancement New feature or request help wanted Extra attention is needed labels Mar 24, 2021
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@Borda
Copy link
Member

Borda commented Mar 24, 2021

@PyTorchLightning/core-metrics thoughts?

@maximsch2
Copy link
Contributor

This makes sense to me. Another nit/question is that self._forward_cache seems like it doesn't necessarily belong on self as it can also be a heavy object we don't want to store long-term. I would suggest even something like this:

def forward(self, *args, **kwargs):

    if not self.compute_on_step:
        with torch.no_grad():
            self.update(*args, **kwargs)
         return

    self._to_sync = self.dist_sync_on_step # why are we resetting this on every forward btw?

    # save context before switch
    cache = {attr: getattr(self, attr) for attr in self._defaults.keys()}

    # call reset, update, compute, on single batch
    self.reset()
    self.update(*args, **kwargs)
    result = self.compute()

    # merge new and old context without recomputing update
    for attr, val in cache.items():
        setattr(self, attr, self._reductions[attr](val, getattr(self, attr)))

    return result

@janvainer
Copy link
Author

@maximsch2 I think the self._forward_cache is used in tests.

@SkafteNicki
Copy link
Member

self._forward_cache is also used for internally logging in lightning, see this file:
https://github.com/PyTorchLightning/pytorch-lightning/blob/555a6fea212e340f0b3a9684829e6027e4ba27c0/pytorch_lightning/core/step_result.py#L303
so that cannot be removed.
@janvainer I really like this suggestion. I cannot completely comprehend if changing will work for all metrics (I am sure it will work for the majority). Could you try locally and see if you can get all tests passing with this change?

@janvainer
Copy link
Author

Yes, I will make a PR :) But probably not today, but on the weekend.

@janvainer
Copy link
Author

@SkafteNicki the current code is a bit problematic when it comes to the use of dist_reduce_fx argument.
The reduction function should mimic what happens in the update function. But the current framework does not enforce that.
So for example, in the tests, the DummyMetric uses dist_reduce_fx=None, but in update there is self.x += x, so the reduction for distributed use and update are not aligned. This is problematic if we want to use only a single update call in forward, because we don't know how what reduction is used in update. Check out the draft PR #141.

For this to work fine, the update method should calculate a single-step value of the metric and the internal state should be updated with a provided reduction function. What do you think? I am afraid this suggestion requires an api change, which will probably not be approved. Maybe there is some other way how to accomplish the suggestion in this issue, but I don't see it yet. Any ideas?

@SkafteNicki
Copy link
Member

@janvainer i agree that this is problemation. I am pretty sure that dist_reduce_fx is aligned with the operations for all "real" implemented metrics (else test should fail) and the dummy metric is a special case where we have not aligned. With that in mind we could make the change.

However, it is troublesome that we do not have a explicit check. To get this correct, we would need a API change, with the most simple being that instead of the user has to do the operations in the update call, they should instead return them and then we call dist_reduce_fx. Something like:

# accuracy example
def __init__(self, ...)
    self.add_state("correct", torch.tensor(0), dist_reduce_fx = 'sum')
    self.add_state("nobs", torch.tensor(0), dist_reduce_fx = 'sum')

def update(self, preds, target):
    equal = (preds==target).sum()
    n_obs = preds.numel()
    return equal, n_obs # need to return in same order as states are initilized

def _wrap_update(self, update):

    @functools.wraps(update)
    def wrapped_func(*args, **kwargs):
        self._computed = None
        out = update(*args, **kwargs)
        for state, step_val in zip(self._defaults.keys(), out):
            setattr(state, self._reductions[state]([getattr(self, state), step_val]))
        return None
    return wrapped_func

@janvainer
Copy link
Author

Thanks, what you are suggesting makes sense to me. This looks relatively nice :) The question is how big api changer this would be for existing users? Should I implement this?

def __init__(...):
    self.add_state("l1_distances", torch.tensor(0.0), dist_reduce_fx="sum")
    self.add_state("numel", torch.tensor(0), dist_reduce_fx="sum")

def update(self, prediction, target, mask):
    distances = l1_dist(prediction, target)
    numel = mask.sum()
    return distances, numel

def forward(self, *args, **kwargs):
    ...
    update_results = self.update(*args, **kwargs)
    # Update old state with new results with reduction functions
    ...
    # return single-step metric value (potentially differentiable scalar)
    return self.compute(*update_results)

@staticmethod
def compute(sum_of_distances, total_numel):
    return sum_of_distances / total_numel

@janvainer
Copy link
Author

Ok I like this solution. The api change seems quite big though. Is it ok?

@SkafteNicki
Copy link
Member

@janvainer as this is a fundamental API change we need more input on it.
@PyTorchLightning/core-metrics any opinions?

@maximsch2
Copy link
Contributor

One potential issue here is that this prevents in-place updates, right?

In https://github.com/PyTorchLightning/metrics/pull/128/files#diff-a605698e7c4a7849117d5d944263ea2218cc58795426cd2c98165794dc31365eR70-R74 I'm going over each column separately to avoid constructing full matrix as it OOMs for us otherwise (very large number of classes).

@SkafteNicki
Copy link
Member

@janvainer after some going back and forth internally, here is an API change that suits the needs without breaking backwards compatability:

new inplace arg

Basic idea is that the base Metric class will get a new fifth argument called inplace that as default is True. The idea is that the flag will indicate if metric states are updated in-place inside the update method or instead will be returned from the update method and we will internally do the reduction using some variation of:

out = update(*args, **kwargs)
for state, step_val in zip(self._defaults.keys(), out):
    setattr(state, self._reductions[state]([getattr(self, state), step_val]))

if the flag is True nothing about the current code should change such that it is still backward compatible. The forward method should then look something like:

def forward(self, *args, **kwargs):
    if not self.inplace:
        return self.fast_forward(*args, **kwargs)
    # insert whatever is in forward now

Positive:

  • faster competition for metrics which can be calculated with inplace=False
  • still support for metrics that require in-place updates such as @maximsch2 mentions
  • backwards compatible with all users custom metrics

Negative:

  • more complex code base

Initial PR should implement the API changes in Metric and maybe redo a single metric such as Accuracy using the new faster API. Then in follow up PRs we can begin changing the remaining metrics.

@janvainer
Copy link
Author

Hi @SkafteNicki, thanks for the follow-up. The suggested api changes make sense to me! I unfortunately do not have the capacity to look into it in the upcoming weeks, so feel free to reassign if anyone wants to pick this up (it should be possible to continue from my draft PR). I will be able to work on this some time in June I think.

@SkafteNicki
Copy link
Member

@janvainer thanks for letting us know. I am going to un-assign you and also close the PR you have created, then we will see if someone feels like picking it up. Else feel free to ping me if you find the time to contribute again :]
Thanks!

@stale
Copy link

stale bot commented Jun 25, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Jun 25, 2021
@stale stale bot closed this as completed Jul 4, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

4 participants