Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

question about batch implementation of IRM loss #7

Closed
weiHelloWorld opened this issue Jun 1, 2021 · 1 comment
Closed

question about batch implementation of IRM loss #7

weiHelloWorld opened this issue Jun 1, 2021 · 1 comment

Comments

@weiHelloWorld
Copy link

Hi,

Thanks for the great work! I am trying to reproduce some results and have a question regarding batch implementation of IRM loss. In Section 3.2 and Appendix D, you suggest to use following to do batch implementation:

def compute_penalty(losses, dummy_w):
    g1 = grad(losses[0::2].mean(), dummy_w, create_graph=True)[0] 
    g2 = grad(losses[1::2].mean(), dummy_w, create_graph=True)[0]
    return (g1 * g2).sum()

I am wondering whether we can do following:

def compute_penalty(losses, dummy_w):
    g = grad(losses.mean(), dummy_w, create_graph=True)[0] 
    return (g ** 2).sum()

You mentioned that the former one is "unbiased estimate of the squared gradient norm", but I am not sure why it is the case. If you can provide some explanation, that would be great.

Thank you!

@igul222
Copy link

igul222 commented Jun 22, 2021

If X denotes a minibatch gradient, then E[X]^2 is the true squared grad norm (i.e. what we're trying to estimate), and E[X^2] is the "naive" minibatch estimator (i.e. your suggested code). In general, E[X^2] =/= E[X]^2, so there's a bias.

On the other hand, E[X1 * X2] = E[X1]*E[X2] when X1 and X2 are independent. Letting X1 and X2 denote different minibatches directly gives our batch-splitting estimator (Section 3.2). Hope this helps!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants