In [None]:
#@title License
# Copyright 2025 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Table of contents

>[Table of contents](#scrollTo=jin3IPhsiW4L)

>[External packages](#scrollTo=BMj0796JsHrw)

>[Preamble](#scrollTo=jxHrA25jhYSN)

>[Main implementations](#scrollTo=yvGcuHtshUpV)

>>[CPU variants](#scrollTo=r-RIw7vofKoE)

>>[GPU variants](#scrollTo=VnUiCbPnfOIN)

>[Main experiments](#scrollTo=KzJPcoufhM1Z)

>>[Gradient norm benchmarks on CPU/GPU](#scrollTo=Fnhbfud2jyxz)

>>>[Experiment executors](#scrollTo=JCeNkaOHpGhE)

>>>[Experiment printers](#scrollTo=Vui2tk-5FoxF)

>>>[CPU benchmarks](#scrollTo=uVhXmlY3q42W)

>>>[GPU benchmarks](#scrollTo=-NoJnMfBi3ye)

>>[End-to-end benchmarks on GPU](#scrollTo=OlznSp5TEfz1)

>>>[Opacus gradient samplers](#scrollTo=t1j9ACzhIMPs)

>>>[Experiment executors](#scrollTo=d2w5jl286m6P)

>>>[GPU benchmarks](#scrollTo=4-hvSid56o9I)

>>>>[Both naive and fast clipping](#scrollTo=xFN5JIrfBpWe)

>>>>[Fast clipping only](#scrollTo=4FJw63nw3tvg)

>[Additional experiments](#scrollTo=UkF7cVd1fjLQ)

>>[Gradient norm benchmarks on CPU/GPU](#scrollTo=tJqYxkKvfz9Y)

>>>[GPU benchmarks](#scrollTo=USkSHgYMgq2l)



# External packages

In [None]:
!pip install opacus

Collecting opacus
  Downloading opacus-1.5.3-py3-none-any.whl.metadata (8.4 kB)
Collecting numpy<2.0,>=1.15 (from opacus)
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0->opacus)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0->opacus)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0->opacus)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0->opacus)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadat

In [None]:
!pip install nvidia-ml-py3

Collecting nvidia-ml-py3
  Downloading nvidia-ml-py3-7.352.0.tar.gz (19 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: nvidia-ml-py3
  Building wheel for nvidia-ml-py3 (setup.py) ... [?25l[?25hdone
  Created wheel for nvidia-ml-py3: filename=nvidia_ml_py3-7.352.0-py3-none-any.whl size=19172 sha256=475611eaefa7e6e3781d2fb94c64240ff590c3ebb327d6f680903ed2e9813395
  Stored in directory: /root/.cache/pip/wheels/47/50/9e/29dc79037d74c3c1bb4a8661fb608e8674b7e4260d6a3f8f51
Successfully built nvidia-ml-py3
Installing collected packages: nvidia-ml-py3
Successfully installed nvidia-ml-py3-7.352.0


# Preamble

In [None]:
import math
import re
import subprocess
import numpy as np
import time
import cupy as cp
from collections import defaultdict
from typing import Sequence, Mapping, Callable, Any, List, Dict, Optional
import torch
import torch.nn as nn
from functools import lru_cache
from numba import cuda
import psutil
import os
import tracemalloc
import platform
import distro
import gc

In [None]:
import nvidia_smi
import opacus
from opacus.grad_sample.utils import register_norm_sampler
from opacus.utils.per_sample_gradients_utils import clone_module
from opacus.grad_sample import GradSampleModule, GradSampleModuleFastGradientClipping
from opacus.optimizers import DPOptimizer, DPOptimizerFastGradientClipping
from opacus.utils.fast_gradient_clipping_utils import DPLossFastGradientClipping

In [None]:
#@title Check GPU availability

@lru_cache()
def is_gpu_available() -> bool:
    """Check if nvidia-smi is available."""
    try:
        nvidia_smi.nvmlInit()
        nvidia_smi.nvmlShutdown()
    except Exception:
        return False
    return True

is_gpu_available()

True

In [None]:
def get_processor_name():
    if platform.system() == "Windows":
        return platform.processor()
    elif platform.system() == "Darwin":
        os.environ['PATH'] = os.environ['PATH'] + os.pathsep + '/usr/sbin'
        command ="sysctl -n machdep.cpu.brand_string"
        return subprocess.check_output(command).strip()
    elif platform.system() == "Linux":
        command = "cat /proc/cpuinfo"
        all_info = subprocess.check_output(command, shell=True).decode().strip()
        for line in all_info.split("\n"):
            if "model name" in line:
                return re.sub( ".*model name.*:", "", line,1)
    return ""

In [None]:
#@title Check system info
def get_size(bytes, suffix="B"):
  factor = 1000
  for unit in ["", "K", "M", "G", "T", "P"]:
    if bytes < factor:
      return f"{bytes:.2f}{unit}{suffix}"
    bytes /= factor

uname = platform.uname()
print(f"System: {uname.system}")
print(f"Distro: {distro.name()} {distro.version()}")
# print(f"Node Name: {uname.node}")
# print(f"Release: {uname.release}")
# print(f"Version: {uname.version}")
# print(f"Machine: {uname.machine}")
print(f"Processor: {get_processor_name()}")
# # number of cores
# print("Physical cores:", psutil.cpu_count(logical=False))
# print("Total cores:", psutil.cpu_count(logical=True))
svmem = psutil.virtual_memory()
print(f"Total RAM: {get_size(svmem.total)}")
print(f"GPU: {torch.cuda.get_device_name()}")

System: Linux
Distro: Ubuntu 22.04
Processor:  Intel(R) Xeon(R) CPU @ 2.00GHz
Total RAM: 13.61GB
GPU: Tesla T4


# Main implementations


## CPU variants

In [None]:
"""This is a library for fast and memory efficient gradient norm computation.

The functions operate on a single Convolutional Neural Network (CNN) layer and
take as input one sample of the batch and the corresponding partial gradient.
The sample is assumed to be a 2D matrix whose rows correspond o the 1D vectors
of the different input channels.The partial gradient is assumed to be a 2D
matrix whose rows correspond to the 1D vectors of the different output channels.
The other two args are the kernel size and stride of the layer.
"""

def _check_value_and_shape_of_arguments(
    input_matrix: np.ndarray,
    partial_gradient: np.ndarray,
    kernel_size: int,
    stride: int,
):
  """Checks the arguments of the functions in this library."""
  if input_matrix.ndim != 2:
    raise ValueError("input_matrix must be a 2D matrix")
  if partial_gradient.ndim != 2:
    raise ValueError("partial_gradient must be a 2D matrix")
  if input_matrix.shape[1] == 0:
    raise ValueError("input_matrix must be non-empty")
  if partial_gradient.shape[1] == 0:
    raise ValueError("partial_gradient must be non-empty")
  if kernel_size <= 0:
    raise ValueError("kernel_size must be a positive integer")
  if stride <= 0:
    raise ValueError("stride must be a positive integer")
  if (
      # This is the formula for the output dimension of a CNN layer.
      math.floor((input_matrix.shape[1] - kernel_size) / stride + 1)
      != partial_gradient.shape[1]
  ):
    raise ValueError(
        "Number of columns of partial_gradient must be equal to the"
        " output dimension of the layer"
    )


def in_place_fast_grad_norm(
    input_matrix: np.ndarray,
    partial_gradient: np.ndarray,
    kernel_size: int,
    stride: int,
) -> float:
  """Computes the gradient norm squared of a single sample in a batch.

  This function avoids explicitly instantiating the intermediate matrix that is
  used in the gradient norm computation. This is useful when the batch size is
  large and the gradient norm is computed for each sample in the batch.

  More formally, it implements the following logic: let x be the input matrix, g
  be the partial gradient, U(x[i]) be the matrix whose rows correspond to
  the different kernel windows of the i-th input channel, n_in be the number of
  input channels, n_out be the number of output channels, and res be the
  l_2 gradient norm squared. Then,
  res = sum_{i in n_in} sum_{j in n_out} ||U(x[i])^T g[j]||^2.

  Args:
    input_matrix: 2D matrix whose rows are 1D vectors of the input to the layer.
    partial_gradient: 2D matrix whose rows are 1D vectors of the partial
      gradient across the output channels.
    kernel_size: kernel size of the layer.
    stride: stride of the layer.

  Returns:
    l_2 norm squared of the gradient as a float.
  """

  # This function is checking that the values and shapes of args are valid.
  _check_value_and_shape_of_arguments(
      input_matrix, partial_gradient, kernel_size, stride
  )

  res = 0

  for input_vector in input_matrix:
    # We use the sliding window view to avoid explicitly instantiating the
    # intermediate matrix.
    u_input_vector = np.lib.stride_tricks.sliding_window_view(
        input_vector, window_shape=(kernel_size,)
    )[::stride]
    u_input_vector_transpose = u_input_vector.T
    for output in partial_gradient:
      res += np.sum(
          np.square(np.tensordot(u_input_vector_transpose, output, (1, 0)))
      )

  return res


def in_place_ghost_norm(
    input_matrix: np.ndarray,
    partial_gradient: np.ndarray,
    kernel_size: int,
    stride: int,
) -> float:
  """Computes the gradient norm squared of a single sample in a batch.

  This function uses the ghost norm trick to compute the gradient norm squared.
  It avoids explicitly instantiating the intermediate matrices that are
  used in the gradient norm computations. This is useful when the batch size is
  large and the gradient norm is computed for each sample in the batch.

  More formally, it implements the following logic: let x be the input matrix, g
  be the partial gradient, U(x[i]) be the matrix whose rows correspond to
  the different kernel windows of the i-th input channel, n_in be the number of
  input channels, n_out be the number of output channels, and res be the
  l_2 gradient norm squared. Then,
  res = <sum_{i in n_in} U(x[i]) U(x[i])^T, sum_{j in n_out} g[j] (g[j])^T>,
  where <,> is the Frobenius inner product.

  Args:
    input_matrix: 2D matrix whose rows are 1D vectors of the input to the layer.
    partial_gradient: 2D matrix whose rows are 1D vectors of the partial
      gradient across the output channels.
    kernel_size: kernel size of the layer.
    stride: stride of the layer.

  Returns:
    l_2 norm squared of the gradient as a float.
  """

  # checking shapes and values of the arguments
  _check_value_and_shape_of_arguments(
      input_matrix, partial_gradient, kernel_size, stride
  )

  output_dimension = partial_gradient.shape[1]
  number_input_channels = input_matrix.shape[0]
  number_output_channels = partial_gradient.shape[0]
  res = 0

  for j_1 in range(output_dimension):
    for j_2 in range(output_dimension):
      # This expression computes the following:
      # sum_{i in n_in} (U(x[i]) U(x[i])^T)[j_1][j_2])
      # Recall that the j_1 row of U(x[i]) corresponds to the kernel
      # window of the i-th input channel at the j_1-th position.
      temp_1 = 0
      for i in range(number_input_channels):
        temp_1 += np.dot(
            input_matrix[i][j_1 * stride : j_1 * stride + kernel_size],
            input_matrix[i][j_2 * stride : j_2 * stride + kernel_size],
        )
      # This expression computes the following:
      # sum_{k in n_out} g[k](g[k])^T[j_1][j_2]
      temp_2 = 0
      for k in range(number_output_channels):
        temp_2 += np.dot(partial_gradient[k][j_1], partial_gradient[k][j_2])
      res += temp_1 * temp_2

  return res


def in_place_norm_fft(
    input_matrix: np.ndarray,
    partial_gradient: np.ndarray,
    kernel_size: int,
    stride: int,
) -> float:
  """Computes the gradient norm squared of a single sample in a batch.

  This function uses the Fast Fourier Transform to compute the gradient
  norm squared, by efficiently transforming the gradient norm computation to a
  multiplication of a circulant matrix with a vector. It avoids explicitly
  instantiating the intermediate matrices and vectors that are used in the
  computations. This is useful when the batch size is large and
  the gradient norm is computed for each sample in the batch.

  More formally, it implements the following logic: let x be the input matrix, g
  be the partial gradient, U(x[i]) be the matrix whose rows correspond to
  the different kernel windows of the i-th input channel, n_in be the number of
  input channels, n_out be the number of output channels, and res be the l_2
  gradient norm squared. Then,
  res = sum_{i in n_in} sum_{j in n_out} ||R P U'(x[i])^T g'[j]||^2, where R is
  an operator that returns some specific entries of the vector it is applied to,
  P is an appropriate permutation matrix, U'(x[i]) is some circulantmatrix that
  is defined based on U(x[i]), and g'[j] is the vector that is obtained by
  padding g[j] appropriately.

  Args:
    input_matrix: 2D matrix whose rows are 1D vectors of the input to the layer.
    partial_gradient: 2D matrix whose rows are 1D vectors of the partial
      gradient across the output channels.
    kernel_size: kernel size of the layer.
    stride: stride of the layer.

  Returns:
    l_2 norm squared of the gradient as a float.
  """

  _check_value_and_shape_of_arguments(
      input_matrix, partial_gradient, kernel_size, stride
  )

  input_dimension = input_matrix.shape[1]
  output_dimension = partial_gradient.shape[1]
  number_input_channels = input_matrix.shape[0]
  number_output_channels = partial_gradient.shape[0]
  res = 0

  # We only allocate memory for the entries of the padded partial gradient once
  padded_partial_gradient = np.zeros(
      input_dimension, dtype=partial_gradient.dtype
  )

  for j in range(number_output_channels):
    # We start by populating the non-zero entries of the padded partial gradient
    upper_bound = (output_dimension - 1) * stride + 1
    padded_partial_gradient[:upper_bound:stride] = partial_gradient[j]

    partial_derivative_fft = np.fft.fft(padded_partial_gradient)
    for i in range(number_input_channels):

      # FFT of the first column of the circulant matrix that is defined based
      # on U(x[i])
      vector_in_fft = np.fft.fft(np.flip(input_matrix[i]))

      temp = np.flip(
          np.fft.ifft(np.multiply(vector_in_fft, partial_derivative_fft))
      )

      res += np.sum(np.real(temp[0:kernel_size] ** 2))

  return res


def _unfold(
    input_matrix: np.ndarray,
    kernel_size: int,
    stride: int,
    output_dimension: int,
):
  """Unfolds the input matrix.

  Args:
    input_matrix: 2D matrix whose rows are 1D vectors of the input to the layer.
    kernel_size: kernel size of the layer.
    stride: stride of the layer.
    output_dimension: output dimension of the layer.

  Returns:
    A 2D matrix whose rows correspond to the different kernel windows of
    all the input channels.
  """

  input_channels = input_matrix.shape[0]

  # Create slices of the input_matrix and stack them
  # Create an array of indices to slice
  indices = (
      np.arange(kernel_size)[None, :]
      + np.arange(output_dimension)[:, None] * stride
  )

  # Extract slices for each channel
  slices = np.stack(
      [input_matrix[j, indices] for j in range(input_channels)], axis=-1
  )
  res = slices.reshape(output_dimension, -1)
  return res


def naive_fast_grad_norm(
    input_matrix: np.ndarray,
    partial_gradient: np.ndarray,
    kernel_size: int,
    stride: int,
) -> float:
  """Computes the gradient norm squared of a single sample in a batch.

  This function is a memory inefficient implementation of the gradient norm
  computation. It can be faster than the in-place implementations of this
  function, but it requires more memory.

  More formally, it implements the following logic: let x be the input matrix, g
  be the partial gradient, U be the matrix whose i-th row consists of
  consecutive blocks of the i-th kernel windows of all input channels, n_in be
  the number of input channels, n_out be the number of output channels, and res
  be the l_2 gradient norm squared. Then,
  res = sum_{i in n_in} sum_{j in n_out} ||U(x[i])^T g[j]||^2.

  Unlike the in-place implementations, this function explicitly instantiates the
  intermediate matrix that is used in the gradient norm computation.

  Args:
    input_matrix: 2D matrix whose rows are 1D vectors of the input to the layer.
    partial_gradient: 2D matrix whose rows are 1D vectors of the partial
      gradient across the output channels.
    kernel_size: kernel size of the layer.
    stride: stride of the layer.

  Returns:
    l_2 norm squared of the gradient as a float.
  """

  _check_value_and_shape_of_arguments(
      input_matrix, partial_gradient, kernel_size, stride
  )

  unfolded_input_matrix = _unfold(
      input_matrix, kernel_size, stride, partial_gradient.shape[1]
  )
  grad = unfolded_input_matrix.T @ partial_gradient.T
  norm_grad = (
      np.einsum("ij,ij->", grad, grad)
  )

  return norm_grad


def naive_ghost_norm(
    input_matrix: np.ndarray,
    partial_gradient: np.ndarray,
    kernel_size: int,
    stride: int,
) -> float:
  """Computes the gradient norm squared of a single sample in a batch.

  This function is a memory inefficient implementation of the gradient norm
  computation. It can be faster than the in-place implementations of this
  function, but it requires more memory.

  More formally, it implements the following logic: let x be the input matrix, g
  be the partial gradient, U be the matrix whose i-th row consists of
  consecutive blocks of the i-th kernel windows of all input channels, n_in be
  the number of input channels, n_out be the number of output channels, and res
  be the l_2 gradient norm squared. Then, res = <U U^T, g g^T>, where <,> is the
  Frobenius inner product.

  Unlike the in-place implementations, this function explicitly instantiates the
  intermediate matrices that are used in the gradient norm computation.

  Args:
    input_matrix: 2D matrix whose rows are 1D vectors of the input to the layer.
    partial_gradient: 2D matrix whose rows are 1D vectors of the partial
      gradient across the output channels.
    kernel_size: kernel size of the layer.
    stride: stride of the layer.

  Returns:
    l_2 norm squared of the gradient as a float.
  """
  _check_value_and_shape_of_arguments(
      input_matrix, partial_gradient, kernel_size, stride
  )

  # computation of UU^T
  unfolded_input_matrix = _unfold(
      input_matrix, kernel_size, stride, partial_gradient.shape[1]
  )
  v = unfolded_input_matrix @ unfolded_input_matrix.T

  # computation of gg^T
  partial_matrix = partial_gradient.T @ partial_gradient

  norm_grad = np.tensordot(v, partial_matrix, axes=[[0, 1], [0, 1]])
  return norm_grad


## GPU variants

In [None]:
def cp_in_place_fast_grad_norm(
    input_matrix: cp.ndarray,
    partial_gradient: cp.ndarray,
    kernel_size: int,
    stride: int,
) -> float:
  """Computes the gradient norm squared of a single sample in a batch (GPU version).

  This function avoids explicitly instantiating the intermediate matrix that is
  used in the gradient norm computation by leveraging CuPy's capabilities.

  Args:
      input_matrix: 2D CuPy array whose rows are 1D vectors of the input to the
        layer.
      partial_gradient: 2D CuPy array whose rows are 1D vectors of the partial
        gradient across the output channels.
      kernel_size: kernel size of the layer.
      stride: stride of the layer.

  Returns:
      l_2 norm squared of the gradient as a float (on the CPU).
  """

  # Assuming _check_value_and_shape_of_arguments can handle CuPy arrays
  _check_value_and_shape_of_arguments(
      input_matrix, partial_gradient, kernel_size, stride
  )

  res_gpu = cp.zeros((1,), dtype=input_matrix.dtype)  # Initialize on the GPU

  for input_vector in input_matrix:
    # CuPy's stride_tricks is similar to NumPy's
    u_input_vector = cp.lib.stride_tricks.sliding_window_view(
        input_vector, window_shape=(kernel_size,)
    )[::stride]
    u_input_vector_transpose = u_input_vector.T
    for output in partial_gradient:
      res_gpu += cp.sum(
          cp.square(cp.tensordot(u_input_vector_transpose, output, (1, 0)))
      )

  return res_gpu.item()  # Transfer the scalar result back to the CPU

def cp_in_place_ghost_norm(
    input_matrix_cp: cp.ndarray,
    partial_gradient_cp: cp.ndarray,
    kernel_size: int,
    stride: int,
) -> float:
  """
  Computes the gradient norm squared using CuPy for GPU acceleration.

  Args:
    input_matrix_cp: 2D CuPy matrix (n_in, input_dim).
    partial_gradient_cp: 2D CuPy matrix (n_out, output_dim).
    kernel_size: kernel size of the layer.
    stride: stride of the layer.

  Returns:
    l_2 norm squared of the gradient as a float.
  """
  # --- Input Validation ---
  # It's often better to validate shapes *before* potential large transfers
  # to GPU. If inputs are already on GPU, validate directly.
  # Assuming _check_value_and_shape_of_arguments works with CuPy or
  # you perform checks beforehand.
  _check_value_and_shape_of_arguments(
      input_matrix_cp, partial_gradient_cp, kernel_size, stride
  )

  # --- Get Dimensions ---
  number_input_channels, input_dim = input_matrix_cp.shape
  number_output_channels, output_dimension = partial_gradient_cp.shape

  # Calculate the expected input dimension based on output dim, kernel, stride
  # This helps determine the shape for as_strided
  expected_input_dim = (output_dimension - 1) * stride + kernel_size
  if input_dim != expected_input_dim:
       # Adjust if padding is involved, or raise error if inconsistent
       raise ValueError(f"Input dimension {input_dim} does not match expected "
                        f"dimension {expected_input_dim} based on output_dim, "
                        f"kernel_size, and stride.")


  # --- Vectorize temp_2 Calculation ---
  # temp_2_matrix[j1, j2] = sum_{k} g[k][j1] * g[k][j2]
  # This is equivalent to partial_gradient.T @ partial_gradient
  temp_2_matrix = cp.matmul(partial_gradient_cp.T, partial_gradient_cp)
  # Shape: (output_dimension, output_dimension)

  # --- Vectorize temp_1 Calculation ---
  # We need temp_1_matrix[j1, j2] = sum_{i} dot(patch(x[i], j1), patch(x[i], j2))
  # 1. Create the patch tensor P using as_strided
  #    P[i, j, k] = input_matrix_cp[i, j*stride + k]
  shape = (number_input_channels, output_dimension, kernel_size)
  # Calculate strides for the view:
  # stride_i: distance between elements along axis 0 (input channels)
  # stride_j: distance between elements along axis 1 (output patches)
  # stride_k: distance between elements along axis 2 (kernel elements)
  itemsize = input_matrix_cp.itemsize
  stride_i = input_matrix_cp.strides[0]
  stride_j = stride * input_matrix_cp.strides[1] # stride * itemsize if flat
  stride_k = input_matrix_cp.strides[1]          # itemsize if flat
  strides = (stride_i, stride_j, stride_k)

  # Create the strided view (no data copied)
  patches = cp.lib.stride_tricks.as_strided(input_matrix_cp, shape=shape, strides=strides)
  # Shape: (n_in, output_dim, kernel_size)

  # 2. Compute the sum of outer products: sum_i (P[i] @ P[i].T)
  #    where P[i] has shape (output_dim, kernel_size)
  #    Result matrix A[j1, j2] = sum_i dot(P[i, j1, :], P[i, j2, :])
  #    Can use einsum: 'ijk,ilk->jl' sums over i (channels) and k (kernel)
  temp_1_matrix = cp.einsum('ijk,ilk->jl', patches, patches, optimize='optimal')
  # Shape: (output_dimension, output_dimension)

  # --- Final Calculation (Frobenius Inner Product) ---
  # res = sum_{j1, j2} temp_1_matrix[j1, j2] * temp_2_matrix[j1, j2]
  res_cp = cp.sum(temp_1_matrix * temp_2_matrix)

  return res_cp.item()

import cupy as cp

def cp_unfold(
    input_matrix_cp: cp.ndarray,
    kernel_size: int,
    stride: int,
    output_dimension: int,
) -> cp.ndarray:
  """
  Unfolds the input matrix using CuPy, optimized with as_strided.

  Args:
    input_matrix_cp: 2D CuPy matrix (n_in, input_dim).
    kernel_size: kernel size of the layer.
    stride: stride of the layer.
    output_dimension: output dimension of the layer.

  Returns:
    A 2D CuPy matrix (output_dimension, n_in * kernel_size) whose rows
    correspond to the different kernel windows of all the input channels.
  """
  number_input_channels = input_matrix_cp.shape[0]
  input_dim = input_matrix_cp.shape[1] # Get actual input dim

  # Calculate expected input dim for validation if needed (optional here)
  # expected_input_dim = (output_dimension - 1) * stride + kernel_size
  # if input_dim < expected_input_dim: # Check if input is large enough
  #     raise ValueError(...)

  # 1. Create the patch tensor view P using as_strided
  #    P[i, j, k] = input_matrix_cp[i, j*stride + k]
  shape_view = (number_input_channels, output_dimension, kernel_size)

  # Calculate strides for the view
  stride_i = input_matrix_cp.strides[0] # Stride between channels
  stride_j = stride * input_matrix_cp.strides[1] # Stride between patches start
  stride_k = input_matrix_cp.strides[1] # Stride within a patch (along input_dim)
  strides = (stride_i, stride_j, stride_k)

  # Create the strided view (no data copied if possible)
  patches_view = cp.lib.stride_tricks.as_strided(
      input_matrix_cp, shape=shape_view, strides=strides
  )
  # Shape: (n_in, output_dim, kernel_size)

  # 2. Transpose and reshape to match the original _unfold output format
  # Target shape: (output_dimension, n_in * kernel_size)
  # Transpose: (output_dim, n_in, kernel_size)
  # Reshape: (output_dim, n_in * kernel_size)
  # Note: transpose and reshape might create a copy if memory isn't contiguous
  unfolded_matrix = patches_view.transpose(1, 0, 2).reshape(output_dimension, -1)

  return unfolded_matrix

import cupy as cp
import cupy.fft

# Assuming _check_value_and_shape_of_arguments is defined elsewhere
# and works with CuPy arrays or is called before GPU transfer.
# from your_module import _check_value_and_shape_of_arguments

def cp_in_place_norm_fft(
    input_matrix_cp: cp.ndarray,
    partial_gradient_cp: cp.ndarray,
    kernel_size: int,
    stride: int,
) -> float:
  """
  Computes the gradient norm squared using CuPy FFT for GPU acceleration.

  Args:
    input_matrix_cp: 2D CuPy matrix (n_in, input_dim).
    partial_gradient_cp: 2D CuPy matrix (n_out, output_dim).
    kernel_size: kernel size of the layer.
    stride: stride of the layer.

  Returns:
    l_2 norm squared of the gradient as a float.
  """
  # --- Input Validation ---
  _check_value_and_shape_of_arguments(
      input_matrix_cp, partial_gradient_cp, kernel_size, stride
  )

  # --- Get Dimensions ---
  number_input_channels, input_dimension = input_matrix_cp.shape
  number_output_channels, output_dimension = partial_gradient_cp.shape

  # Ensure data type supports complex numbers for FFT
  if not cp.can_cast(input_matrix_cp.dtype, cp.complex64):
      input_matrix_cp = input_matrix_cp.astype(cp.float32) # Use float32 base
  if not cp.can_cast(partial_gradient_cp.dtype, cp.complex64):
       partial_gradient_cp = partial_gradient_cp.astype(cp.float32)

  # --- Vectorized Padding ---
  padded_partial_gradient_batch = cp.zeros(
      (number_output_channels, input_dimension), dtype=partial_gradient_cp.dtype
  )
  upper_bound = (output_dimension - 1) * stride + 1
  if upper_bound > input_dimension:
       raise ValueError("Input dimension too small for output dim/stride/kernel.")
  padded_partial_gradient_batch[:, :upper_bound:stride] = partial_gradient_cp

  # --- Batch FFTs ---
  # FFT of padded gradients (batch along axis 0)
  partial_derivative_fft_batch = cp.fft.fft(padded_partial_gradient_batch, axis=1)
  # Shape: (n_out, input_dim)

  # FFT of flipped input channels (batch along axis 0)
  flipped_input = cp.flip(input_matrix_cp, axis=1)
  vector_in_fft_batch = cp.fft.fft(flipped_input, axis=1)
  # Shape: (n_in, input_dim)

  # --- Combine FFTs (Broadcasting) ---
  # vector_in_fft_batch shape: (n_in, 1, input_dim)
  # partial_derivative_fft_batch shape: (1, n_out, input_dim)
  multiplied_ffts = (
      vector_in_fft_batch[:, None, :] * partial_derivative_fft_batch[None, :, :]
  )
  # Shape: (n_in, n_out, input_dim)

  # --- Batch IFFT ---
  temp_batch = cp.fft.ifft(multiplied_ffts, axis=2)
  # Shape: (n_in, n_out, input_dim)

  # --- Flip and Extract Relevant Part ---
  temp_flipped_batch = cp.flip(temp_batch, axis=2)
  relevant_part = temp_flipped_batch[:, :, :kernel_size]
  # Shape: (n_in, n_out, kernel_size)

  # --- Calculate Final Result ---
  # Sum of squares of the real part
  res_cp = cp.sum(cp.real(relevant_part) ** 2)

  # --- Return Result ---
  return res_cp.item()

import cupy as cp

def cp_naive_fast_grad_norm(
    input_matrix_cp: cp.ndarray,
    partial_gradient_cp: cp.ndarray,
    kernel_size: int,
    stride: int,
) -> float:
  """
  Computes gradient norm squared using explicit unfolding (CuPy version).

  Memory inefficient but potentially faster for some GPU/problem sizes.

  Args:
    input_matrix_cp: 2D CuPy matrix (n_in, input_dim).
    partial_gradient_cp: 2D CuPy matrix (n_out, output_dim).
    kernel_size: kernel size of the layer.
    stride: stride of the layer.

  Returns:
    l_2 norm squared of the gradient as a float.
  """
  # --- Input Validation ---
  _check_value_and_shape_of_arguments(
      input_matrix_cp, partial_gradient_cp, kernel_size, stride
  )
  output_dimension = partial_gradient_cp.shape[1]

  # --- Unfold Input ---
  unfolded_input_matrix_cp = cp_unfold(
      input_matrix_cp, kernel_size, stride, output_dimension
  )
  # Shape: (output_dim, n_in * kernel_size)

  # --- Compute Gradient Matrix ---
  # grad = U.T @ g.T
  grad_cp = cp.matmul(unfolded_input_matrix_cp.T, partial_gradient_cp.T)
  # Shape: (n_in * kernel_size, n_out)

  # --- Compute Squared Norm ---
  # Equivalent to cp.einsum("ij,ij->", grad_cp, grad_cp) or cp.linalg.norm(grad_cp)**2
  norm_grad_cp = cp.sum(grad_cp * grad_cp) # Often very efficient

  # --- Return Result ---
  return norm_grad_cp.item()

import cupy as cp

# Assuming _check_value_and_shape_of_arguments is defined elsewhere
# Assuming cp_unfold is defined as above
# from your_module import _check_value_and_shape_of_arguments, cp_unfold

def cp_naive_ghost_norm(
    input_matrix_cp: cp.ndarray,
    partial_gradient_cp: cp.ndarray,
    kernel_size: int,
    stride: int,
) -> float:
  """
  Computes gradient norm squared via <UU^T, gg^T> (CuPy version).

  Memory inefficient but potentially faster for some GPU/problem sizes.

  Args:
    input_matrix_cp: 2D CuPy matrix (n_in, input_dim).
    partial_gradient_cp: 2D CuPy matrix (n_out, output_dim).
    kernel_size: kernel size of the layer.
    stride: stride of the layer.

  Returns:
    l_2 norm squared of the gradient as a float.
  """
  # --- Input Validation ---
  _check_value_and_shape_of_arguments(
      input_matrix_cp, partial_gradient_cp, kernel_size, stride
  )
  output_dimension = partial_gradient_cp.shape[1]

  # --- Unfold Input ---
  unfolded_input_matrix_cp = cp_unfold(
      input_matrix_cp, kernel_size, stride, output_dimension
  )
  # Shape: (output_dim, n_in * kernel_size)

  # --- Compute UU^T ---
  # v = U @ U.T
  v_cp = cp.matmul(unfolded_input_matrix_cp, unfolded_input_matrix_cp.T)
  # Shape: (output_dim, output_dim)

  # --- Compute gg^T ---
  # partial_matrix = g.T @ g
  partial_matrix_cp = cp.matmul(partial_gradient_cp.T, partial_gradient_cp)
  # Shape: (output_dim, output_dim)

  # --- Compute Frobenius Inner Product ---
  # Equivalent to cp.tensordot(v_cp, partial_matrix_cp, axes=2)
  norm_grad_cp = cp.sum(v_cp * partial_matrix_cp)

  # --- Return Result ---
  return norm_grad_cp.item()


# Main experiments

## Gradient norm benchmarks on CPU/GPU

In [None]:
#@title Global constants
NUM_REPEATS = 5
STRIDE = 1
N_CHANNEL = 3

### Experiment executors

In [None]:
def compute_setting_params(d, stride, setting_number):
  """Computes (d_in, d_out, d_k)."""
  if setting_number == 1:
    d_k = d // 2
  elif setting_number == 2:
    d_k = d - 13
  elif setting_number == 3:
    d_k = d
  else:
    return ValueError(f"Unknown setting number {setting_number}")

  d_in = d
  d_out = math.floor((d_in - d_k) / stride + 1) # output dimension
  return d_in, d_out, d_k

In [None]:
RuntimeMap = Mapping[str, Mapping[int, Sequence[float]]]
MemoryMap = Mapping[str, Mapping[int, Sequence[float]]]
ResultMap = Mapping[str, Mapping[int, Sequence[float]]]

def run_experiments(
    functions_to_test: Mapping[str, Callable[..., Any]],
    d_list: Sequence[int],
    n_channel: int,
    stride: int,
    num_repeats: int,
    setting_number: int,
    print_results = False,
    aggregation_function = np.median,
) -> tuple[
    RuntimeMap,
    MemoryMap,  # CPU
    MemoryMap,  # GPU
    ResultMap,
]:
  """Computes experiment timings across different dimensions.

  Args:
    functions_to_test: Mapping of which functions to benchmark
    d_list: List of base dimensions.
    n_channel: The number of input/output channels.
    stride: Kernel stride when computing convolutions.
    num_repeats: Number of experiment trials for a given base dimension.
    setting: Which setting in our paper that we are running (values: {1, 2, 3}).
    print_results: Whether to log results to the terminal.
    aggregation_function: Aggregation function for trial runtimes when printing
      to the terminal.

  Returns:
    A mapping from method name to another mapping from base dimension to the
    list of runtimes/memory/results.
  """
  nvidia_smi.nvmlInit()
  assert nvidia_smi.nvmlDeviceGetCount() == 1  # assume single GPU
  handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)

  n_in = n_channel
  n_out = n_channel

  runtimes = defaultdict(lambda: defaultdict(list))
  peak_rams = defaultdict(lambda: defaultdict(list))
  delta_vrams = defaultdict(lambda: defaultdict(list))
  results = defaultdict(lambda: defaultdict(list))
  d_list = [int(d) for d in d_list]

  for d in d_list:
    print(f"Running d = {d}...")
    for name in functions_to_test.keys():
      for _ in range(num_repeats):

        if name.startswith("cp_"):
          cp._default_memory_pool.free_all_blocks()
          start_vram = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used

        dim_in, dim_out, kernel_size = compute_setting_params(
            d, stride, setting_number
        )
        vector_in = np.random.rand(n_in, dim_in)
        partial_derivative = np.random.rand(n_out, dim_out)

        # Use float32 for better GPU performance generally
        dtype_np = np.float32
        dtype_cp = cp.float32

        # Create random data on CPU
        input_matrix_np = np.random.rand(n_in, dim_in).astype(dtype_np)
        partial_gradient_np = np.random.rand(n_out, dim_out).astype(dtype_np)

        # Transfer data to GPU
        input_matrix_cp = cp.asarray(input_matrix_np, dtype=dtype_cp)
        partial_gradient_cp = cp.asarray(partial_gradient_np, dtype=dtype_cp)

        # --- Run and Time Functions ---
        np_args = [input_matrix_np, partial_gradient_np, kernel_size, stride]
        cp_args = [input_matrix_cp, partial_gradient_cp, kernel_size, stride]

        function = functions_to_test[name]
        function_args = {name : np_args if 'np' in name else cp_args}

        # Synchronize for accurate timing if CuPy function
        if name.startswith("cp_"):
          cp.cuda.Stream.null.synchronize()

        # Execution is done here.
        start_time = time.time()
        tracemalloc.start()

        try:
          results[name][d].append(function(*function_args[name]))
          _, peak_local_ram = tracemalloc.get_traced_memory()
          end_time = time.time()
          # Synchronize again for CuPy
          if name.startswith("cp_"):
            cp.cuda.Stream.null.synchronize()
            end_vram = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used
            cp._default_memory_pool.free_all_blocks()
        except:
          end_time = float("inf")
          peak_local_ram = float("inf")
          end_vram = float("inf")

        tracemalloc.stop()
        runtimes[name][d].append(end_time - start_time)
        peak_rams[name][d].append(peak_local_ram)

        if name.startswith("cp_"):
          delta_vrams[name][d].append(end_vram - start_vram)


    if print_results:
      # --- Compare Results and Speedups (Example: compare FFT versions) ---
      print("\n[Time in ms]")
      for name in functions_to_test.keys():
        print(f"{name:<15} {aggregation_function(runtimes[name][d])*1e3:.6f}")
      print("\n[Peak RAM in MB]")
      for name in functions_to_test.keys():
        print(f"{name:<15} {aggregation_function(peak_rams[name][d])/1e6:.6f}")
      print("\n[Delta VRAM in MB]")
      for name in functions_to_test.keys():
        print(f"{name:<15} {aggregation_function(delta_vrams[name][d])/1e6:.6f}")
      print("\n")

  nvidia_smi.nvmlShutdown()
  return runtimes, peak_rams, delta_vrams, results

### Experiment printers

In [None]:
def print_table(
    experiment_results: ResultMap | RuntimeMap | MemoryMap,
    dvalues: Sequence[int],
    aggregation_function = np.median,
    sep=","
):
  print(f"d", end=sep)
  for d in dvalues:
    print(f"{d}", end=sep)
  print()
  for k, v in experiment_results.items():
    print(f"{k}", end=sep)
    for d in dvalues:
      print(f"{aggregation_function(v[d])}", end=sep)
    print()

### CPU benchmarks

In [None]:
CPU_DVALUES = [
    100,
    200,
    400,
    800,
    1_600,
    3_200,
    6_400,
    12_800,
    25_600,
]

# Which methods to benchmark.
CPU_FUNCTIONS = {
  "np_fft": in_place_norm_fft,
  "np_naive_fast": naive_fast_grad_norm,
  "np_naive_ghost": naive_ghost_norm,
}

# d_k = d / 2
cpu_runtimes, cpu_peak_rams, cpu_delta_vrams, cpu_results = run_experiments(
    CPU_FUNCTIONS,
    CPU_DVALUES,
    N_CHANNEL,
    STRIDE,
    NUM_REPEATS,
    setting_number=1,
)

# For copying into Excel
print("\n[Runtime in seconds]")
print_table(cpu_runtimes, CPU_DVALUES)
print("\n[Peak RAM in bytes]")
print_table(cpu_peak_rams, CPU_DVALUES)

Running d = 100...
Running d = 200...
Running d = 400...
Running d = 800...
Running d = 1600...
Running d = 3200...
Running d = 6400...
Running d = 12800...
Running d = 25600...

[Runtime in seconds]
d,100,200,400,800,1600,3200,6400,12800,25600,
np_fft,0.0008511543273925781,0.0005800724029541016,0.0010218620300292969,0.0012087821960449219,0.0016374588012695312,0.002469301223754883,0.003998756408691406,0.0074689388275146484,0.01009988784790039,
np_naive_fast,0.0002238750457763672,0.0003342628479003906,0.0010647773742675781,0.0037941932678222656,0.013909578323364258,0.06909012794494629,0.27854347229003906,1.0922863483428955,24.85206389427185,
np_naive_ghost,0.0005457401275634766,0.0010247230529785156,0.0019216537475585938,0.0066487789154052734,0.027704954147338867,0.1853630542755127,1.1202032566070557,7.788317680358887,78.77303791046143,

[Peak RAM in bytes]
d,100,200,400,800,1600,3200,6400,12800,25600,
np_fft,9192.0,17592.0,34456.0,68120.0,135320.0,269720.0,538520.0,1076120.0,2151320.0,

### GPU benchmarks

In [None]:
GPU_DVALUES = [
    4_000,
    8_000,
    16_000,
    32_000,
    64_000,
    128_000,
    256_000,
    512_000,
    1_024_000,
]

# Which methods to benchmark.
GPU_FUNCTIONS = {
  "cp_fft": cp_in_place_norm_fft,
  "cp_naive_fast": cp_naive_fast_grad_norm,
  "cp_naive_ghost": cp_naive_ghost_norm,
}

# d_k = d / 2
gpu_runtimes, gpu_peak_rams, gpu_delta_vrams, gpu_results = run_experiments(
    GPU_FUNCTIONS,
    GPU_DVALUES,
    N_CHANNEL,
    STRIDE,
    NUM_REPEATS,
    setting_number=1,
)

# For copying into Excel
print("\n[Runtime in seconds]")
print_table(gpu_runtimes, GPU_DVALUES)
print("\n[Peak RAM in bytes]")
print_table(gpu_peak_rams, GPU_DVALUES)
print("\n[Delta VRAM in bytes]")
print_table(gpu_delta_vrams, GPU_DVALUES)

Running d = 4000...
Running d = 8000...
Running d = 16000...
Running d = 32000...
Running d = 64000...
Running d = 128000...
Running d = 256000...
Running d = 512000...
Running d = 1024000...

[Runtime in seconds]
d,4000,8000,16000,32000,64000,128000,256000,512000,1024000,
cp_fft,0.0020003318786621094,0.0021271705627441406,0.002496480941772461,0.0028612613677978516,0.0026824474334716797,0.0025663375854492188,0.0029802322387695312,0.007040739059448242,0.012893199920654297,
cp_naive_fast,0.0016911029815673828,0.0029211044311523438,0.008950233459472656,0.03448319435119629,0.45615363121032715,inf,inf,inf,inf,
cp_naive_ghost,0.019704341888427734,0.10161733627319336,0.9627914428710938,8.625009536743164,inf,inf,inf,inf,inf,

[Peak RAM in bytes]
d,4000,8000,16000,32000,64000,128000,256000,512000,1024000,
cp_fft,8303.0,7103.0,7167.0,6815.0,6815.0,8519.0,8639.0,8015.0,7175.0,
cp_naive_fast,5844.0,5724.0,5668.0,5668.0,43301.0,inf,inf,inf,inf,
cp_naive_ghost,6876.0,8470.0,50407.0,51151.0,inf,inf,i

In [None]:
#@title Debug only
# =================================================================================
# GPU_DVALUES = [
#     64_000,
#     128_000,
#     256_000,
#     512_000,
#     1_024_000,
# ]

# # Which methods to benchmark.
# GPU_FUNCTIONS = {
#   "cp_fft": cp_in_place_norm_fft,
# }

# # d_k = d / 2
# gpu_runtimes, gpu_peak_rams, gpu_delta_vrams, gpu_results = run_experiments(
#     GPU_FUNCTIONS,
#     GPU_DVALUES,
#     N_CHANNEL,
#     STRIDE,
#     NUM_REPEATS,
#     setting_number=1,
# )

# print("\n[Runtime in seconds]")
# print_table(gpu_runtimes, GPU_DVALUES)
# print("\n[Peak RAM in bytes]")
# print_table(gpu_peak_rams, GPU_DVALUES)
# print("\n[Delta VRAM in bytes]")
# print_table(gpu_delta_vrams, GPU_DVALUES)

## End-to-end benchmarks on GPU

### Opacus gradient samplers

In [None]:
@register_norm_sampler(nn.Conv1d)
def compute_conv1d_norm_sample(
    layer: nn.Conv1d,
    inputs: List[torch.Tensor],
    backprops: torch.Tensor,
) -> Dict[nn.Parameter, torch.Tensor]:

  stride = layer.stride[0]
  kernel_size = layer.kernel_size[0]
  input_matrix_cp = cp.asarray(inputs[0].numpy()[0, :, :])
  backprops_cp = cp.asarray(backprops.numpy()[0, :, :])
  number_input_channels, input_dimension = input_matrix_cp.shape
  number_output_channels, output_dimension = backprops_cp.shape

  if not cp.can_cast(input_matrix_cp.dtype, cp.complex64):
    input_matrix_cp = input_matrix_cp.astype(cp.float32) # Use float32 base
  if not cp.can_cast(backprops_cp.dtype, cp.complex64):
    backprops_cp = backprops_cp.astype(cp.float32)

  padded_partial_gradient_batch = cp.zeros(
      (number_output_channels, input_dimension), dtype=backprops_cp.dtype
  )
  upper_bound = (output_dimension - 1) * stride + 1
  padded_partial_gradient_batch[:, :upper_bound:stride] = backprops_cp
  partial_derivative_fft_batch = cp.fft.fft(padded_partial_gradient_batch, axis=1)

  flipped_input = cp.flip(input_matrix_cp, axis=1)
  vector_in_fft_batch = cp.fft.fft(flipped_input, axis=1)
  multiplied_ffts = (
      vector_in_fft_batch[:, None, :] * partial_derivative_fft_batch[None, :, :]
  )
  temp_batch = cp.fft.ifft(multiplied_ffts, axis=2)
  temp_flipped_batch = cp.flip(temp_batch, axis=2)
  relevant_part = temp_flipped_batch[:, :, :kernel_size]
  res_cp = cp.sum(cp.real(relevant_part) ** 2)
  norms = torch.asarray([res_cp.item()])

  return {layer.weight: torch.sqrt(norms)}

In [None]:
#@title Unit test
# n_in = 3
# n_out = 4
# d = 5
# d_out = 7
# batch_size = 1
# layer1 = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=d, stride=1)
# inputs1 = torch.randn(batch_size, n_in, d)
# backprops1 = torch.randn(batch_size, n_out, d)
# compute_conv1d_norm_sample(layer1, inputs1, backprops1)

### Experiment executors

In [None]:
class SampleConv1dModule(nn.Module):
  # stride = 1
  def __init__(self, n, d_k):
    super(SampleConv1dModule, self).__init__()
    self.conv1d = nn.Conv1d(n, n, d_k, stride=1, bias=False)

  def forward(self, x):
    x = self.conv1d(x)
    return(x)

In [None]:
def run_normal_dp_sgd(
   num_iterations,
   num_channels,
   kernel_size,
   input_dim,
   batch_size = 1,
   noise_multiplier = 1.0,
   max_grad_norm = 1.0,
   criterion = torch.nn.MSELoss(reduce="None"),
):
  input_data = torch.rand((batch_size, num_channels, input_dim))
  sample_module = SampleConv1dModule(num_channels, kernel_size)
  model_normal = GradSampleModule(clone_module(sample_module))
  optimizer_normal = torch.optim.SGD(model_normal.parameters(), lr=1)
  optimizer_normal = DPOptimizer(
      optimizer_normal,
      noise_multiplier=noise_multiplier,
      max_grad_norm=max_grad_norm,
      expected_batch_size=batch_size,
  )

  t0 = time.time()
  for _ in range(num_iterations):
    optimizer_normal.zero_grad()
    output_normal = model_normal(input_data)
    target_data = torch.rand_like(output_normal)
    loss_normal = torch.mean(criterion(output_normal, target_data))
    loss_normal.backward()
  return time.time() - t0



In [None]:
def run_gc_dp_sgd(
    num_iterations,
    num_channels,
    kernel_size,
    input_dim,
    batch_size = 1,
    noise_multiplier = 1.0,
    max_grad_norm = 1.0,
    criterion = torch.nn.MSELoss(reduce="None"),
):
  input_data = torch.rand((batch_size, num_channels, input_dim))
  sample_module = SampleConv1dModule(num_channels, kernel_size)
  model_gc = GradSampleModuleFastGradientClipping(
      clone_module(sample_module),
      max_grad_norm=max_grad_norm,
      use_ghost_clipping=True,
  )
  optimizer_gc = torch.optim.SGD(model_gc.parameters(), lr=1)
  optimizer_gc = DPOptimizerFastGradientClipping(
      optimizer_gc,
      noise_multiplier=noise_multiplier,
      max_grad_norm=max_grad_norm,
      expected_batch_size=batch_size,
  )

  t0 = time.time()
  for i in range(num_iterations):
    model_gc.enable_hooks()
    output_gc = model_gc(input_data)
    target_data = torch.rand_like(output_gc)
    first_loss_per_sample = criterion(output_gc, target_data)
    first_loss = torch.mean(first_loss_per_sample)
    first_loss.backward(retain_graph=True)
    optimizer_gc.zero_grad()
    coeff = model_gc.get_clipping_coef()
    second_loss_per_sample = coeff * first_loss_per_sample
    second_loss = torch.sum(second_loss_per_sample)
    model_gc.disable_hooks()
    second_loss.backward()
  return time.time() - t0

In [None]:
def run_all_dp_sgd_executors(
    num_iterations,
    num_channels,
    kernel_size,
    input_dim,
    batch_size = 1,
    noise_multiplier = 1.0,
    max_grad_norm = 1.0,
):
  nvidia_smi.nvmlInit()
  assert nvidia_smi.nvmlDeviceGetCount() == 1  # assume single GPU

  gc.collect()
  torch.cuda.empty_cache()
  cp._default_memory_pool.free_all_blocks()
  try:
    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
    start_vram = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used
    dp_sgd_time = run_normal_dp_sgd(
        num_iterations,
        num_channels,
        kernel_size,
        input_dim,
        batch_size,
        noise_multiplier,
        max_grad_norm,
    )
    end_vram = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used
    dp_sgd_vram = end_vram - start_vram
  except:
    dp_sgd_time = float("inf")
    dp_sgd_vram = float("inf")

  gc.collect()
  torch.cuda.empty_cache()
  cp._default_memory_pool.free_all_blocks()
  try:
    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
    start_vram = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used
    gc_dp_sgd_time = run_gc_dp_sgd(
        num_iterations,
        num_channels,
        kernel_size,
        input_dim,
        batch_size,
        noise_multiplier,
        max_grad_norm,
    )
    end_vram = nvidia_smi.nvmlDeviceGetMemoryInfo(handle).used
    gc_dp_sgd_vram = end_vram - start_vram
  except:
    gc_dp_sgd_time = float("inf")
    gc_dp_sgd_vram = float("inf")

  return dp_sgd_time, gc_dp_sgd_time, dp_sgd_vram, gc_dp_sgd_vram

### GPU benchmarks

#### Both naive and fast clipping

In [None]:
GPU_E2E_DVALUES = [500, 1_000, 2_000, 4_000, 8_000]
dp_sgd_times = []
gc_dp_sgd_times = []
dp_sgd_vrams = []
gc_dp_sgd_vrams = []
for d in GPU_E2E_DVALUES:
  print(f"d = {d}")
  t0, t1, m0, m1 = run_all_dp_sgd_executors(
    num_iterations=5,
    num_channels=1,
    kernel_size=d//2,
    input_dim=d,
    batch_size=128,
  )
  dp_sgd_times.append(t0)
  gc_dp_sgd_times.append(t1)
  dp_sgd_vrams.append(m0)
  gc_dp_sgd_vrams.append(m1)

In [None]:
print("d,", end="")
for d in GPU_E2E_DVALUES:
  print(d, end=",")
print("\ndp_sgd_time,", end="")
for t in dp_sgd_times:
  print(t, end=",")
print("\ngc_dp_sgd_time,", end="")
for t in gc_dp_sgd_times:
  print(t, end=",")

#### Fast clipping only

d=16K for naive DP-SGD OOMs

In [None]:
GPU_E2E_EXT_DVALUES = [16_000, 32_000, 64_000]
gc_dp_sgd_ext_times = []
for d in GPU_E2E_EXT_DVALUES:
  print(f"d = {d}")
  t = run_gc_dp_sgd(
    num_iterations=5,
    num_channels=1,
    kernel_size=d//2,
    input_dim=d,
    batch_size=128,
  )
  gc_dp_sgd_ext_times.append(t)

In [None]:
print("d,", end="")
for d in GPU_E2E_EXT_DVALUES:
  print(d, end=",")
print("\ngc_dp_sgd_time,", end="")
for t in gc_dp_sgd_ext_times:
  print(t, end=",")

# Additional experiments

## Gradient norm benchmarks on CPU

### CPU benchmarks

#### Setting 1

In [None]:
D_LIST1 = [400, 800, 1_600, 3_200, 6_400]

# Which methods to benchmark.
FUNCTIONS1 = {
  "np_fft": in_place_norm_fft,
  "np_fast_grad": in_place_fast_grad_norm,
  # "np_in_place": in_place_ghost_norm,
  "np_naive_fast": naive_fast_grad_norm,
  "np_naive_ghost": naive_ghost_norm,
}

# Setting 1: d_k = d_in / 2
print("Setting 1\n")
(
    cpu_runtimes_ex1,
    cpu_peak_rams_ex1,
    cpu_delta_vrams_ex1,
    cpu_results_ex1,
) = run_experiments(
    FUNCTIONS1,
    D_LIST1,
    N_CHANNEL,
    STRIDE,
    NUM_REPEATS,
    setting_number=1,
)

Setting 1

Running d = 400...
Running d = 800...
Running d = 1600...
Running d = 3200...
Running d = 6400...


In [None]:
# For copying into Excel
print("\n[Runtime in seconds]")
print_table(cpu_runtimes_ex1, D_LIST1)
print("\n[Peak RAM in bytes]")
print_table(cpu_peak_rams_ex1, D_LIST1)


[Runtime in seconds]
d,400,800,1600,3200,6400,
np_fft,0.0006990432739257812,0.0012874603271484375,0.0015435218811035156,0.0024542808532714844,0.003905773162841797,
np_fast_grad,0.004522562026977539,0.005242824554443359,0.007205486297607422,0.013261556625366211,0.08928394317626953,
np_naive_fast,0.0010628700256347656,0.0037775039672851562,0.014162540435791016,0.07788252830505371,0.29234766960144043,
np_naive_ghost,0.0019783973693847656,0.006479740142822266,0.029323101043701172,0.19436955451965332,1.1241252422332764,

[Peak RAM in bytes]
d,400,800,1600,3200,6400,
np_fft,34456.0,68120.0,135320.0,269720.0,538520.0,
np_fast_grad,164319.0,646143.0,2569799.0,10255743.0,40993047.0,
np_naive_fast,1287720.0,5134216.0,20507016.0,81973968.0,327792291.0,
np_naive_ghost,1287720.0,5134216.0,20507168.0,81973088.0,327788552.0,


#### Setting 2

In [None]:
D_LIST2 = [64_000, 128_000, 256_000, 512_000, 1_024_000]

# Which methods to benchmark.
FUNCTIONS2 = {
  "np_fft": in_place_norm_fft,
  "np_fast_grad": in_place_fast_grad_norm,
  "np_in_place": in_place_ghost_norm,
  "np_naive_fast": naive_fast_grad_norm,
  "np_naive_ghost": naive_ghost_norm,
}

# Setting 2: d_k = d_in - 13
print("Setting 2\n")
(
    cpu_runtimes_ex2,
    cpu_peak_rams_ex2,
    cpu_delta_vrams_ex2,
    cpu_results_ex2,
) = run_experiments(
    FUNCTIONS2,
    D_LIST2,
    N_CHANNEL,
    STRIDE,
    NUM_REPEATS,
    setting_number=2,
)

Setting 2

Running d = 64000...
Running d = 128000...
Running d = 256000...
Running d = 512000...
Running d = 1024000...


In [None]:
# For copying into Excel
print("\n[Runtime in seconds]")
print_table(cpu_runtimes_ex2, D_LIST2)
print("\n[Peak RAM in bytes]")
print_table(cpu_peak_rams_ex2, D_LIST2)


[Runtime in seconds]
d,64000,128000,256000,512000,1024000,
np_fft,0.04720807075500488,0.06013035774230957,0.14534854888916016,0.32793593406677246,0.8023104667663574,
np_fast_grad,0.029539108276367188,0.02435159683227539,0.04384875297546387,0.20948457717895508,0.27891993522644043,
np_in_place,0.044385433197021484,0.05373072624206543,0.06265926361083984,0.11540675163269043,0.18042659759521484,
np_naive_fast,0.019695281982421875,0.04124641418457031,0.08883976936340332,0.17839574813842773,0.38510656356811523,
np_naive_ghost,0.023937225341796875,0.03459930419921875,0.06670308113098145,0.16159534454345703,0.3193166255950928,

[Peak RAM in bytes]
d,64000,128000,256000,512000,1024000,
np_fft,5377312.0,10753312.0,21511520.0,43019901.0,86037512.0,
np_fast_grad,3842419.0,7682123.0,15363098.0,30729138.0,61453472.0,
np_in_place,852.0,3553.0,3108.0,3388.0,8316.0,
np_naive_fast,28667680.0,57339680.0,114684040.0,229378768.0,458757838.0,
np_naive_ghost,28667528.0,57339528.0,114683888.0,229378824.0,458

#### Setting 3

In [None]:
D_LIST3 = [10]
N_LIST = [40, 80, 160, 320, 640]

# Which methods to benchmark.
FUNCTIONS3 = {
  "np_fft": in_place_norm_fft,
  # "np_fast_grad": in_place_fast_grad_norm,
  "np_in_place": in_place_ghost_norm,
  "np_naive_fast": naive_fast_grad_norm,
  "np_naive_ghost": naive_ghost_norm,
}

# Setting 3: variable n
print("Setting 3\n")
runtimes_ex3 = []
peak_rams_ex3 = []
for n in N_LIST:
  print(f"n = {n}")
  (
      cpu_runtimes_ex3,
      cpu_peak_rams_ex3,
      _,
      _,
  ) = run_experiments(
      FUNCTIONS3,
      D_LIST3,
      n,
      STRIDE,
      NUM_REPEATS,
      setting_number=3
  )
  runtimes_ex3.append(cpu_runtimes_ex3)
  peak_rams_ex3.append(cpu_peak_rams_ex3)

Setting 3

n = 40
Running d = 10...
n = 80
Running d = 10...
n = 160
Running d = 10...
n = 320
Running d = 10...
n = 640
Running d = 10...


In [None]:
def print_channel_table(
    experiment_results: ResultMap | RuntimeMap | MemoryMap,
    n_values: Sequence[int],
    d_value: int,
    aggregation_function = np.median,
    sep=","
):
  print(f"n", end=sep)
  for n in n_values:
    print(f"{n}", end=sep)
  print()
  methods = list(experiment_results[0].keys())
  for m in methods:
    print(m, end=sep)
    for e in experiment_results:
      print(f"{aggregation_function(e[m][d_value])}", end=sep)
    print()

In [None]:
# For copying into Excel
print("\n[Runtime in seconds]")
print_channel_table(runtimes_ex3, N_LIST, D_LIST3[0])
print("\n[Peak RAM in bytes]")
print_channel_table(peak_rams_ex3, N_LIST, D_LIST3[0])


[Runtime in seconds]
n,40,80,160,320,640,
np_fft,0.11008262634277344,0.29825496673583984,1.061675786972046,4.766016721725464,19.153496503829956,
np_in_place,0.003210306167602539,0.0034322738647460938,0.006872653961181641,0.013518571853637695,0.034094810485839844,
np_naive_fast,0.0008285045623779297,0.0008981227874755859,0.0024089813232421875,0.00498199462890625,0.016141414642333984,
np_naive_ghost,0.0010898113250732422,0.0008485317230224609,0.0014989376068115234,0.002501964569091797,0.0048220157623291016,

[Peak RAM in bytes]
n,40,80,160,320,640,
np_fft,5696.0,11087.0,50924.0,52227.0,56192.0,
np_in_place,540.0,540.0,540.0,668.0,668.0,
np_naive_fast,100476.0,294076.0,1065276.0,4143676.0,16444476.0,
np_naive_ghost,15992.0,31960.0,63320.0,126552.0,251992.0,
