Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About the training loss #21

Closed
guozhiyao opened this issue Apr 23, 2021 · 4 comments
Closed

About the training loss #21

guozhiyao opened this issue Apr 23, 2021 · 4 comments
Labels

Comments

@guozhiyao
Copy link

During the training process, is the first loss larger than the second loss? But my situation is the opposite.

@guozhiyao
Copy link
Author

I was training a face recognition with SAM (backbone is ResNet, and the loss is arcface). Firstly, the backbone load a pretrianed model, and then train the classifier while freeze the backbone. Finally, I train the whole model with SAM. But something wired happens:

  1. When I freeze the bn running variable as you recommend, which is right, the first loss is larger than second loss. And the feature obtained from backbone will become much larger (up to 10^9) in the second iteration. And the model cannot converge.
  2. When I update the bn running variable, which is wrong. The first loss is smaller than second loss. And the feature obtained from backbone becomes normal. But the model still cannot converge.
    Hope to get your reply, thanks.

@davda54
Copy link
Owner

davda54 commented Apr 26, 2021

Most likely, the BN freezing won't make a significant difference, so I would advise you to not focus on that until you fix the convergence issue. I guess the losses should be of similar magnitude, but I don't see a problem if one is slightly larger than the other one.

Does your model converge with a standard optimizer? Have you tried different hyperparameters?

@guozhiyao
Copy link
Author

My mode will converge with standard optimizer, but not with sam.
I test for a few times, and the loss become a little more normal. Here is my opinion:
When we set the model.eval(), the BN will use the running_mean and running_var to normalize the input, instead of the statistics of current batch data. Which will make the output of second forward different from first forward. So I change the process as follow:

# first time
loss_fn(model(input), label).backward()
optimizer.first_step(zero_grad=True)
# second time
bn_bak = save_bn_running(model) # save the running_mean and running_var
loss_fn(model(input), label).backward()
optimizer.second_step(zero_grad=True)
reset_bn_running(model, bn_bak)

Before the second forward, I will save the running_mean and running_var of BN and set the model to be train mode, so the statistics of BN will be the current, which is constant with first forward, and reset the running_mean and running_var of BN after second backward to avoid modification of BN statistics.

@stale
Copy link

stale bot commented May 18, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the stale label May 18, 2021
@stale stale bot closed this as completed May 27, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants