-
Notifications
You must be signed in to change notification settings - Fork 400
/
sam.py
166 lines (137 loc) · 6.18 KB
/
sam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""SAM algorithm and optimizer class."""
from __future__ import annotations
import logging
from typing import Optional
import torch
from composer.core import Algorithm, Event, State
from composer.loggers import Logger
from composer.utils import ensure_tuple
log = logging.getLogger(__name__)
__all__ = ['SAM', 'SAMOptimizer']
class SAMOptimizer(torch.optim.Optimizer):
"""Wraps an optimizer with sharpness-aware minimization (`Foret et al, 2020 <https://arxiv.org/abs/2010.01412>`_).
See :class:`.SAM` for details.
Implementation based on https://github.com/davda54/sam
Args:
base_optimizer (torch.optim.Optimizer) The optimizer to apply SAM to.
rho (float, optional): The SAM neighborhood size. Must be greater than 0. Default: ``0.05``.
epsilon (float, optional): A small value added to the gradient norm for numerical stability. Default: ``1.0e-12``.
interval (int, optional): SAM will run once per ``interval`` steps. A value of 1 will
cause SAM to run every step. Steps on which SAM runs take
roughly twice as much time to complete. Default: ``1``.
"""
def __init__(self,
base_optimizer: torch.optim.Optimizer,
rho: float = 0.05,
epsilon: float = 1.0e-12,
interval: int = 1,
**kwargs):
if rho < 0:
raise ValueError(f'Invalid rho, should be non-negative: {rho}')
self.base_optimizer = base_optimizer
self.global_step = 0
self.interval = interval
self._step_supports_amp_closure = True # Flag for Composer trainer
defaults = {'rho': rho, 'epsilon': epsilon, **kwargs}
super(SAMOptimizer, self).__init__(self.base_optimizer.param_groups, defaults)
@torch.no_grad()
def sub_e_w(self):
for group in self.param_groups:
for p in group['params']:
if 'e_w' not in self.state[p]:
continue
e_w = self.state[p]['e_w'] # retrieve stale e(w)
p.sub_(e_w) # get back to "w" from "w + e(w)"
@torch.no_grad()
def first_step(self):
grad_norm = self._grad_norm()
for group in self.param_groups:
scale = group['rho'] / (grad_norm + group['epsilon'])
for p in group['params']:
if p.grad is None:
continue
e_w = p.grad * scale.to(p)
p.add_(e_w) # climb to the local maximum "w + e(w)"
self.state[p]['e_w'] = e_w
@torch.no_grad()
def second_step(self):
for group in self.param_groups:
for p in group['params']:
if p.grad is None or 'e_w' not in self.state[p]:
continue
p.sub_(self.state[p]['e_w']) # get back to "w" from "w + e(w)"
self.base_optimizer.step() # do the actual "sharpness-aware" update
@torch.no_grad()
def step(self, closure=None):
assert closure is not None, 'Sharpness Aware Minimization requires closure, but it was not provided'
closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
loss = None
if (self.global_step + 1) % self.interval == 0:
# Compute gradient at (w) per-GPU, and do not sync
loss = closure(ddp_sync=False) # type: ignore
if loss:
self.first_step() # Compute e(w) and set weights to (w + (e(w)) separately per-GPU
if closure(): # Compute gradient at (w + e(w))
self.second_step() # Reset weights to (w) and step base optimizer
else:
self.sub_e_w() # If second forward-backward closure fails, reset weights to (w)
else:
loss = closure()
if loss:
self.base_optimizer.step()
self.global_step += 1
return loss
def _grad_norm(self):
norm = torch.norm(torch.stack(
[p.grad.norm(p=2) for group in self.param_groups for p in group['params'] if p.grad is not None]),
p='fro')
return norm
class SAM(Algorithm):
"""Adds sharpness-aware minimization (`Foret et al, 2020 <https://arxiv.org/abs/2010.01412>`_)
by wrapping an existing optimizer with a :class:`.SAMOptimizer`. SAM can improve model generalization
and provide robustness to label noise.
Runs on :attr:`.Event.INIT`.
Args:
rho (float, optional): The neighborhood size parameter of SAM. Must be greater than 0.
Default: ``0.05``.
epsilon (float, optional): A small value added to the gradient norm for numerical stability.
Default: ``1e-12``.
interval (int, optional): SAM will run once per ``interval`` steps. A value of 1 will
cause SAM to run every step. Steps on which SAM runs take
roughly twice as much time to complete. Default: ``1``.
Example:
.. testcode::
from composer.algorithms import SAM
algorithm = SAM(rho=0.05, epsilon=1.0e-12, interval=1)
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_duration="1ep",
algorithms=[algorithm],
optimizers=[optimizer],
)
"""
def __init__(
self,
rho: float = 0.05,
epsilon: float = 1.0e-12,
interval: int = 1,
):
"""__init__ is constructed from the same fields as in hparams."""
self.rho = rho
self.epsilon = epsilon
self.interval = interval
def match(self, event: Event, state: State) -> bool:
return event == Event.INIT
def apply(self, event: Event, state: State, logger: Optional[Logger]) -> Optional[int]:
assert state.optimizers is not None
state.optimizers = tuple(
SAMOptimizer(
base_optimizer=optimizer,
rho=self.rho,
epsilon=self.epsilon,
interval=self.interval,
) for optimizer in ensure_tuple(state.optimizers))