# Optimizing optimizers

in PyTorch

#### Runtime vs memory

Runtime and memory usage are often at odds with each other





you're towing 512 cars from A to B. which truck do you take?

#### Runtime vs memory

Runtime and memory usage are often at odds with each other



you're towing 512 cars from A to B. which truck do you take? what if the only way to B had a low clearance bridge?

#### Runtime vs memory

Runtime and memory usage are often at odds with each other







Today, we focus on speeeeeeeed!

yes this does mean memory will take a hit, disclaimer

Fusion, the high level idea

# your simplest optimizer

for loop/single tensor



https://github.com/pytorch/pytorch/blob/b5ba80828f77c565bcda7558da97c792af32d517/torch/optim/adamw.py#L362

```
for i, param in enumerate(params):
   grad = grads[i] if not maximize else -grads[i]
   exp_avg = exp_avgs[i]
   exp_avg_sq = exp_avg_sqs[i]
   step_t = state_steps[i]
   # update step
   step_t += 1
   # Perform stepweight decay
   param.mul_(1 - lr * weight_decay)
   # Decay the first and second moment running average coefficient
   exp_avg.lerp_(grad, 1 - beta1)
   exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
   step = _get_value(step_t)
   bias_correction1 = 1 - beta1 ** step
   bias correction2 = 1 - beta2 ** step
   step_size = lr / bias_correction1
   bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)
   denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
   param.addcdiv_(exp_avg, denom, value=-step_size)
```

## horizontally fused optimizer

foreach



```
torch._foreach_add_(device_state_steps, 1)
# Perform stepweight decay
if weight_decay != 0:
   torch._foreach_mul_(device_params, 1 - lr * weight_decay)
# Decay the first and second moment running average coefficient
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)
torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads,
device_grads, 1 - beta2)
. . .
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
torch._foreach_add_(exp_avg_sq_sqrt, eps)
torch._foreach_addcdiv_(device_params, device_exp_avgs,
exp_avg_sq_sqrt, step_size)
```

https://github.com/pytorch/pytorch/blob/b5ba80828f77c565bcda7558da97c792af32d517/torch/optim/adamw.py#L480

# entirely fused optimizer

...fused

(thanks NVIDIA)



```
torch. fused adamw (
    device_params,
    device_grads,
    device_exp_avgs,
    device_exp_avg_sqs,
    device_max_exp_avg_sqs,
    device_state_steps,
    amsgrad=amsgrad,
    lr=lr,
    beta1=beta1,
    beta2=beta2,
    weight_decay=weight_decay,
    eps=eps,
    maximize=maximize,
    grad_scale=device_grad_scale,
    found_inf=device_found_inf,
```

The Gist: the fewer kernels you launch on CUDA, the faster.



Fusion, the nitty gritty

## starting with multi\_tensor\_apply

you know how mitochondria is the powerhouse of the cell?

multi\_tensor\_apply is the powertruck of our speedy optimizers.







multi\_tensor\_apply allows us to operate over a list of Tensors vs a single tensor.

#### Example with torch.add



```
add(
Tensor self, Tensor other, *,
Scalar alpha=1
) -> Tensor
```

```
_foreach_add(
    Tensor[] self, Tensor[] other, *,
    Scalar alpha=1
) -> Tensor[]
```

#### Under the hood in CUDA

```
add(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
```

A simplified CUDA kernel signature, assuming we have float Tensors:

```
_foreach_add(Tensor[] self,
Tensor[] other, *, Scalar alpha=1)
-> Tensor[]
```

How would you write this one?

#### Attempt 1: use std::vector

```
_foreach_add(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]
__device__ void _foreach_add_kernel(
        std::vector(float*) self.
        std::vector(float*) other.
        std::vector(float*) res.
        float alpha=1) {
Does this work?
```

```
add(Tensor self, Tensor
other, *, Scalar alpha=1) ->
Tensor
__device__ void add_kernel(
        float* self.
        float* other.
        float* res.
        float alpha=1) {
```

#### Attempt 1: use std::vector

```
_foreach_add(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]
__device__ void _foreach_add_kernel(
        std::vector(float*) self.
        std::vector(float*) other.
        std::vector(float*) res.
        float alpha=1) {
Does this work?
No because CUDA doesn't recognize std::vectors!
```

