From 888e43582f3258ec3348c1f708e6d58d43827f94 Mon Sep 17 00:00:00 2001 From: Pooja Agarwal Date: Mon, 9 Dec 2024 04:14:02 +0000 Subject: [PATCH 1/2] Adds fast gradient norm calculation for the embedding layer. The algorithm is described in the 'A Unified Fast Gradient Clipping Framework for DP-SGD' paper: https://proceedings.neurips.cc/paper_files/paper/2023/file/a45d344b28179c8da7646bc38ff50ad8-Paper-Conference.pdf. This reduces the memory needed to run DP-SGD over embedding layers, significantly reducing OOMs over large embedding layers. --- opacus/grad_sample/__init__.py | 1 + opacus/grad_sample/embedding.py | 26 ++- opacus/grad_sample/embedding_norm_sample.py | 150 ++++++++++++ ...mple_module_fast_gradient_clipping_test.py | 164 +++++++++++++ .../embedding_norm_sample_test.py | 218 ++++++++++++++++++ 5 files changed, 557 insertions(+), 2 deletions(-) create mode 100644 opacus/grad_sample/embedding_norm_sample.py create mode 100644 opacus/tests/grad_samples/embedding_norm_sample_test.py diff --git a/opacus/grad_sample/__init__.py b/opacus/grad_sample/__init__.py index a44f26648..3e65c3317 100644 --- a/opacus/grad_sample/__init__.py +++ b/opacus/grad_sample/__init__.py @@ -17,6 +17,7 @@ from .dp_multihead_attention import compute_sequence_bias_grad_sample # noqa from .dp_rnn import compute_rnn_linear_grad_sample # noqa from .embedding import compute_embedding_grad_sample # noqa +from .embedding_norm_sample import compute_embedding_norm_sample # noqa from .grad_sample_module import GradSampleModule, create_or_accumulate_grad_sample from .grad_sample_module_fast_gradient_clipping import ( # noqa GradSampleModuleFastGradientClipping, diff --git a/opacus/grad_sample/embedding.py b/opacus/grad_sample/embedding.py index 9e206a6a8..846cd9075 100644 --- a/opacus/grad_sample/embedding.py +++ b/opacus/grad_sample/embedding.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict +from typing import Dict, List +from opacus.grad_sample import embedding_norm_sample import torch import torch.nn as nn -from .utils import register_grad_sampler +from .utils import register_grad_sampler, register_norm_sampler @register_grad_sampler(nn.Embedding) @@ -82,3 +83,24 @@ def compute_embeddingbag_gradsampler(layer, inputs, backprops): ret[layer.weight] = gsm return ret + + +@register_norm_sampler(nn.Embedding) +def compute_embedding_norm_sample( + layer: nn.Embedding, + activations: List[torch.Tensor], + backprops: torch.Tensor, +) -> Dict[nn.Parameter, torch.Tensor]: + """Computes gradient norms for ``nn.Embedding`` layer. + + Args: + layer: Layer + activations: Activations + backprops: Backpropagations + + Returns: + A dictionary of parameter gradients + """ + return embedding_norm_sample.compute_embedding_norm_sample( + layer, activations, backprops + ) diff --git a/opacus/grad_sample/embedding_norm_sample.py b/opacus/grad_sample/embedding_norm_sample.py new file mode 100644 index 000000000..308f0028b --- /dev/null +++ b/opacus/grad_sample/embedding_norm_sample.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright 2024, The Opacus authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility for computing gradient norm for the embedding layer. + +Based on the algorithm from the paper: +https://proceedings.neurips.cc/paper_files/paper/2023/file/a45d344b28179c8da7646bc38ff50ad8-Paper-Conference.pdf. +""" +from typing import Dict, List + +import torch +from torch import nn + + +def compute_embedding_norm_sample( + layer: nn.Embedding, + activations: List[torch.Tensor], + backprops: torch.Tensor, +) -> Dict[nn.Parameter, torch.Tensor]: + """Computes per sample gradient norms for ``nn.Embedding`` layer. + + Args: + layer: Layer + activations: Activations + backprops: Backpropagations + + Returns: + A dictionary of parameter gradients + + NOTE: Here is an example input, and the expected intermediate values. This + is proivided to help in understanding the algorithm: + Inputs: + layer: Embedding(3, 1) # (vocab_size, embedding_dim) + activations: [tensor([[1, 1], + [2, 0], + [2, 0]])] + backprops: tensor([[0.2000], + [0.2000], + [0.3000], + [0.1000], + [0.3000], + [0.1000]]) + + Intermediate values: + input_ids: tensor([[1, 1], + [2, 0], + [2, 0]]) + input_ids.shape: torch.Size([3, 2]) + grad_values: tensor([[0.2000], + [0.2000], + [0.3000], + [0.1000], + [0.3000], + [0.1000]]) + grad_values.shape: torch.Size([6, 1]) + nrows: 3 + ncols: 2 + row_indices: tensor([[0], + [0], + [1], + [1], + [2], + [2]]) + flattened_indices: tensor([[1], + [1], + [2], + [0], + [2], + [0]]) + paired_indices: tensor([[0, 1], + [0, 1], + [1, 2], + [1, 0], + [2, 2], + [2, 0]]) + unique_paired_indices: tensor([[0, 1], + [1, 0], + [1, 2], + [2, 0], + [2, 2]]) + new_index_positions: tensor([0, 0, 2, 1, 4, 3]) + num_unique_paired_indices: 5 + summed_gradients: tensor([[0.4000], + [0.1000], + [0.3000], + [0.1000], + [0.3000]]) + sqr_gradient_sum: tensor([0.1600, 0.0100, 0.0900, 0.0100, 0.0900]) + unique_batch_ids: tensor([0, 1, 1, 2, 2]) + result: tensor([0.1600, 0.1000, 0.1000]) + result_sqrt: tensor([0.4000, 0.3162, 0.3162]) + """ + device = activations[0].device + input_ids = activations[0].to(device) + grad_values = backprops.to(device) + + # Reshape input_ids preserving the batch size as the first dimension + input_ids = input_ids.reshape(input_ids.shape[0], -1) + + # Reshape grad_values preserving the embedding dimension as the last dimension + grad_values = grad_values.reshape(-1, grad_values.size(-1)) + + # Create 1D tensor of row indices + nrows = input_ids.size(0) + ncols = input_ids.size(1) + row_indices = ( + torch.repeat_interleave(torch.arange(nrows).to(device), ncols) + .unsqueeze(-1) + .to(device) + ) + + # Pair the input IDs with the row indices + flattened_indices = input_ids.view(-1, 1) + paired_indices = torch.cat([row_indices, flattened_indices], dim=1).to(device) + + # Get unique paired indices and new index positions for aggregation + unique_paired_indices, new_index_positions = torch.unique( + paired_indices, dim=0, return_inverse=True, sorted=True + ) + + # Sum gradients over new index positions and compute squared gradient norms + num_unique_paired_indices = unique_paired_indices.size(0) + summed_gradients = torch.zeros( + num_unique_paired_indices, grad_values.size(-1), device=device + ) + summed_gradients = summed_gradients.index_add( + 0, new_index_positions.to(device), grad_values + ) + sqr_gradient_sum = torch.sum(summed_gradients**2, dim=1) + + # Scatter add the squared sums back to their respective rows + result = torch.zeros(nrows, device=device) + unique_batch_ids = unique_paired_indices[:, 0].to(device) + result.scatter_add_(0, unique_batch_ids, sqr_gradient_sum) + + # Compute the square root for the final result (norm) + result_sqrt = torch.sqrt(result) + return {layer.weight: result_sqrt} diff --git a/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py b/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py index 8029c8f4b..a86f088f5 100644 --- a/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py +++ b/opacus/tests/grad_sample_module_fast_gradient_clipping_test.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import unittest import hypothesis.strategies as st import torch @@ -67,6 +68,21 @@ def forward(self, x): return x +class SampleEmbeddingModule(nn.Module): + def __init__(self, vocab_size, embedding_dim): + super(SampleEmbeddingModule, self).__init__() + self.embedding = nn.Embedding(vocab_size, embedding_dim) + + # Manually set weights for the embedding layer for testing + self.embedding.weight = nn.Parameter( + torch.tensor([[0.1], [0.2], [0.3]], dtype=torch.float32) + ) + + def forward(self, x): + x = self.embedding(x) + return x + + class GradSampleModuleFastGradientClippingTest(GradSampleModuleTest): CLS = GradSampleModuleFastGradientClipping @@ -260,3 +276,151 @@ def test_gradient_calculation_fast_gradient_clipping(self, size, length, dim): logging.info(f"Diff = {diff}") msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different" assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg + + +class GradSampleModuleFastGradientClippingEmbeddingLayerTest(unittest.TestCase): + + def test_norm_calculation(self): + """ + Tests if norm calculation for embedding layer is the same between + standard (Opacus) and fast gradient clipping" + """ + vocab_size = 3 + embedding_dim = 1 + + criterion = torch.nn.CrossEntropyLoss(reduction="none") + noise_multiplier = 0.0 + input_data = torch.tensor([[1, 1], [2, 0], [2, 0]], dtype=torch.long) + batch_size = 3 + max_grad_norm = 1.0 + sample_module = SampleEmbeddingModule(vocab_size, embedding_dim) + model_normal = GradSampleModule(clone_module(sample_module)) + optimizer_normal = torch.optim.SGD(model_normal.parameters(), lr=1) + optimizer_normal = DPOptimizer( + optimizer_normal, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=batch_size, + ) + + grad_sample_module = GradSampleModuleFastGradientClipping( + clone_module(sample_module), + max_grad_norm=max_grad_norm, + use_ghost_clipping=True, + ) + optimizer_gc = torch.optim.SGD(grad_sample_module.parameters(), lr=1) + optimizer_gc = DPOptimizerFastGradientClipping( + optimizer_gc, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=batch_size, + ) + + optimizer_normal.zero_grad() + output_normal = model_normal(input_data) + target_data = torch.rand_like(output_normal) + + loss_normal = torch.mean(criterion(output_normal, target_data), dim=0) + loss_normal.backward() + all_norms_normal = torch.stack( + [ + torch.stack([g.norm() for g in param.grad_sample], dim=0) + for param in model_normal.parameters() + ], + dim=0, + ) + flat_norms_normal = torch.cat([p.flatten() for p in all_norms_normal]) + + grad_sample_module.enable_hooks() + output_gc = grad_sample_module(input_data) + + first_loss_per_sample = criterion(output_gc, target_data) + first_loss = torch.mean(first_loss_per_sample) + first_loss.backward(retain_graph=True) + + optimizer_gc.zero_grad() + coeff = grad_sample_module.get_clipping_coef() + second_loss_per_sample = coeff * first_loss_per_sample + second_loss = torch.sum(second_loss_per_sample) + grad_sample_module.disable_hooks() + second_loss.backward() + + all_norms_gc = [param._norm_sample for param in grad_sample_module.parameters()] + flat_norms_gc = torch.cat([p.flatten() for p in all_norms_gc]) + + diff = flat_norms_normal - flat_norms_gc + + logging.info(f"Diff = {diff}") + msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different" + assert torch.allclose(flat_norms_normal, flat_norms_gc, atol=1e-3), msg + + def test_gradient_calculation(self): + """Tests if gradients for embedding layer are the same between standard + (Opacus) and fast gradient clipping.""" + + noise_multiplier = 0.0 + vocab_size = 3 + embedding_dim = 1 + batch_size = 3 + input_data = torch.tensor([[1, 1], [2, 0], [2, 0]], dtype=torch.long) + max_grad_norm = 1.0 + criterion = torch.nn.CrossEntropyLoss() + + sample_module = SampleEmbeddingModule(vocab_size, embedding_dim) + model_normal = GradSampleModule(clone_module(sample_module)) + grad_sample_module = GradSampleModuleFastGradientClipping( + clone_module(sample_module), + max_grad_norm=max_grad_norm, + use_ghost_clipping=True, + ) + + optimizer_normal = torch.optim.SGD(model_normal.parameters(), lr=1) + optimizer_normal = DPOptimizer( + optimizer_normal, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=batch_size, + ) + + optimizer_gc = torch.optim.SGD(grad_sample_module.parameters(), lr=1) + optimizer_gc = DPOptimizerFastGradientClipping( + optimizer_gc, + noise_multiplier=noise_multiplier, + max_grad_norm=max_grad_norm, + expected_batch_size=batch_size, + ) + + criterion_gc = DPLossFastGradientClipping( + grad_sample_module, optimizer_gc, criterion + ) + + optimizer_normal.zero_grad() + output_normal = model_normal(input_data) + target_data = torch.tensor([[[0.1], [0.1]], [[0.2], [0.3]], [[0.2], [0.3]]]) + loss_normal = torch.mean(criterion(output_normal, target_data), dim=0) + loss_normal.backward() + optimizer_normal.step() + + all_grads_normal = [param.summed_grad for param in model_normal.parameters()] + flat_grads_normal = torch.cat([p.flatten() for p in all_grads_normal]) + + optimizer_gc.zero_grad() + grad_sample_module.enable_hooks() + output_gc = grad_sample_module(input_data) + + loss_gc = criterion_gc(output_gc, target_data) + loss_gc.backward() + optimizer_gc.step() + + all_grads_gc = [param.grad for param in grad_sample_module.parameters()] + flat_grads_gc = torch.cat([p.flatten() for p in all_grads_gc]) + diff = torch.tensor( + [ + (g_gc - g_normal).norm() + for (g_gc, g_normal) in zip(flat_grads_gc, flat_grads_normal) + ] + ) + + logging.info(f"Diff = {diff}") + msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different" + assert torch.allclose(flat_grads_normal, flat_grads_gc, atol=1e-3), msg diff --git a/opacus/tests/grad_samples/embedding_norm_sample_test.py b/opacus/tests/grad_samples/embedding_norm_sample_test.py new file mode 100644 index 000000000..97aeec826 --- /dev/null +++ b/opacus/tests/grad_samples/embedding_norm_sample_test.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from opacus.grad_sample import embedding_norm_sample +import torch +import torch.nn as nn + + +class TestComputeEmbeddingNormSample(unittest.TestCase): + + def test_compute_embedding_norm_sample(self): + # Define the embedding layer + embedding_dim = 1 + vocab_size = 3 + embedding_layer = nn.Embedding(vocab_size, embedding_dim) + + # Manually set weights for the embedding layer for testing + embedding_layer.weight = nn.Parameter( + torch.tensor([[0.1], [0.2], [0.3]], dtype=torch.float32) + ) + + # Example input ids (activations). Shape: [3, 2] + input_ids = torch.tensor([[1, 1], [2, 0], [2, 0]], dtype=torch.long) + + # Example gradients with respect to the embedding output (backprops). + # Shape: [6, 1] + grad_values = torch.tensor( + [[0.2], [0.2], [0.3], [0.1], [0.3], [0.1]], dtype=torch.float32 + ) + + # Simulate backprop through embedding layer + backprops = grad_values + + # Wrap input_ids in a list as expected by the norm sample function + activations = [input_ids] + + # Call the function under test + result = embedding_norm_sample.compute_embedding_norm_sample( + embedding_layer, activations, backprops + ) + + # Expected norms + expected_norms = torch.tensor([0.4000, 0.3162, 0.3162], dtype=torch.float32) + + # Extract the result for the embedding layer weight parameter + computed_norms = result[embedding_layer.weight] + + # Verify the computed norms match the expected norms + torch.testing.assert_close(computed_norms, expected_norms, atol=1e-4, rtol=1e-4) + + def test_compute_embedding_norm_sample_with_non_one_embedding_dim(self): + # Define the embedding layer + embedding_dim = 2 + vocab_size = 3 + embedding_layer = nn.Embedding(vocab_size, embedding_dim) + + # Manually set weights for the embedding layer for testing + embedding_layer.weight = nn.Parameter( + torch.tensor([[0.1], [0.2], [0.3]], dtype=torch.float32) + ) + + # Example input ids (activations). Shape: [6, 1, 1]. + input_ids = torch.tensor( + [[[1]], [[1]], [[2]], [[0]], [[2]], [[0]]], dtype=torch.long + ) + + # Example gradients per input id, with embedding_dim=2. + # Shape: [6, 1, 1, 2] + grad_values = torch.tensor( + [ + [[[0.2, 0.2]]], + [[[0.2, 0.2]]], + [[[0.3, 0.3]]], + [[[0.1, 0.1]]], + [[[0.3, 0.3]]], + [[[0.1, 0.1]]], + ], + dtype=torch.float32, + ) + + # Simulate backprop through embedding layer + backprops = grad_values + + # Wrap input_ids in a list as expected by the grad norm function + activations = [input_ids] + + # Call the function under test + result = embedding_norm_sample.compute_embedding_norm_sample( + embedding_layer, activations, backprops + ) + + # Expected output based on the example + expected_norms = torch.tensor( + [0.2828, 0.2828, 0.4243, 0.1414, 0.4243, 0.1414], dtype=torch.float32 + ) + + # Extract the result for the embedding layer weight parameter + computed_norms = result[embedding_layer.weight] + + # Verify the computed norms match the expected norms + torch.testing.assert_close(computed_norms, expected_norms, atol=1e-4, rtol=1e-4) + + def test_compute_embedding_norm_sample_with_extra_activations_per_example(self): + # Define the embedding layer + embedding_dim = 1 + vocab_size = 10 + embedding_layer = nn.Embedding(vocab_size, embedding_dim) + + # Manually set weights for the embedding layer for testing + embedding_layer.weight = nn.Parameter( + torch.tensor( + [ + [0.1], + [0.2], + [0.3], + [0.4], + [0.5], + [0.6], + [0.7], + [0.8], + [0.9], + [1.0], + ], + dtype=torch.float32, + ) + ) + + # Example input ids with 6 activations per sample, shape: [5, 6, 1] + input_ids = torch.tensor( + [ + [[0], [0], [0], [0], [0], [0]], + [[1], [0], [0], [0], [0], [0]], + [[2], [3], [4], [5], [6], [7]], + [[4], [3], [0], [0], [0], [0]], + [[8], [7], [9], [6], [5], [0]], + ], + dtype=torch.long, + ) + + # Example gradients per input id, with embedding_dim=1. + # Shape: [5, 6, 1, 1] + backprops = torch.tensor( + [ + [ + [[0.0025]], + [[0.0025]], + [[0.0025]], + [[0.0025]], + [[0.0025]], + [[0.0025]], + ], + [ + [[-0.0014]], + [[-0.0014]], + [[-0.0014]], + [[-0.0014]], + [[-0.0014]], + [[-0.0014]], + ], + [ + [[-0.0002]], + [[-0.0002]], + [[-0.0002]], + [[-0.0002]], + [[-0.0002]], + [[-0.0002]], + ], + [ + [[0.0019]], + [[0.0019]], + [[0.0019]], + [[0.0019]], + [[0.0019]], + [[0.0019]], + ], + [ + [[-0.0016]], + [[-0.0016]], + [[-0.0016]], + [[-0.0016]], + [[-0.0016]], + [[-0.0016]], + ], + ], + dtype=torch.float32, + ) + + # Wrap input_ids in a list as expected by the function + activations = [input_ids] + + # Call the function we want to test + result = embedding_norm_sample.compute_embedding_norm_sample( + embedding_layer, activations, backprops + ) + + # Expected output based on the example + expected_norms = torch.tensor( + [0.0150, 0.0071, 0.0005, 0.0081, 0.0039], dtype=torch.float32 + ) + print("expected_norms: ", expected_norms) + computed_norms = result[embedding_layer.weight] + + # Verify the computed norms match the expected norms + torch.testing.assert_close(computed_norms, expected_norms, atol=1e-4, rtol=1e-4) From b9f521aa76e1d7208b6c00233ad7c67202ac805c Mon Sep 17 00:00:00 2001 From: Pooja Agarwal Date: Mon, 9 Dec 2024 06:00:55 +0000 Subject: [PATCH 2/2] Minor updates to Opacus: a) Avoid compile time error on variable args returned for ghost vs non-ghost clipping techniques. The approach can be further enhanced, adding a quick fix to handle the error. b) Allow providing epsilon_tolerance dynamically. --- opacus/privacy_engine.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/opacus/privacy_engine.py b/opacus/privacy_engine.py index 1af891c48..6fb36928c 100644 --- a/opacus/privacy_engine.py +++ b/opacus/privacy_engine.py @@ -289,7 +289,7 @@ def make_private( noise_generator=None, grad_sample_mode: str = "hooks", **kwargs, - ) -> Tuple[GradSampleModule, DPOptimizer, DataLoader]: + ): """ Add privacy-related responsibilities to the main PyTorch training objects: model, optimizer, and the data loader. @@ -339,12 +339,15 @@ def make_private( details Returns: - Tuple of (model, optimizer, data_loader). + Tuple of (model, optimizer, criterion (if grad_sample_model="ghost"), data_loader). Model is a wrapper around the original model that also computes per sample gradients Optimizer is a wrapper around the original optimizer that also does gradient clipping and noise addition to the gradients + Criterion is a wrapper around the original criterion that does two + backward pass under the hood. Returned if grad_sample_mode is + "ghost". DataLoader is a brand new DataLoader object, constructed to behave as equivalent to the original data loader, possibly with updated sampling mechanism. Points to the same dataset object. @@ -472,17 +475,23 @@ def make_private_with_epsilon( details Returns: - Tuple of (model, optimizer, data_loader). + Tuple of (model, optimizer, criterion (if grad_sample_mode="ghost"), data_loader). Model is a wrapper around the original model that also computes per sample gradients Optimizer is a wrapper around the original optimizer that also does gradient clipping and noise addition to the gradients + Criterion is a wrapper around the original criterion that does two + backward pass under the hood. Returned if grad_sample_mode is + "ghost". DataLoader is a brand new DataLoader object, constructed to behave as equivalent to the original data loader, possibly with updated sampling mechanism. Points to the same dataset object. """ sample_rate = 1 / len(data_loader) + epsilon_tolerance = kwargs.get( + "epsilon_tolerance", 0.01 + ) # same default as in get_noise_multiplier if len(self.accountant) > 0: warnings.warn( @@ -502,6 +511,7 @@ def make_private_with_epsilon( sample_rate=sample_rate, epochs=epochs, accountant=self.accountant.mechanism(), + epsilon_tolerance=epsilon_tolerance, **kwargs, ), max_grad_norm=max_grad_norm,