# Tutorial-Extra 4.1: Learning Rate Warmup in Transformer Training

Author: [Erik Syniawa](mailto:erik.syniawa@informatik.tu-chemnitz.de)

## 1. Understanding Learning Rate Warmup

Learning rate warmup is a critical technique in training Transformer models, including Vision Transformers (ViT). It involves starting with a very small learning rate and gradually increasing it to a predefined "base" learning rate over a specified number of steps or epochs before applying any decay schedule.

### Why Warmup is important for Transformers

Transformers are particularly sensitive to the learning rate during initial training for several reasons:

1. **Random Initialization Sensitivity**: Transformers start with randomly initialized attention mechanisms that need time to form meaningful patterns.

2. **Gradient Instability**: Early in training, large gradients combined with high learning rates can cause the model to diverge.

3. **Self-Attention Stabilization**: The self-attention mechanism needs time to learn which features are relevant before taking large optimization steps.

4. **Complex Loss Landscape**: Transformer architectures create a complex loss landscape where early optimization steps need to be more cautious.

5. **High Parameter Count**: With millions of parameters, Transformers benefit from a more controlled start to training.

## 2. Learning Rate Dynamics: Within and Across Epochs

Let's visualize how learning rates change throughout training with warmup and subsequent scheduling.

### Key Components of a Learning Rate Schedule

1. **Warmup Phase**: Gradual increase from a small initial learning rate to the base learning rate
2. **Decay Phase**: Gradual decrease following a specific schedule (e.g., cosine, linear, step)
3. **Minimum Learning Rate**: The lower bound below which the learning rate won't decrease

### Visualization of Learning Rate Changes


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, ChainedScheduler
import torch.optim as optim

# Create a model and optimizer
class DummyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(10, 2)
    
    def forward(self, x):
        return self.linear(x)

model = DummyModel()
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

# Define parameters
total_epochs = 100
warmup_epochs = 20
steps_per_epoch = 50
base_lr = 1e-3
min_lr = 1e-6
warmup_start_lr = 1e-6

# Create schedulers
warmup_scheduler = LinearLR(
    optimizer, 
    start_factor=warmup_start_lr/base_lr,
    end_factor=1.0, 
    total_iters=warmup_epochs
)

cosine_scheduler = CosineAnnealingLR(
    optimizer, 
    T_max=total_epochs, 
    eta_min=min_lr
)

if torch.__version__ >= '2.6':
    print("Using ChainedScheduler for PyTorch 2.6 and onward")
    
    # ChainedScheduler is available in PyTorch 2.6 and onward
    # It allows you to combine multiple schedulers
    # Note: ChainedScheduler is available in earlier versions of PyTorch but takes different arguments
    
    scheduler = ChainedScheduler(
        schedulers=(warmup_scheduler, cosine_scheduler),
        optimizer=optimizer,
    )

# if you are using an older version of PyTorch, you may need to use the following instead:
elif torch.__version__ < '2.6':
    from torch.optim.lr_scheduler import SequentialLR
    print("Using SequentialLR for older PyTorch versions")
    
    cosine_scheduler = CosineAnnealingLR(
        optimizer, 
        T_max=total_epochs - warmup_epochs,  # adjust T_max for the remaining epochs 
        eta_min=min_lr)
    
    scheduler = SequentialLR(
        optimizer,
        schedulers=[warmup_scheduler, cosine_scheduler],
        milestones=[warmup_epochs],)
# SequentialLR is deprecated in PyTorch 2.6 and onward due to epoch-based scheduling (scheduler.step(epoch) instead of scheduler.step())

# Track learning rates across epochs
epoch_lrs = []
epoch_lrs.append(warmup_start_lr)

for epoch in range(1, total_epochs):  # Start from 1 since we already recorded epoch 0
    dummy_input = torch.randn(32, 10)
    dummy_output = model(dummy_input)
    dummy_loss = dummy_output.sum()
    optimizer.zero_grad()
    dummy_loss.backward()
    optimizer.step()
    
    scheduler.step()
    epoch_lrs.append(scheduler.get_last_lr()[0])  # Get the current learning rate

# Function to simulate learning rate changes within an epoch
def lr_within_epoch(epoch, steps_per_epoch):
    if epoch < warmup_epochs:
        # For the warmup phase
        if epoch == 0:
            # For epoch 0, start at exactly warmup_start_lr
            start_lr = warmup_start_lr
            end_lr = epoch_lrs[0]
        else:
            # For other warmup epochs, interpolate between recorded values
            start_lr = epoch_lrs[epoch-1]
            end_lr = epoch_lrs[epoch]
            
        return [start_lr + (end_lr - start_lr) * (step / steps_per_epoch) 
                for step in range(steps_per_epoch)]
    else:
        # Constant within an epoch after warmup
        return [epoch_lrs[epoch]] * steps_per_epoch