```
add(Tensor self, Tensor
other, *, Scalar alpha=1) ->
Tensor
__device__ void add_kernel(
        float* self.
        float* other.
        float* res.
        float alpha=1) {
```

```
_foreach_add(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]
__device__ void _foreach_add_kernel(
        float** self.
        float** other.
        float** res.
        float alpha=1) {
Does this work?
```

```
add(Tensor self, Tensor
other, *, Scalar alpha=1) ->
Tensor
__device__ void add_kernel(
        float* self.
        float* other.
        float* res.
        float alpha=1) {
. . .
```

```
_foreach_add(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]
__device__ void _foreach_add_kernel(
        float** self.
        float** other.
                                                    add(Tensor self, Tensor
        float** res.
                                                    other, *, Scalar alpha=1) ->
        float alpha=1) {
                                                    Tensor
                                                    __device__ void add_kernel(
                                                            float* self.
Does this work?
                                                             float* other.
Nope! This will cause an Illegal Memory Access (IMA)
                                                             float* res.
because the outer pointer * is a CPU address!
                                                             float alpha=1) {
```

Look again at the add\_kernel.



```
add(Tensor self, Tensor
other, *, Scalar alpha=1) ->
Tensor
__device__ void add_kernel(
        float* self.
        float* other.
        float* res.
        float alpha=1) {
. . .
```

When we dereference in the CUDA kernel, it is OK! The address is in GPU.



```
add(Tensor self, Tensor
other, *, Scalar alpha=1) ->
Tensor
__device__ void add_kernel(
         float* self.
         float* other.
         float* res.
  you're going to dereference in her
```

But not so our \_foreach\_add kernel. While Tensors live in GPU, the list part of TensorLists live in CPU.



But not so our \_foreach\_add kernel. While Tensors live in GPU, the list part of TensorLists live in CPU.

Dereferencing a CPU address from GPU =>

Tensor data
Tensor's data

address of

Tensor data, for a CUDA Tensor

Tensor's data\_ptr()

address of the Tensor List

address of the Tensor List kernel argument space

#### Attempt 3: pass by chonky boi (not reference)

```
struct TensorListMetadata {
 const float* addresses[3][NUM_TENSORS];
(add all the addresses into the struct)
__device__ void _foreach_add_kernel(
        TensorListMetadata tlm.
       float alpha=1) {
Does this work?
```

```
add(Tensor self, Tensor
other, *, Scalar alpha=1) ->
Tensor
__device__ void add_kernel(
        float* self.
        float* other.
        float* res.
        float alpha=1) {
```

### Attempt 3: pass by chonky boi (not reference)

```
struct TensorListMetadata {
 const float* addresses[3][NUM_TENSORS];
                                                                         green = CUDA/GPU
                                                     purple = CPU
(add all the addresses into the struct)
__device__ void _foreach_add_kernel(
         TensorListMetadata tlm.
         float alpha=1) {
                                                                  Tensor data, for a CUDA Tensor
                                                                  Tensor's data-ptr()
                                                                  address of the Tensor List
Does this work? It passes CI! Yay!
                                                                  kernel argument space
```

~ the end ~

## ~ the end ~

### I actually did land a PR like this and it got reverted:(

Cuz some! how! an illegal memory access happened for some models (like timm\_efficientdet).

I minified the repro to the following and played around with N.

```
params = [torch.rand(2, 3, device="cuda") for _ in range(N)]
torch._foreach_norm(params, ord=1)
torch.cuda.synchronize()
```

Now would you join me on my (binary) search!

(yes, this example is for norm and not add, but the principle is the same!)

```
params = [torch.rand(2, 3, device="cuda") for _ in range(N)]
torch._foreach_norm(params, ord=1)
torch.cuda.synchronize()
```

N = 500 🙀

```
params = [torch.rand(2, 3, device="cuda") for _ in range(N)]
torch._foreach_norm(params, ord=1)
torch.cuda.synchronize()
```

```
N = 500 N = 256
```

```
params = [torch.rand(2, 3, device="cuda") for _ in range(N)]
torch._foreach_norm(params, ord=1)
torch.cuda.synchronize()
```

```
N = 500 🔆
```

$$N = 400$$

```
params = [torch.rand(2, 3, device="cuda") for _ in range(N)]
torch._foreach_norm(params, ord=1)
torch.cuda.synchronize()
```

```
N = 500 💥
```

```
params = [torch.rand(2, 3, device="cuda") for _ in range(N)]
torch._foreach_norm(params, ord=1)
torch.cuda.synchronize()
```

```
N = 500 💥
```

```
params = [torch.rand(2, 3, device="cuda") for _ in range(N)]
torch._foreach_norm(params, ord=1)
torch.cuda.synchronize()
```

```
N = 500 💥
```

$$N = 400$$

$$N = 412$$

```
params = [torch.rand(2, 3, device="cuda") for _ in range(N)]
torch._foreach_norm(params, ord=1)
torch.cuda.synchronize()
```

```
N = 500 🔆
```

$$N = 420$$

```
params = [torch.rand(2, 3, device="cuda") for _ in range(N)]
torch._foreach_norm(params, ord=1)
torch.cuda.synchronize()
```

- N = 500 💥
- N = 256 👌
- N = 400
- N = 450
- N = 425
- N = 412 👌
- N = 420 👌
- N = 423 👌

```
params = [torch.rand(2, 3, device="cuda") for _ in range(N)]
torch._foreach_norm(params, ord=1)
torch.cuda.synchronize()
```

- N = 500 💥
- N = 256 👌
- N = 400
- N = 450 🔆
- N = 425
- N = 412 👌
- $N = 420 \frac{1}{2}$
- N = 423 👌
- N = 424

#### Let's binary search over N.

```
params = [torch.rand(2, 3, device="cuda") for _ in range(N)]
torch_foreach_norm(params, ord=1)
torch.cuda.synchronize()
N = 500 🙀
N = 256
N = 400 
N = 450 🔆
N = 425
N = 412
N = 420 
N = 423
                    what is so special about these numbers?
N = 424
                    what could be going on here?
```

### Attempt 3: pass by chonky boi (not reference) cont.

```
struct TensorListMetadata {
 const float* addresses[3][NUM_TENSORS];
                                                                              = CUDA/GPU
                                                    purple = CPU
(add all the addresses into the struct)
__device__ void _foreach_add_kernel(
         TensorListMetadata tlm.
        float alpha=1) {
                                                                 Tensor data, for a CUDA Tensor
                                                                 Tensor's data-ptr()
                                                                 address of the Tensor List
Does this work? Only if NUM_TENSORS < 424?
                                                                 kernel argument space
```

### Attempt 3: pass by chonky boi (not reference) cont.



Fun fact: CUDA Kernel argument space has a max limit of 4KB  $\circlearrowright$ 

### Attempt 3: pass by chonky boi (not reference) cont.

**Expectation:** Reality:

Fun fact: CUDA Kernel argument space has a max limit of 4KB  $\stackrel{\text{CU}}{\text{CU}}$  so what now?

#### Attempt 4: just launch more kernels; make more trips









```
struct TensorListMetadata {
  const float*
addresses[3][MAX_NUM_TENSORS];
};
```

<make multiple structs>
<add a chunk of the addresses to each>
<launch the kernel multiple times>

3

#### Attempt 4 is what we do today. But we could do better.

While we claim to horizontally fuse into 1 kernel...we often end up with more:



#### How: revisit attempt 2

We wanna turn those purple \* to green \*.

How? Move them to CUDA beforehand! (thanks Yifu Wang!)



CPU Memory







#### How: revisit attempt 2

Pack all these vectors of pointers into one tensor, then copy to CUDA

CUDA/GPU purple = CPU (memcpy the lists of addresses to CUDA) \_\_device\_\_ void \_foreach\_add\_kernel( float\*\* self. float\*\* other. float\*\* res. float alpha=1) { CPU Memory GPU memor This is now a pointer to Tensor data, for a CUDA Tensor pointer, i.e., float\*\*

We thereby avoid the 4KB constraint in the kernel argument to enable launching just one kernel. But remember, memcpy is \$\$\$!

Tensor's data\_ptr()

address of the Tensor List

kernel argument space

#### Conclusion: we will be doing a mix of struct + memcpy

Can we use

if it fits, use the struct

purple = CPU

green = CUDA/GPU







otherwise, take the memcpy hit

lunified memory here? (currently not supported in PvTorch)



#### Did you notice I split up the fused implementation too?

This is cuz our fastest fused impls also rely on multi\_tensor\_apply!



#### Did you notice I split up the fused implementation too?

Whereas \_foreach\_add will call multi\_tensor\_apply with a Callable that does addition, \_fused\_adamw will call multi\_tensor\_apply with a bigger Callable.

#### Did you notice I split up the fused implementation too?

Whereas \_foreach\_add will call multi\_tensor\_apply with a Callable that does addition, \_fused\_adamw will call multi\_tensor\_apply with a bigger Callable.

```
AT DISPATCH FLOATING TYPES AND2
   kHalf,
   kBFloat16,
   params[0].scalar type(),
   "fused_adamw_kernel_cuda",
    [&]() {
     multi_tensor_apply_for_fused_optimizer<4>(
          tensor_lists,
          state steps,
         FusedAdamMathFunctor<scalar_t, 4, ADAM_MODE::ADAMW, false>()
         lr ptr,
         1.0, // unused
          beta1,
          beta2,
          weight_decay,
          eps,
          maximize,
          grad_scale_ptr,
          found_inf_ptr);
```

## so let's peek at FusedAdamMathFunctor

```
struct FusedAdamMathFunctor {
                                                                                                       C10 DEVICE inline void adam math(
 C10_DEVICE __forceinline__ void operator()(
                                                                                                       #pragma unroll
                                                                                                                                                 Math part of the optimizer
                                                                                                         for (int ii = 0; ii < kILP; ii++) {
   const auto tensor loc = tl.block to tensor[blockIdx.x];
                                                                                                           // Load values.
                                                             Locate the thread
   const auto chunk idx = tl.block to chunk[blockIdx.x];
                                                                                                           opmath t param = static cast<opmath t>(r args[kParamIdx][ii]);
   const double lr double = lr ptr ? *lr ptr : lr;
                                                                                                           opmath t grad = static cast<opmath t>(r args[kGradIdx][ii]);
                                                                                                           if (grad scale ptr) {
   if (found_inf_ptr && *found_inf_ptr == 1) {
                                                                                                             grad /= (static cast<double>(*grad scale ptr));
                                                                                                           const opmath t grad to store = grad;
   const auto [bias_correction1: <dependent type>, bias_correction2_sqrt: <dependent type>] =
                                                                                                           if (maximize) {
        [&]() -> std::pair<double, double> {
                                                                                                             grad = -grad;
     auto* step count: const float * =
         reinterpret cast<const float*>(tl.state steps addresses[tensor loc]);
                                                                                                           opmath_t exp_avg = static_cast<opmath_t>(r_args[kExpAvgIdx][ii]);
     const auto bias correction1: double const = 1 - at::native::pow (base: beta1, exp: *step count);
                                                                                                           opmath_t exp_avq_sq = static_cast<opmath_t>(r_args[kExpAvqSqIdx][ii]);
     const auto bias_correction2: double const = 1 - at::native::pow_(base: beta2, exp: *step_count);
                                                                                                           opmath_t max_exp_avq_sq;
     const auto bias_correction2_sqrt: const double = std::sqrt(x: bias_correction2);
                                                                                                           if (amsgrad) {
     return {a: bias_correction1, b: bias_correction2_sqrt};
                                                                                                             max_exp_avq_sq = static_cast<opmath_t>(r_arqs[kMaxExpAvqSqIdx][ii]);
                                                                                                           // Update param, grad, 1st and 2nd order momentum.
   scalar type* args[depth];
                                                                                                           if (weight decay != 0) {
   scalar type r args[depth][kILP];
                                                                                                             if constexpr (adam_mode == ADAM_MODE::ORIGINAL) {
   const auto n = tl.numel for tensor[tensor loc] - chunk idx * chunk size;
                                                                                                               grad += param * weight_decay;
                                                                                                             } else if constexpr (adam_mode == ADAM_MODE::ADAMW) {
   const bool all_aligned{
                                                                                                               param -= lr * weight_decay * param;
        init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc)};
   if ((n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned) {
     for (int64 t i start = threadIdx.x:
                                                                                                           // todo(crcrpar): use lerp
          i start * kILP < n && i start * kILP < chunk size:
                                                                                                           // ref: https://developer.nvidia.com/blog/lerp-faster-cuda/
          i start += blockDim.x) {
                                                                                                           exp_avg = beta1 * exp_avg + (1 - beta1) * grad;
#pragma unroll
                                                                                                           exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad * grad;
        for (int i = 0; i < depth; i++) {
                                                                                                           const opmath_t step_size = lr / bias_correction1;
          load store(dst: r_args[i], src: args[i], dst_offset: 0, src_offset: i_start);
                                                                                                           opmath t denom:
                                                                                                           if (amsgrad) {
       adam_math<scalar_type, opmath_t, depth, adam_mode, amsgrad>(
                                                                                                             max_exp_avg_sq = std::max(max_exp_avg_sq, exp_avg_sq);
            r_args,
                                                                                                             denom = (std::sqrt(max_exp_avg_sq) / bias_correction2_sqrt) + eps;
            lr: lr_double,
```

...that was very manual.

What if we could automate vertical fusion? with just 1 line?

# Enter torch.compile()

### torch.compile()'s strength is vertical fusion



#### How do I use torch.compile() with optimizers?

```
optimizer = torch.optim.AdamW(params)

@torch.compile(fullgraph=False)
def compiled_step():
    optimizer.step()
```

Now call <a href="mailto:compiled\_step">compiled\_step</a> instead of optimizer.step() in your training loop. That's it!

(okay I suppose that was 2 lines)

#### Inductor will generate a lorge triton kernel

```
xpid_offset = xpid - 0
                                                            tmp15 = tmp13 * tmp14
xnumel = 512
                                                            tmp16 = tl.sqrt(tmp12)
xoffset = xpid_offset * XBLOCK
                                                            tmp19 = tl,math,pow(tmp7, tmp18)
xindex = xoffset + tl.arange(0, XBLOCK)[:]
                                                            tmp20 = 1.0
xmask = xindex < xnumel
                                                            tmp21 = tmp19 - tmp20
x0 = xindex
                                                            tmp22 = -tmp21
tmp0 = tl.load(in_ptr0 + (x0), xmask)
                                                            tmp23 = tl.sqrt(tmp22)
tmp1 = tl.load(in_ptr1 + (x0), xmask)
                                                            tmp24 = tmp16 / tmp23
tmp6 = tl.load(in_ptr2 + (x0), xmask)
                                                            tmp25 = 1e-08
tmp13 = tl.load(in_ptr3 + (x0), xmask)
                                                            tmp26 = tmp24 + tmp25
tmp17 = tl.load(in ptr4 + (0))
                                                            tmp27 = 0.9
tmp18 = tl.broadcast_to(tmp17, [XBLOCK])
                                                            tmp28 = tl.math.pow(tmp27, tmp18)
tmp2 = tmp1 - tmp0
                                                            tmp29 = tmp28 - tmp20
tmp30 = 0.001
tmp4 = tmp2 * tmp3
                                                            tmp31 = tmp29 / tmp30
tmp5 = tmp0 + tmp4
                                                            tmp32 = 1 / tmp31
tmp7 = 0.999
                                                            tmp33 = tmp26 / tmp32
tmp8 = tmp6 * tmp7
                                                            tmp34 = tmp5 / tmp33
tmp9 = tmp1 * tmp1
                                                            tmp35 = tmp15 + tmp34
tmp10 = 0.0010000000000000000
                                                            tl.store(out_ptr0 + (x0), tmp5, xmask)
tmp11 = tmp9 * tmp10
                                                            tl.store(out_ptr3 + (x0), tmp35, xmask)
tmp12 = tmp8 + tmp11
                                                            tl.store(out_ptr4 + (x0), tmp12, xmask)
tmp14 = 0.99999
```

#### When does it work? (or not work?)

- You must have CUDA capability 7.0+ for Triton
- All optimizers in pytorch/pytorch with a foreach implementation are now compilable
  - So everything except L-BFGS and SparseAdam
- Vertical fusion of any sequence of supported \_foreach\_\* ops should work!
  - try out your experimental optimizers!
  - open an issue when this isn't true

Compiled optimizers is in beta! Try it out and complain lots here!

So should you stop learning CUDA?



thanks! questions?