This is an unofficial implementation of SURROGATE GAP MINIMIZATION IMPROVES SHARPNESS-AWARE TRAINING for keras and tensorflow 2
The proposed Sharpness-Aware Minimization (SAM) improves generalization by minimizing a perturbed loss defined as the maximum loss within a neighborhood in the parameter space. Surrogate Gap Guided Sharpness-Aware Minimization (GSAM) is a novel improvement over SAM with negligible computation overhead. Conceptually, GSAM consists of two steps: 1) a gradient descent like SAM to minimize the perturbed loss, and 2) an ascent step in the orthogonal direction (after gradient decomposition) to minimize the surrogate gap and yet not affect the perturbed loss. Empirically, GSAM consistently improves generalization (e.g., +3.2% over SAM and +5.4% over AdamW on ImageNet top-1 accuracy for ViT-B/32). Official implementation in JAX
pip install git+https://github.com/mortfer/keras-gsam.git
from gsam import GSAM
# Wrap keras.model instance and specify rho and alpha hyperparameters
gsam_model = GSAM(model, rho=0.05, alpha=0.01)
You can use rho schedulers similar to learning rate schedulers
from gsam.callbacks import RhoScheduler, CosineAnnealingScheduler
from tensorflow.keras.callbacks import LearningRateScheduler
callbacks = [
LearningRateScheduler(CosineAnnealingScheduler(T_max=max_epochs, eta_max=1e-3, eta_min=0), verbose=1),
RhoScheduler(CosineAnnealingScheduler(T_max=max_epochs, eta_max=0.1, eta_min=0.01), verbose=1),
]
gsam_model.fit(
x_train,
y_train,
callbacks=callbacks,
batch_size=batch_size,
epochs=max_epochs
)
An example of how to use gsam can be found in gsam_comparison.ipynb
Results obtained:
Val accuracy (%) | |
---|---|
Vanilla | 80.52 |
SAM | 82.33 |
GSAM | 83.04 |