Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using dataloader with fixed batch size #7

Closed
xplip opened this issue Nov 28, 2021 · 4 comments
Closed

Using dataloader with fixed batch size #7

xplip opened this issue Nov 28, 2021 · 4 comments

Comments

@xplip
Copy link

xplip commented Nov 28, 2021

Hi, thanks for providing this codebase!

So for a while I've been using Opacus to experiment with DP-SGD and RoBERTa, but I wanted to check out your PrivacyEngine, mainly because of the training speed and memory optimizations. With Opacus, I always trained with their UniformWithReplacementSampler for accurate RDP accounting and as far as I can tell, you're training with fixed size batches in your examples. I'm wondering if there's a reason the UniformWithReplacementSampler isn't needed in your codebase anymore, and if the uniform sampler is compatible with your modified PrivacyEngine because the optimizer needs to be able to deal with variations in batch size?

@lxuechen
Copy link
Owner

Hi,

Thanks for the interest and raising a good point!

With Opacus, I always trained with their UniformWithReplacementSampler for accurate RDP accounting and as far as I can tell, you're training with fixed size batches in your examples.

Yes, most examples in Opacus use UniformWithReplacementSampler, which effectively does Poisson sampling. This is consistent with the subsampled Gaussian mechanism.

Strictly speaking, there's a mismatch between the batch selection rule used in our code (fixed batch size and non-overlapping batches within one epoch) and the privacy accounting procedure. To actually attain the prescribed privacy guarantee, one should use Poisson sampling under the current privacy accounting procedure.

However, Poisson sampling yields batches of non-uniform sizes, some of which could be too large to cause memory issues. Gradient accumulation (e.g., the example here) partially addresses this problem but not entirely -- even with a smaller sampling rate for the micro-batches, there's still a non-zero probability of picking extremely large micro-batches.

Past works have therefore used the usual batch selection rule (as for non-private learning) as a proxy for the true Poisson sampled performance, see Appendix D.4 of this paper; they also show that the difference in performance is minor. We follow this convention here.

if the uniform sampler is compatible with your modified PrivacyEngine

UniformWithReplacementSampler should be compatible with my codebase, since I don't see any direct and/or serious roadblocks, but I haven't tested this out. I will test this out by the end of this week.

One should still recall that there's a non-zero probability of selecting huge micro-batches which would cause OOM issues (with the approach given in the Opacus example). The alternative option here would be to manually break each Poisson sampled batch into micro-batches of fixed size.

I actually have been working on refining this part of the codebase, and there is an alternative solution. One could still use fixed batches, but each of which would be an independent and uniform sample over all possible batches. Note the usual loop-over-batches-across-dataset approach doesn't satisfy this since two consecutive batches within one epoch can't be independent due to sampling without replacement. The privacy accounting procedure, however, needs to be slightly modified (see this for code, and Theorem 27 in this paper for the theory).

Hope this addresses your concerns and helps with whatever you're working on!

Chen

@xplip
Copy link
Author

xplip commented Nov 30, 2021

Thank you so much for this extensive reply! The OOM issues you've mentioned are exactly what I've encountered when using the poisson sampler, which is why I had to train with batch size 1 and a huge number of gradient accumulation steps (which is obviously much slower). I wasn't aware of that Appendix D.4 in particular, so this is going to save me a lot of time in future experiments, and it definitely cleared up my confusion around this topic.

Phillip

@lxuechen
Copy link
Owner

Thanks for the question, and I'm glad that it helped! I'm closing this issue for now, but feel free to re-open if there are other questions.

The refinement I mentioned about using fixed-size batches with the alternative accounting procedure would be checked in in the near future.

@xplip
Copy link
Author

xplip commented Dec 6, 2021

Sorry for having to reopen this, but I do have two more (perhaps related) questions after all and would really appreciate if you could help clarify them.

  1. When using the automated sigma search (based on a specified target epsilon and N epochs), the final epsilon computed by the PrivacyEngine after training for N epochs is always much higher than the target epsilon, so it seems that the sigma chosen by get_sigma_from_rdp is too high. This also happens when I run the sentence classification and the table2test examples in the repo. E.g., instead of my target epsilon 8, I will end up with something like epsilon 10-11. How did you get your final epsilon to match the target epsilon in the experiments in your paper?

  2. How do you compute the converted epsilon from composed tradeoff functions when let's say training SST-2 with the default hyperparameters from the examples? Do you reduce the num_compositions=1000 in _eps_from_glw to something way lower than 1000 because the script only runs for ~400 optimization steps and would otherwise always throw the Numerical composition of tradeoff functions failed! Double check privacy parameters. error?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants