Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization. In International Conference on Learning
Representations, 2021.
https://arxiv.org/pdf/2010.01412.pdf
Loss landscape가 flat할수록 generalizability가 강해짐.
Flatness of minima -> the flatter, the more generalizable
SAM은 loss value와 loss sharpness를 동시적으로 minimizing을 하여 flatter region의 파라미터들을 찾을 수 있게 되며, 찾은 파라미터 근방의 파라미터들도 비슷한 낮은 loss를 갖게 됨.
generalization ability 검증 방법
- 밑바닥부터 학습했을 때 generalization
- Finetuning 단계에 적용했을 때 generalization
- 1995년부터 연구가 진행됨: flat minima search -> 더 나은 generalization을 위해 loss sharpness를 penalizing하는 방법 연구
- Regularize local entropy(i.e. Entropy-SGD)
- Average weights during training(i.e. SWA)
- 기존 연구에서의 한계 : sharpness의 측정이 계산하기에 어려웠음(ex. weights의 Hessain 계산).
Our goal : population loss를 가장 최소화하는 모델 파라미터 w를 찾는 것!
- PAC(probably approximately correct) Bayesian generalization bound 기반으로 generalization bound 산정을 하여 population loss에 대한 upper bound(상한)를 정함.
- ϵ, epsilon은 perturbation을 표현한 것임.
- h는 양의 실수에서 양의 실수로 엄격하게 증가하는 함수
- weight perturb (loss가 커지게 만드는)를 줬을 때의 loss와 그렇지 않았을 때의 loss 차이를 보는 것으로 해석할 수 있음.
- perturb에 민감하면 (loss landscape가 sharp하면) 차이가 커지고 둔감하면 (loss landscape가 flat하면) 차이가 작아짐.
- 대괄호 항: S에 대해서 계산되는 training loss가 w 근처에서 얼마나 빠르게 변동(증가)를 하는가를 의미 -> 곧 sharpness를 의미함.
- 대괄호 항 뒤의 항: w에 대한 training loss와 regularization
- h term을 일반적인 L2 정규화(regularization)로 계산
- 결과적으로 위의 Sharpness-aware minimization problem을 해결하면 됨.
where ρ ≥ 0 is a hyperparameter and p ∈ [1, ∞] (경험적으로 ρ = 2로 하여 L2-norm 계산한 것이 최적의 결과를 냄)
- max에 대한 gradient를 구해야함으로 approximation을 진행함 -> 단순하게 first order Taylor 전개를 진행함.
- 여기서 두번째 항의 hessian 계산은 속도를 늦추며 성능을 낮추므로 Drop함(계산에서 제외).
- 최종적으로 구하고자 하는 gradient 전개는 위와 같음
- SVHN과 Fasion-MNIST에서도 성능 개선을 보임
- ImageNet에서도 SAM을 적용했을 때 더 높은 성능을 보였으며 과적합이 덜 발생하면서 training을 진행할 수 있었음.
놀랍게도 두번째 항을 버리는 것이 더 좋은 성능을 냈음. -> Cosine 유사도를 계산했을 때 중반 이후에 더 좋은 유사도를 보임.
p차 norm 계산에서 p=2일 때 정확도가 더 올라감.
- 배치마다 SAM 업데이트를 사용하는 방법
- GPU accelerate 마다 SAM 업데이트를 계산하여 GPU의 각 계산된 SAM 업데이트의 평균을 구하여 모델을 업데이트하는 방법 (각 accelerator는 m 사이즈의 subset을 사용함) -> m이 작을수록 더 좋은 generalization ability를 보여줌.
https://m0nads.wordpress.com/2021/04/10/sharpness-aware-minimization/
https://github.com/reyllama/paper-reviews/issues/30
https://creamnuts.github.io/short%20review/sam/
https://deepseow.tistory.com/20
[UNIST 유튜브] https://youtu.be/lcNjbOHf0uo?si=dA16_4d4c_ARwZLt
[딥러닝논문읽기모임 유튜브] https://youtu.be/iC3Y85W5tmM?si=YcO3mehbB-3_a2Yc