Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

All2All precision always in fp32 #195

Open
vchiley opened this issue Feb 21, 2023 · 1 comment
Open

All2All precision always in fp32 #195

vchiley opened this issue Feb 21, 2023 · 1 comment

Comments

@vchiley
Copy link
Contributor

vchiley commented Feb 21, 2023

This example shows how to use TutelMoE with Torch autocast amp.

Q: Is the All2All precision still meant to be done in FP32?

In general torch autocast amp keeps network master weights in FP32 and downcasts weights before a layers fwd pass.
So within an autocast context (as suggested here) the matmul here will be autocast to fp16
(Note torch.add is done in fp32; list of ops which are autocast; by default ops upcast to the input with highest precision; since batched_fc1_bias is fp32, the add will be done in fp32 and will output an fp32 answer)
So far everything is just standard torch.

My question is about these few lines of code. Since expert weights are in fp32, this will upcast input x to type fp32.
As a result the All2All communication is done using fp32 inputs.
Is this correct or am I missing some other cast?
(Note: in the cast at the end of this line, x has already been case to fp32).

It looks like the all2all is ALWAYS done using fp32 precision even if we are using an amp autocast context manager. Was this done deliberately or is this a bug? It seems like if the all2all is done using 16 bits we'd save 2x the BW.

Final note: as mentioned, if in the autocast context manager, although the all2all is done using fp32, autocast is still on and therefore the matmul's here are done using fp16.

Potential Bug: I'm not sure this does anything... That layer should already be in fp32 and when its run here autocast should still run it in fp16...
I THINK the right way to do this is something like this:

    def forward(self, x):
        if self.fp32_gate:
            x = x.float()
        with torch.autocast(device_type=x.device.type, enabled=not self.fp32_gate):
            out = self.wg(x)
            return out

Autocast is disabled here.
I'm pretty sure the suggested rewrite for gate autocast is correct and more understandable.

@vchiley
Copy link
Contributor Author

vchiley commented Feb 21, 2023

In general it seems as though if the input is not explicitly cast here (ie we comment out those lines)
and the input to the MoE layer is in fp16, then the the all2all input is fp16 (and I'm assuming the all2all will be done in fp16); if the input to the MoE is in fp32 then the input to the all2all is in fp32 (and I'm assuming the all2all will be done in fp16)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant