Skip to content
forked from davda54/sam

SAM: Sharpness-Aware Minimization (PyTorch)

License

Notifications You must be signed in to change notification settings

lionsheep0724/sam

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

44 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

(Adaptive) SAM Optimizer

Sharpness-Aware Minimization for Efficiently Improving Generalization

~ in Pytorch ~



SAM simultaneously minimizes loss value and loss sharpness. In particular, it seeks parameters that lie in neighborhoods having uniformly low loss. SAM improves model generalization and yields SoTA performance for several datasets. Additionally, it provides robustness to label noise on par with that provided by SoTA procedures that specifically target learning with noisy labels.

This is an unofficial repository for Sharpness-Aware Minimization for Efficiently Improving Generalization and ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks. Implementation-wise, SAM class is a light wrapper that computes the regularized "sharpness-aware" gradient, which is used by the underlying optimizer (such as SGD with momentum). This repository also includes a simple WRN for Cifar10; as a proof-of-concept, it beats the performance of SGD with momentum on this dataset.

Loss landscape with and without SAM

ResNet loss landscape at the end of training with and without SAM. Sharpness-aware updates lead to a significantly wider minimum, which then leads to better generalization properties.


Usage

It should be straightforward to use SAM in your training pipeline. Just keep in mind that the training will run twice as slow, because SAM needs two forward-backward passes to estime the "sharpness-aware" gradient. If you're using gradient clipping, make sure to change only the magnitude of gradients, not their direction.

from sam import SAM
...

model = YourModel()
base_optimizer = torch.optim.SGD  # define an optimizer for the "sharpness-aware" update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)
...

for input, output in data:

  # first forward-backward pass
  loss = loss_function(output, model(input))  # use this loss for any training statistics
  loss.backward()
  optimizer.first_step(zero_grad=True)
  
  # second forward-backward pass
  loss_function(output, model(input)).backward()  # make sure to do a full forward pass
  optimizer.second_step(zero_grad=True)
...

Alternative usage with a single closure-based step function. This alternative offers similar API to native PyTorch optimizers like LBFGS (kindly suggested by @rmcavoy):

from sam import SAM
...

model = YourModel()
base_optimizer = torch.optim.SGD  # define an optimizer for the "sharpness-aware" update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)
...

for input, output in data:
  def closure():
    loss = loss_function(output, model(input))
    loss.backward()
    return loss

  loss = loss_function(output, model(input))
  loss.backward()
  optimizer.step(closure)
  optimizer.zero_grad()
...

Training tips

  • @hjq133: The suggested usage can potentially cause problems if you use batch normalization. The running statistics are computed in both forward passes, but they should be computed only for the first one. A possible solution is to use these two functions (kindly suggested by @evanatyourservice and @slala2121) to bypass the running statistics during the second pass:
def disable_bn(model):
  for module in model.modules():
    if isinstance(module, nn.BatchNorm):
      module.eval()

def enable_bn(model):
  model.train()
  • @evanatyourservice: If you plan to train on multiple GPUs, the paper states that "To compute the SAM update when parallelizing across multiple accelerators, we divide each data batch evenly among the accelerators, independently compute the SAM gradient on each accelerator, and average the resulting sub-batch SAM gradients to obtain the final SAM update." This can be achieved by the following code:
for input, output in data:
  # first forward-backward pass
  loss = loss_function(output, model(input))
  loss.backward()
  optimizer.first_step(zero_grad=True)
  
  # second forward-backward pass
  loss_function(output, model(input)).backward()
  reduce_gradients_from_all_accelerators()  # <- this is the important line
  optimizer.second_step(zero_grad=True)
  • @evanatyourservice: Adaptive SAM reportedly performs better than the original SAM. The ASAM paper suggests to use higher rho for the adaptive updates (~10x larger)

  • @mlaves: LR scheduling should be either applied to the base optimizer or you should use SAM with a single step call (with a closure):

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer.base_optimizer, T_max=200)
  • @AlbertoSabater: Integration with Pytorch Lightning — you can write the training_step function as:
def training_step(self, batch, batch_idx):
    optimizer = self.optimizers()

    # first forward-backward pass
    loss_1 = self.compute_loss(batch)
    self.manual_backward(loss_1, optimizer)
    optimizer.first_step(zero_grad=True)

    # second forward-backward pass
    loss_2 = self.compute_loss(batch)
    self.manual_backward(loss_2, optimizer)
    optimizer.second_step(zero_grad=True)

    return loss_1

Documentation

SAM.__init__

Argument Description
params (iterable) iterable of parameters to optimize or dicts defining parameter groups
base_optimizer (torch.optim.Optimizer) underlying optimizer that does the "sharpness-aware" update
rho (float, optional) size of the neighborhood for computing the max loss (default: 0.05)
adaptive (bool, optional) set this argument to True if you want to use an experimental implementation of element-wise Adaptive SAM (default: False)
**kwargs keyword arguments passed to the __init__ method of base_optimizer

SAM.first_step

Performs the first optimization step that finds the weights with the highest loss in the local rho-neighborhood.

Argument Description
zero_grad (bool, optional) set to True if you want to automatically zero-out all gradients after this step (default: False)

SAM.second_step

Performs the second optimization step that updates the original weights with the gradient from the (locally) highest point in the loss landscape.

Argument Description
zero_grad (bool, optional) set to True if you want to automatically zero-out all gradients after this step (default: False)

SAM.step

Performs both optimization steps in a single call. This function is an alternative to explicitly calling SAM.first_step and SAM.second_step.

Argument Description
closure (callable) the closure should do an additional full forward and backward pass on the optimized model (default: None)

Experiments

I've verified that SAM works on a simple WRN 16-8 model run on CIFAR10; you can replicate the experiment by running train.py. The Wide-ResNet is enhanced only by label smoothing and the most basic image augmentations with cutout, so the errors are higher than those in the SAM paper. Theoretically, you can get even lower errors by running for longer (1800 epochs instead of 200), because SAM shouldn't be as prone to overfitting. SAM uses rho=0.05, while ASAM is set to rho=2.0, as suggested by its authors.

Optimizer Test error rate
SGD + momentum 3.20 %
SAM + SGD + momentum 2.86 %
ASAM + SGD + momentum 2.55 %

About

SAM: Sharpness-Aware Minimization (PyTorch)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%