# Plot learning rate across epochs
plt.figure(figsize=(12, 8))

# Plot 1: Learning Rate across all epochs
plt.subplot(2, 1, 1)
plt.plot(range(total_epochs), epoch_lrs)
plt.axvline(x=warmup_epochs, color='r', linestyle='--', 
           label=f'End of Warmup ({warmup_epochs} epochs)')
plt.title('Learning Rate Schedule Across Training', fontsize=14)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Learning Rate', fontsize=12)
plt.yscale('log')
plt.ylim(warmup_start_lr * 0.8, base_lr * 1.2)  # Set limits to ensure we see the full range
plt.grid(True, which="both", ls="--", alpha=0.7)

# Add explicit tick at warmup_start_lr to confirm the value
plt.yticks(list(plt.yticks()[0]) + [warmup_start_lr])

# Annotate the starting point to make it clear
plt.annotate(f'Start: {warmup_start_lr:.1e}', 
             xy=(0, warmup_start_lr),
             xytext=(5, warmup_start_lr*2),
             arrowprops=dict(arrowstyle='->'))

plt.legend()

# Plot 2: Learning Rate within specific epochs
plt.subplot(2, 1, 2)
within_epochs = [0, 5, warmup_epochs, 50, total_epochs-1]  # Selected epochs
for epoch in within_epochs:
    lrs = lr_within_epoch(epoch, steps_per_epoch)
    label = 'Warmup' if epoch < warmup_epochs else 'Decay'
    if epoch == warmup_epochs:
        label = 'Transition'
    plt.plot(range(steps_per_epoch), lrs,   
             label=f'Epoch {epoch} ({label})')

plt.title('Learning Rate Within Selected Epochs', fontsize=14)
plt.xlabel('Step within Epoch', fontsize=12)
plt.ylabel('Learning Rate', fontsize=12)
plt.grid(True, which="both", ls="--", alpha=0.7)
plt.legend()

plt.tight_layout()
plt.show()


## 3. Different Types of Warmup and Decay Schedules

### Common Warmup Strategies

1. **Linear Warmup**: The most common approach, where the learning rate increases linearly from initial to base value (`LinearLR`). You can also use `StepLR` for a stepwise increase with a given ratio ($\gamma > 1.0$).
2. **Exponential Warmup**: Learning rate increases exponentially, starting slow and accelerating (`LambdaLR`).

### Common Decay Strategies After Warmup

1. **Cosine Decay**: Follows a cosine curve, decreasing more slowly at the beginning and end (`CosineAnnealingLR`).
2. **Linear Decay**: Linear decrease from base learning rate to minimum (`LinearLR`).
3. **Step Decay**: Learning rate drops by a factor at specific milestones (`StepLR` with $\gamma < 1.0$).
4. **Exponential Decay**: Learning rate decreases by a multiplicative factor each step (`ExponentialLR`).

Let's visualize some of these combinations:


In [None]:
def get_lr_schedule(schedule_type: str, 
                    total_steps: int, 
                    warmup_steps: int, 
                    base_lr: float = 1.0, 
                    min_lr: float = 1e-7, 
                    warmup_start_lr: float = 1e-6):
    
    lrs = []
    for step in range(total_steps):
        if step < warmup_steps:
            # Warmup phase
            if 'linear_warmup' in schedule_type:
                lr = warmup_start_lr + (base_lr - warmup_start_lr) * (step / warmup_steps)
            elif 'exp_warmup' in schedule_type:
                # Exponential warmup
                warmup_factor = np.exp(np.log(base_lr/warmup_start_lr) * step / warmup_steps)
                lr = warmup_start_lr * warmup_factor
        else:
            # Decay phase
            progress = (step - warmup_steps) / (total_steps - warmup_steps)
            
            if 'cosine' in schedule_type:
                lr = min_lr + 0.5 * (base_lr - min_lr) * (1 + np.cos(np.pi * progress))
            elif 'linear' in schedule_type:
                lr = base_lr - (base_lr - min_lr) * progress
            elif 'step' in schedule_type:
                # Step decay at 1/3 and 2/3 of decay period
                if progress < 1/3:
                    lr = base_lr
                elif progress < 2/3:
                    lr = base_lr * 0.1
                else:
                    lr = base_lr * 0.01
            elif 'exp' in schedule_type:
                decay_rate = np.power(min_lr / base_lr, 1 / (total_steps - warmup_steps))
                lr = base_lr * (decay_rate ** (step - warmup_steps))
        
        lrs.append(lr)
    
    return lrs

# Parameters
total_steps = 500
warmup_steps = 50
schedules = [
    'linear_warmup_cosine',
    'linear_warmup_linear',
    'linear_warmup_step',
    'exp_warmup_cosine'
]

plt.figure(figsize=(12, 8))

