You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In Readme.MD, there is typo in Usage.
Below is the code.
Usage
from sam import SAM
model = YourModel()
base_optimizer = torch.optim.SGD # define an optimizer for the "sharpness-aware" update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)
for input, output in data:
loss = loss_function(output, model(input)) # use this loss for any training statistics
loss.backward()
optimizer.first_step(zero_grad=True)
loss_function(output, model(input)).backward() # make sure to do a full forward pass
optimizer.second_step(zero_grad=True)
I think the arguments of loss_function have to be reordered.
First argument of loss_function in pytorch might be output of the model,
Second argument of loss_function in pytorch might be "label".
so I think loss_function(model(input),output) will be the right answer.
Some people like me maybe confused with this order.
Thanks for sharing your great code.
The text was updated successfully, but these errors were encountered:
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
In Readme.MD, there is typo in Usage.
Below is the code.
Usage
I think the arguments of loss_function have to be reordered.
First argument of loss_function in pytorch might be output of the model,
Second argument of loss_function in pytorch might be "label".
so I think loss_function(model(input),output) will be the right answer.
Some people like me maybe confused with this order.
Thanks for sharing your great code.
The text was updated successfully, but these errors were encountered: