
# Gradient Checkpointing — Deep Explanation

Gradient checkpointing is a technique that **reduces GPU memory usage by not storing some intermediate activations during the forward pass**.

Instead, during backward pass, PyTorch **recomputes** those activations on-the-fly.

You **trade more compute** for **less memory**.

---

# Why do we need it?

Normally, for backpropagation, PyTorch must keep **all intermediate activations** because gradients require them:

Backprop step for a layer:

$$
\nabla W = \frac{\partial \mathcal{L}}{\partial y} \cdot \frac{\partial y}{\partial W}
$$

Inside this, the term
$$
\frac{\partial y}{\partial W}
$$
depends on **the activation x**, so **x must be stored**.

For a deep model:

* Saving all activations costs many GB.
* Even if weights fit, activations explode in memory usage.

Gradient checkpointing solves this by:

* **Saving fewer activations**
* **Recomputing forward pass for selected layers** during backward

---

# What exactly happens?

### Without checkpointing

Forward pass stores activations:

```
x → layer1 → a1   (save a1)
a1 → layer2 → a2  (save a2)
a2 → layer3 → a3  (save a3)
```

Memory usage:

* Save a1
* Save a2
* Save a3
  Total = large.

### With checkpointing

```
x → layer1 → a1   (NOT saved)
a1 → layer2 → a2  (NOT saved)
a2 → layer3 → a3  (saved only for final layer)
```

Backward pass:

* To compute gradients for layer2, PyTorch **re-runs layer2 + layer1 forward** to reconstruct the missing activations.

---

# Memory savings

Approximate:

* Without checkpointing:
  Memory ≈ activations of all layers
* With checkpointing:
  Memory ≈ activations only at checkpoints + outputs of checkpointed segments

Rule of thumb:

**You cut activation memory roughly by √2 to 2× depending on structure.**

---

# Cost?

Extra compute:

* During backward, checkpointed layers are recomputed.
* Backward compute becomes roughly up to **1.5×–2× slower**.

---

# When is checkpointing useful?

Use it when:

* You run out of GPU memory
* Your model is deep (Transformers, ConvNeXt, ViT, ResNets, diffusion models)
* Large resolution inputs (512–1024+)
* Large batch sizes

Do NOT use it when:

* Model is small (no need)
* You are already compute-bound
* You need maximum throughput

---

# How to apply checkpointing



Here is a clean rewrite showing **two versions of your model**:

1. **Normal forward** (no checkpointing)
2. **Checkpointed forward** (with nonlinearities, as a realistic network)

Both versions use:

* Linear → ReLU → Linear → ReLU → Linear
* And checkpoint the expensive parts only in the second version.

No emojis used.

---

# Version 1: Normal forward (no checkpointing)

```python
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 1024)
        self.layer2 = nn.Linear(1024, 1024)
        self.layer3 = nn.Linear(1024, 10)
        self.act = nn.ReLU()

    def forward(self, x):
        # Normal forward with standard activation flow
        x = self.layer1(x)
        x = self.act(x)
        
        x = self.layer2(x)
        x = self.act(x)
        
        x = self.layer3(x)
        return x
```

This forward pass stores **all** intermediate activations for backward.

---

# Version 2: Forward with gradient checkpointing

Checkpointing is applied to the *Linear + ReLU* blocks.

Important: We wrap the sequence into a function because
checkpointing requires a function that recomputes the exact operations.

```python
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint

class MyModelCheckpointed(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 1024)
        self.layer2 = nn.Linear(1024, 1024)
        self.layer3 = nn.Linear(1024, 10)
        self.act = nn.ReLU()

    def forward_block1(self, x):
        x = self.layer1(x)
        x = self.act(x)
        return x

    def forward_block2(self, x):
        x = self.layer2(x)
        x = self.act(x)
        return x

    def forward(self, x):
        # Checkpoint the expensive blocks
        x = checkpoint.checkpoint(self.forward_block1, x)
        x = checkpoint.checkpoint(self.forward_block2, x)

        # Last layer is cheap, so usually not checkpointed
        x = self.layer3(x)
        return x
```

---

# Notes

1. In checkpointing, you must wrap multiple operations inside a function (`forward_block1`, `forward_block2`).
   You cannot simply do:

   ```
   checkpoint.checkpoint(self.layer1, x)
   checkpoint.checkpoint(self.act, x)
   ```

   because recomputing must follow the same sequence.

2. You decide which blocks to checkpoint.
   Usually, you checkpoint large or repeated layers, not the last linear classifier.

3. Activations inside checkpointed blocks are not saved
   They will be recomputed during backward.

---

If you want, I can also create:

* A wrapper that automatically checkpoint-wraps arbitrary sequential blocks
* A version using `checkpoint_sequential`
* A version with ConvNeXt blocks to show exactly how it would look in timm models



---

# Important rule: The function must be “pure”

A checkpointed function:

* must not modify global state
* must not use random operations unless deterministic
* must not use in-place ops on inputs

PyTorch will rerun it during backward, so result must be identical.

---

# Example for larger models (Transformers)

Typical structure:

```python
def forward(self, x):
    for block in self.transformer_blocks:
        x = checkpoint.checkpoint(block, x)
    x = self.final_layer(x)
    return x
```

This is where checkpointing gives **huge** memory savings because Transformer blocks are expensive.

---

# How many layers to checkpoint?

A good guideline:

* Checkpoint **middle layers** of deep networks.
* Do not checkpoint:

  * the very first layer (cheap)
  * the very last layer (cheap)
  * layers with small activation sizes

Example for 12-layer Transformer:

```
No checkpoint: Layer 1
Checkpoint: Layers 2–11
No checkpoint: Layer 12
```

---

# Advanced: checkpointing a *sequence* of layers (more efficient)

Instead of wrapping each layer individually, you can group them:

```python
def forward(self, x):
    def block1(x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

    x = checkpoint.checkpoint(block1, x)
    x = self.layer3(x)
    return x
```

This reduces the overhead of calling checkpoint many times.

---

# When checkpointing gives the best memory reduction

When activations dominate memory:

* ViT (very good benefit)
* MLP-Mixer
* ConvNeXt
* ResNet-50+
* UNet (diffusion, medical imaging)
* Any sequence of >10 layers

---

If you want, I can also give you:

* A visualization "before vs after" of memory usage
* A version using gradient checkpointing for ConvNeXt or ViT
* Code that automatically wraps N layers with checkpointing
* A PyTorch Lightning example
* A demonstration of memory usage using `torch.cuda.memory_allocated()`