for schedule in schedules:
    lrs = get_lr_schedule(schedule, total_steps, warmup_steps)
    plt.plot(lrs, label=schedule, alpha=0.7)

plt.axvline(x=warmup_steps, color='r', linestyle='--', label='End of Warmup')
plt.xlabel('Training Steps', fontsize=12)
plt.ylabel('Learning Rate', fontsize=12)
plt.title('Different Warmup and Decay Schedules', fontsize=14)
plt.legend()
plt.grid(True, which="both", ls="--", alpha=0.7)
plt.tight_layout()
plt.show()


## 4. Per-Step vs. Per-Epoch Updates

It's important to understand when to update your learning rate scheduler:

### Per-Epoch Updates
- The scheduler is stepped once per epoch after all batches have been processed
- More common in standard training setups
- Easier to implement and track
- Works well when epochs have consistent numbers of batches

```python
# Per-epoch update
for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        # Training steps
        ...
    
    # Update learning rate once per epoch
    scheduler.step()
```

### Per-Step Updates
- The scheduler is stepped after each batch/optimization step
- More precise control over learning rate changes
- Better for very large datasets where epochs contain many batches
- Common in large-scale Transformer pretraining (e.g., BERT, GPT)

```python
# Per-step update
for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        # Training steps
        ...
        
        # Update learning rate after each batch
        scheduler.step()
```

> If you're using per-step updates, you'll need to adjust your warmup period to be in steps rather than epochs!

## 5. Recommended Warmup Settings for Different Scenarios


| Scenario | Warmup | Base LR | Min LR | Decay | See |
|----------|--------|---------|--------|-------|-----|
| **ViT on small datasets** | 5% of epochs | 1e-4 to 3e-4 | 1e-6 | Cosine | Steiner et al. (2021) |
| **ViT on ImageNet** | 10-20 epochs | 3e-3 to 5e-3 | 1e-5 | Cosine | Dosovitskiy et al. (2020) |
| **Self-supervised ViT** | 10% of epochs | 5e-4 | 1e-6 | Cosine | Caron et al. (2021) |
| **Fine-tuning pretrained ViT** | 2-5% of epochs | 1e-5 to 5e-5 | 1e-7 | Linear | Touvron et al. (2021) |



> Note: These are general guidelines and may need to be adjusted based on your specific dataset, model, and training setup.

### 5.1 References

1. Dosovitskiy, A., et al. (2020). "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale." ICLR 2021
2. Caron, M., et al. (2021). "Emerging Properties in Self-Supervised Vision Transformers." ICCV 2021
3. Touvron, H., et al. (2021). "Training data-efficient image transformers & distillation through attention." ICML 2021
4. Steiner, A., et al. (2021). "How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers." arXiv:2106.10270

# 6. Custom Warmup Scheduler Implementation 

Some architectures or training setups may require custom schedulers not just limited to the learning rate but also for other hyperparameters. For example DINO uses a scheduler for the temperature parameter for the teacher and for the weight decay regularization in the optimizer. Below is an example of a custom warmup scheduler that can be used in PyTorch for weight decay:

 

In [None]:
import torch.optim as optim

class CosineWeightDecayScheduler:
    def __init__(self, 
                 optimizer: optim.Optimizer, 
                 initial_wd: float, 
                 final_wd: float, 
                 total_epochs: int):
        
        self.optimizer = optimizer
        self.initial_wd = initial_wd
        self.final_wd = final_wd
        self.total_epochs = total_epochs
        self.current_epoch = 0
        
    def step(self):
        # Cosine schedule from initial_wd to final_wd
        factor = 0.5 * (1 + np.cos(np.pi * self.current_epoch / self.total_epochs))
        current_wd = self.final_wd + (self.initial_wd - self.final_wd) * factor
        
        # Update weight decay in optimizer
        for param_group in self.optimizer.param_groups:
            param_group['weight_decay'] = current_wd
            
        self.current_epoch += 1
        return current_wd
    
# Example usage
lr = 1e-3
initial_wd = 1e-3
final_wd = 0.1
num_epochs = 10

optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=initial_wd)
lr_scheduler = LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=num_epochs)

# Setup weight decay scheduler
wd_scheduler = CosineWeightDecayScheduler(
    optimizer,
    initial_wd=initial_wd,
    final_wd=final_wd,
    total_epochs=num_epochs
)

# In your training loop
for epoch in range(num_epochs):
    # Training code...
    optimizer.step()
    
    # Step both schedulers at the end of each epoch
    lr_scheduler.step()
    wd_scheduler.step()
    
    current_lr = optimizer.param_groups[0]['lr']
    current_wd = optimizer.param_groups[0]['weight_decay']
    print(f"Epoch {epoch}: LR={current_lr:.6f}, WD={current_wd:.6f}")

> What could be the benefits of using a custom warmup scheduler for weight decay in your training setup? And why might this approach just suitable in self-supervised learning setups?