Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clip_grad_norm_ from fairscale downcasts to bf16 before all reduce #1092

Open
glample opened this issue Nov 2, 2022 · 3 comments
Open

clip_grad_norm_ from fairscale downcasts to bf16 before all reduce #1092

glample opened this issue Nov 2, 2022 · 3 comments

Comments

@glample
Copy link

glample commented Nov 2, 2022

Copied from: https://github.com/fairinternal/xlformers/issues/117

Shouldn't we remove the .to(dtype=parameters[0].dtype) from this line?

local_norm = torch.norm(torch.stack([torch.norm(par.grad.detach(), p, dtype=torch.float32) for par in parameters]), p).to(dtype=parameters[0].dtype) # type: ignore

It seems weird (and it results in inaccuracies) to convert partial gradient norms to fp16/bf16 before summing them.

Context:

We use:

which calculates grad norms via:

def calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tensor:

which downcasts to param dtype via:

local_norm = torch.norm(torch.stack([torch.norm(par.grad.detach(), p, dtype=torch.float32) for par in parameters]), p).to(dtype=parameters[0].dtype) # type: ignore

before the allreduce:

dist.all_reduce(total_norm, group=self.process_group)

Spotted from looking at how unusually even grad norms look at each training step:

"g_norm": 5.6875
"g_norm": 11.1875
"g_norm": 23.0
"g_norm": 45.25
"g_norm": 89.5
"g_norm": 176.0
"g_norm": 360.0
"g_norm": 704.0
"g_norm": 720.0
"g_norm": 724.0
"g_norm": 728.0
"g_norm": 716.0
"g_norm": 724.0
"g_norm": 728.0
"g_norm": 752.0
"g_norm": 736.0
"g_norm": 728.0
"g_norm": 728.0
"g_norm": 736.0
"g_norm": 728.0
"g_norm": 728.0
"g_norm": 724.0
"g_norm": 724.0
"g_norm": 724.0
"g_norm": 732.0
"g_norm": 764.0
"g_norm": 720.0
"g_norm": 728.0
"g_norm": 728.0
"g_norm": 740.0
"g_norm": 732.0
"g_norm": 736.0
"g_norm": 704.0
"g_norm": 700.0
"g_norm": 728.0
"g_norm": 740.0
"g_norm": 724.0
"g_norm": 752.0
"g_norm": 712.0
"g_norm": 716.0
"g_norm": 724.0
"g_norm": 744.0
"g_norm": 728.0
"g_norm": 736.0
"g_norm": 720.0
"g_norm": 716.0
"g_norm": 724.0
"g_norm": 716.0
"g_norm": 720.0
"g_norm": 712.0
"g_norm": 744.0
"g_norm": 724.0
"g_norm": 708.0
"g_norm": 708.0
"g_norm": 716.0
"g_norm": 704.0
"g_norm": 712.0
"g_norm": 724.0
"g_norm": 708.0
"g_norm": 708.0
"g_norm": 728.0
"g_norm": 720.0
"g_norm": 724.0
"g_norm": 716.0
"g_norm": 712.0
"g_norm": 704.0
"g_norm": 700.0
"g_norm": 688.0
"g_norm": 692.0
"g_norm": 696.0
"g_norm": 732.0
"g_norm": 620.0
"g_norm": 1168.0
"g_norm": 1152.0
"g_norm": 1144.0
"g_norm": 1112.0
"g_norm": 1128.0
"g_norm": 1136.0
"g_norm": 1128.0
"g_norm": 1128.0
"g_norm": 1104.0
"g_norm": 1112.0
"g_norm": 1088.0
"g_norm": 1112.0
"g_norm": 1112.0
"g_norm": 1120.0
"g_norm": 1112.0
"g_norm": 1064.0
"g_norm": 1040.0
"g_norm": 1024.0
"g_norm": 1056.0
"g_norm": 1032.0
"g_norm": 1032.0
"g_norm": 1024.0
"g_norm": 1048.0
"g_norm": 1016.0
"g_norm": 1040.0
"g_norm": 1016.0
"g_norm": 936.0
"g_norm": 828.0
"g_norm": 764.0
"g_norm": 732.0
"g_norm": 692.0
"g_norm": 676.0
"g_norm": 1376.0
"g_norm": 1360.0
"g_norm": 1328.0
"g_norm": 1360.0
"g_norm": 1360.0
"g_norm": 1312.0
"g_norm": 1328.0
"g_norm": 1264.0
"g_norm": 1304.0
"g_norm": 1280.0
"g_norm": 1296.0
"g_norm": 1224.0
"g_norm": 1256.0
"g_norm": 1264.0
"g_norm": 1224.0
"g_norm": 1152.0
"g_norm": 1160.0
"g_norm": 1184.0
"g_norm": 1184.0
"g_norm": 1144.0
"g_norm": 1128.0
"g_norm": 1112.0
"g_norm": 1080.0
"g_norm": 1072.0
"g_norm": 1048.0
"g_norm": 1040.0
"g_norm": 1040.0
"g_norm": 1072.0
"g_norm": 1032.0
"g_norm": 1024.0
"g_norm": 996.0
"g_norm": 976.0
"g_norm": 988.0
"g_norm": 976.0
"g_norm": 956.0
"g_norm": 988.0
"g_norm": 944.0
"g_norm": 924.0
"g_norm": 924.0
"g_norm": 904.0
"g_norm": 1840.0
"g_norm": 1872.0
"g_norm": 1816.0
"g_norm": 1760.0
"g_norm": 1752.0
"g_norm": 1808.0
@min-xu-ai
Copy link
Contributor

Yeah, it does seem like an unnecessary thing to do. It seems to be from this commit:

8dc2030

@blefaudeux, do you remember why it was done in the first place? If not, we can try removing it. @glample feel free to send a PR.

@blefaudeux
Copy link
Contributor

blefaudeux commented Nov 12, 2022

Yeah, it does seem like an unnecessary thing to do. It seems to be from this commit:

8dc2030

@blefaudeux, do you remember why it was done in the first place? If not, we can try removing it. @glample feel free to send a PR.

hey there, seeing this a bit late, no context from me really I guess that it was a type "fix" at some point

edit: I don´t really understand your link @min-xu-ai, the linked commit made sure that the norm was computed in fp32 locally (even if the type fp16 for instance), but this is not what @glample is suggesting here, right ? I'm a bit lost with this PR title

edit2: ok, so the commit you point to introduced both, upcast + cast back to the original type, I agree that the cast back can be delayed if it helps any operation, it's not crucial here.

@min-xu-ai
Copy link
Contributor

edit2: ok, so the commit you point to introduced both, upcast + cast back to the original type, I agree that the cast back can be delayed if it helps any operation, it's not crucial here.

Yes, that's what I meant. Thanks a lot for the context, Ben!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants