Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantizers are not DDP/AMP compliant #10

Closed
danieltudosiu opened this issue Dec 7, 2021 · 7 comments
Closed

Quantizers are not DDP/AMP compliant #10

danieltudosiu opened this issue Dec 7, 2021 · 7 comments

Comments

@danieltudosiu
Copy link

Hi Lucidrains,

Thanks for the amazing work you do by implementing all those papers!

Is there a plan to make the Quantizer be compliant with:

  • DDP - They need an all gather before calculating anything so the updates are exactly the same across all ranks
  • AMP - In my experience, if AMP touches upon the quantizers it screws up the gradient magnitudes making it NaN/Overflow

If you want I can have a go at it.

@lucidrains
Copy link
Owner

@danieltudosiu Hi Daniel! No that would be great! Always welcoming contributors :)

@lucidrains
Copy link
Owner

@danieltudosiu do you want to see if https://github.com/lucidrains/vector-quantize-pytorch/releases/tag/0.4.8 fixes the AMP issue?

@lucidrains
Copy link
Owner

as for DDP, i'm guessing just need an allreduce at these two lines? https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/vector_quantize_pytorch.py#L153-L155

@danieltudosiu
Copy link
Author

danieltudosiu commented Dec 10, 2021

as for DDP, i'm guessing just need an allreduce at these two lines? https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/vector_quantize_pytorch.py#L153-L155

One reduction should happen here but only for the summation of the one hot encoddings (embed_onehot.sum(0)).

And one here for the summation of the embeddings (embed_sum).

@danieltudosiu do you want to see if https://github.com/lucidrains/vector-quantize-pytorch/releases/tag/0.4.8 fixes the AMP issue?

Regarding the AMP part, I am not actively using this codebase since we are close to finishing the project and we have a more barebone implementation ourselves, I was just signalling the issues so after the project I can move to this library ;) .

But from a quick look, I would say it should work. In our case, we have just used the decorator to disable the AMP fully. And given my experience with the VQ logic, I would say it would be a good default (maybe even not giving a chance to enable AMP).

@lucidrains
Copy link
Owner

@danieltudosiu got it! thanks for your input :)

@danieltudosiu
Copy link
Author

@lucidrains just to be clear the all reduce should be something like this:

    import torch.distributed.distributed_c10d as dist

    if dist.is_initialized():
        dist.all_reduce(tensor=encodings_sum, op=dist.ReduceOp.SUM)
        dist.all_reduce(tensor=dw, op=dist.ReduceOp.SUM)

Where encodings_sum is your embed_onehot.sum(0) and dw is your embed_sum.

@lucidrains
Copy link
Owner

@danieltudosiu hey yup! i think sum is by default anyways :)

https://github.com/lucidrains/vector-quantize-pytorch#ddp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants