Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the gradient of all parameters is None #283

Closed
nankepan opened this issue Apr 16, 2024 · 5 comments
Closed

the gradient of all parameters is None #283

nankepan opened this issue Apr 16, 2024 · 5 comments
Assignees
Labels
help wanted Extra attention is needed

Comments

@nankepan
Copy link

image
Hi,
I print param.grad here and find that the gradient of all parameters is None. Is this caused by using colorsalai? How can I obtain the gradient of parameters? Thank you.

@JThh
Copy link
Collaborator

JThh commented Apr 16, 2024

It should only be None after optimizer.zero_grad(); booster.backward was doing torch.optim.Optimizer.backward(loss). Would you mind printing the contents of loss to see if it is NaN?

@nankepan
Copy link
Author

nankepan commented Apr 17, 2024

It should only be None after optimizer.zero_grad(); booster.backward was doing torch.optim.Optimizer.backward(loss). Would you mind printing the contents of loss to see if it is NaN?

Thanks for reply. loss is normal, but gradient is None before optimizer.zero_grad(), which is strange.
I trained the model, loss was steadily decreasing and the model performance was also improving. But the gradient None makes me confused.

@zhengzangw
Copy link
Collaborator

This is because Colossalai manages the gradients, so you cannot directly access them by param.grad. @ver217 Could you please help with this?

@zhengzangw zhengzangw added the help wanted Extra attention is needed label May 10, 2024
@ver217
Copy link
Member

ver217 commented Jun 24, 2024

Hi, gradients is managed in zero optimizer and p.grad is None. This is expected behavior. If you want to check grads manually, you can refer to https://github.com/hpcaitech/ColossalAI/blob/7f8b16635b42013b73e1cb1ffdebc07b4d71ac93/tests/test_zero/test_low_level/test_zero1_2.py#L164
Note that the grad is sharded and flat.

@281LinChenjian
Copy link

I also need to extract p.grad for subsequent calculations. Is there any way to get p.grad correctly? I have read the above code but still don't know how to do it.

Hi, gradients is managed in zero optimizer and p.grad is None. This is expected behavior. If you want to check grads manually, you can refer to https://github.com/hpcaitech/ColossalAI/blob/7f8b16635b42013b73e1cb1ffdebc07b4d71ac93/tests/test_zero/test_low_level/test_zero1_2.py#L164 Note that the grad is sharded and flat.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

5 participants