Skip to content

Commit

Permalink
feat: Added custom weight decay for normalization layers (#162)
Browse files Browse the repository at this point in the history
* feat: Added norm weight decay to trainer

* feat: Added norm weight decay support to training scripts

* chore: Updated mypy

* refactor: Refactored norm param split

* refactor: Removed unused import

* fix: Fixed param split

* fix: Fixed implementation

* fix: Fixed norm wd
  • Loading branch information
frgfm committed Nov 7, 2021
1 parent ee903c4 commit cd3a624
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 29 deletions.
57 changes: 43 additions & 14 deletions holocron/trainer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.utils.data import DataLoader
from torchvision.ops.boxes import box_iou

from .utils import freeze_bn, freeze_model
from .utils import freeze_bn, freeze_model, split_normalization_params

__all__ = ['Trainer', 'ClassificationTrainer', 'BinaryClassificationTrainer', 'SegmentationTrainer', 'DetectionTrainer']

Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(
self.epoch = 0
self.min_loss = math.inf
self.gpu = gpu
self._params: Optional[ContiguousParams] = None
self._params: Optional[List[ContiguousParams]] = None
self.lr_recorder: List[float] = []
self.loss_recorder: List[float] = []
self.set_device(gpu)
Expand Down Expand Up @@ -170,16 +170,33 @@ def _get_loss(self, x: Tensor, target: Tensor) -> Tensor:
out = self.model(x)
return self.criterion(out, target)

def _set_params(self) -> None:
self._params = ContiguousParams([p for p in self.model.parameters() if p.requires_grad])
def _set_params(self, norm_weight_decay: Optional[float] = None) -> None:
if norm_weight_decay is None:
self._params = [ContiguousParams([p for p in self.model.parameters() if p.requires_grad])]
else:
self._params = [
ContiguousParams(_params) if len(_params) > 0 else None
for _params in split_normalization_params(self.model)
]

def _reset_opt(self, lr: float) -> None:
def _reset_opt(self, lr: float, norm_weight_decay: Optional[float] = None) -> None:
"""Reset the target params of the optimizer"""
self.optimizer.defaults['lr'] = lr
self.optimizer.state = defaultdict(dict)
self.optimizer.param_groups = []
self._set_params()
self.optimizer.add_param_group(dict(params=self._params.contiguous())) # type: ignore[union-attr]
self._set_params(norm_weight_decay)
# Split it if norm layers needs custom WD
if norm_weight_decay is None:
self.optimizer.add_param_group(
dict(params=self._params[0].contiguous()) # type: ignore[index]
)
else:
wd_groups = [norm_weight_decay, self.optimizer.defaults.get('weight_decay', 0)]
for _params, _wd in zip(self._params, wd_groups): # type: ignore[arg-type]
if _params:
self.optimizer.add_param_group(
dict(params=_params.contiguous(), weight_decay=_wd)
)

@torch.inference_mode()
def evaluate(self):
Expand All @@ -202,7 +219,8 @@ def fit_n_epochs(
num_epochs: int,
lr: float,
freeze_until: Optional[str] = None,
sched_type: str = 'onecycle'
sched_type: str = 'onecycle',
norm_weight_decay: Optional[float] = None,
) -> None:
"""Train the model for a given number of epochs
Expand All @@ -211,11 +229,12 @@ def fit_n_epochs(
lr (float): learning rate to be used by the scheduler
freeze_until (str, optional): last layer to freeze
sched_type (str, optional): type of scheduler to use
norm_weight_decay (float, optional): weight decay to apply to normalization parameters
"""

self.model = freeze_model(self.model.train(), freeze_until)
# Update param groups & LR
self._reset_opt(lr)
self._reset_opt(lr, norm_weight_decay)
# Scheduler
self._reset_scheduler(lr, num_epochs, sched_type)

Expand All @@ -227,7 +246,8 @@ def fit_n_epochs(

self._fit_epoch(mb)
# Check whether ops invalidated the buffer
self._params.assert_buffer_is_valid() # type: ignore[union-attr]
for _group in self._params: # type: ignore[union-attr]
_group.assert_buffer_is_valid()
eval_metrics = self.evaluate()

# master bar
Expand All @@ -246,14 +266,16 @@ def lr_find(
freeze_until: Optional[str] = None,
start_lr: float = 1e-7,
end_lr: float = 1,
num_it: int = 100
norm_weight_decay: Optional[float] = None,
num_it: int = 100,
) -> None:
"""Gridsearch the optimal learning rate for the training
Args:
freeze_until (str, optional): last layer to freeze
start_lr (float, optional): initial learning rate
end_lr (float, optional): final learning rate
norm_weight_decay (float, optional): weight decay to apply to normalization parameters
num_it (int, optional): number of iterations to perform
"""

Expand All @@ -262,7 +284,7 @@ def lr_find(

self.model = freeze_model(self.model.train(), freeze_until)
# Update param groups & LR
self._reset_opt(start_lr)
self._reset_opt(start_lr, norm_weight_decay)
gamma = (end_lr / start_lr) ** (1 / (num_it - 1))
scheduler = MultiplicativeLR(self.optimizer, lambda step: gamma)

Expand Down Expand Up @@ -323,18 +345,25 @@ def plot_recorder(self, beta: float = 0.95, block: bool = True) -> None:
plt.grid(True, linestyle='--', axis='x')
plt.show(block=block)

def check_setup(self, freeze_until: Optional[str] = None, lr: float = 3e-4, num_it: int = 100) -> bool:
def check_setup(
self,
freeze_until: Optional[str] = None,
lr: float = 3e-4,
norm_weight_decay: Optional[float] = None,
num_it: int = 100,
) -> bool:
"""Check whether you can overfit one batch
Args:
freeze_until (str, optional): last layer to freeze
lr (float, optional): learning rate to be used for training
norm_weight_decay (float, optional): weight decay to apply to normalization parameters
num_it (int, optional): number of iterations to perform
"""

self.model = freeze_model(self.model.train(), freeze_until)
# Update param groups & LR
self._reset_opt(lr)
self._reset_opt(lr, norm_weight_decay)

x, target = next(iter(self.train_loader))
x, target = self.to_cuda(x, target)
Expand Down
41 changes: 36 additions & 5 deletions holocron/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

from typing import Optional
from typing import List, Optional, Tuple

from torch.nn import Module
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm

__all__ = ['freeze_bn', 'freeze_model']
__all__ = ['freeze_bn', 'freeze_model', 'split_normalization_params']


def freeze_bn(mod: Module) -> Module:
def freeze_bn(mod: nn.Module) -> nn.Module:
"""Prevents parameter and stats from updating in Batchnorm layers that are frozen
Args:
Expand All @@ -31,7 +31,11 @@ def freeze_bn(mod: Module) -> Module:
return mod


def freeze_model(model: Module, last_frozen_layer: Optional[str] = None, frozen_bn_stat_update: bool = False) -> Module:
def freeze_model(
model: nn.Module,
last_frozen_layer: Optional[str] = None,
frozen_bn_stat_update: bool = False
) -> nn.Module:
"""Freeze a specific range of model layers
Args:
Expand Down Expand Up @@ -60,3 +64,30 @@ def freeze_model(model: Module, last_frozen_layer: Optional[str] = None, frozen_
model = freeze_bn(model)

return model


def split_normalization_params(
model: nn.Module,
norm_classes: Optional[List[type]] = None,
) -> Tuple[List[nn.Parameter], List[nn.Parameter]]:
# Borrowed from https://github.com/pytorch/vision/blob/main/torchvision/ops/_utils.py
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
if not norm_classes:
norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm]

for t in norm_classes:
if not issubclass(t, nn.Module):
raise ValueError(f"Class {t} is not a subclass of nn.Module.")

classes = tuple(norm_classes)

norm_params: List[nn.Parameter] = []
other_params: List[nn.Parameter] = []
for module in model.modules():
if next(module.children(), None):
other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad)
elif isinstance(module, classes):
norm_params.extend(p for p in module.parameters() if p.requires_grad)
else:
other_params.extend(p for p in module.parameters() if p.requires_grad)
return norm_params, other_params
9 changes: 6 additions & 3 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,20 @@ def main(args):

if args.lr_finder:
print("Looking for optimal LR")
trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100))
trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100), norm_weight_decay=args.norm_weight_decay)
trainer.plot_recorder()
return

if args.check_setup:
print("Checking batch overfitting")
is_ok = trainer.check_setup(args.freeze_until, args.lr, num_it=min(len(train_loader), 100))
is_ok = trainer.check_setup(args.freeze_until, args.lr, norm_weight_decay=args.norm_weight_decay,
num_it=min(len(train_loader), 100))
print(is_ok)
return

print("Start training")
start_time = time.time()
trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched)
trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched, norm_weight_decay=args.norm_weight_decay)
total_time_str = str(datetime.timedelta(seconds=int(time.time() - start_time)))
print(f"Training time {total_time_str}")

Expand All @@ -220,6 +221,8 @@ def parse_args():
parser.add_argument('--sched', default='onecycle', type=str, help='Scheduler to be used')
parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate')
parser.add_argument('--wd', '--weight-decay', default=0, type=float, help='weight decay', dest='weight_decay')
parser.add_argument('--norm-wd', default=None, type=float, dest='norm_weight_decay',
help='weight decay of norm parameters')
parser.add_argument('--mixup-alpha', default=0, type=float, help='Mixup alpha factor')
parser.add_argument("--lr-finder", dest='lr_finder', action='store_true', help="Should you run LR Finder")
parser.add_argument("--check-setup", dest='check_setup', action='store_true', help="Check your training setup")
Expand Down
9 changes: 6 additions & 3 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,20 @@ def main(args):

if args.lr_finder:
print("Looking for optimal LR")
trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100))
trainer.lr_find(args.freeze_until, norm_weight_decay=args.norm_weight_decay, num_it=min(len(train_loader), 100))
trainer.plot_recorder()
return

if args.check_setup:
print("Checking batch overfitting")
is_ok = trainer.check_setup(args.freeze_until, args.lr, num_it=min(len(train_loader), 100))
is_ok = trainer.check_setup(args.freeze_until, args.lr, norm_weight_decay=args.norm_weight_decay,
num_it=min(len(train_loader), 100))
print(is_ok)
return

print("Start training")
start_time = time.time()
trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched)
trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched, norm_weight_decay=args.norm_weight_decay)
total_time_str = str(datetime.timedelta(seconds=int(time.time() - start_time)))
print(f"Training time {total_time_str}")

Expand All @@ -203,6 +204,8 @@ def parse_args():
parser.add_argument('--sched', default='onecycle', type=str, help='scheduler')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--wd', '--weight-decay', default=0, type=float, help='weight decay', dest='weight_decay')
parser.add_argument('--norm-wd', default=None, type=float, dest='norm_weight_decay',
help='weight decay of norm parameters')
parser.add_argument("--lr-finder", dest='lr_finder', action='store_true', help="Should you run LR Finder")
parser.add_argument("--check-setup", dest='check_setup', action='store_true', help="Check your training setup")
parser.add_argument('--output-file', default='./model.pth', help='path where to save')
Expand Down
9 changes: 6 additions & 3 deletions references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,19 +208,20 @@ def main(args):

if args.lr_finder:
print("Looking for optimal LR")
trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100))
trainer.lr_find(args.freeze_until, norm_weight_decay=args.norm_weight_decay, num_it=min(len(train_loader), 100))
trainer.plot_recorder()
return

if args.check_setup:
print("Checking batch overfitting")
is_ok = trainer.check_setup(args.freeze_until, args.lr, num_it=min(len(train_loader), 100))
is_ok = trainer.check_setup(args.freeze_until, args.lr, norm_weight_decay=args.norm_weight_decay,
num_it=min(len(train_loader), 100))
print(is_ok)
return

print("Start training")
start_time = time.time()
trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched)
trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched, norm_weight_decay=args.norm_weight_decay)
total_time_str = str(datetime.timedelta(seconds=int(time.time() - start_time)))
print(f"Training time {total_time_str}")

Expand All @@ -245,6 +246,8 @@ def parse_args():
parser.add_argument('--sched', default='onecycle', type=str, help='scheduler')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--wd', '--weight-decay', default=0, type=float, help='weight decay', dest='weight_decay')
parser.add_argument('--norm-wd', default=None, type=float, dest='norm_weight_decay',
help='weight decay of norm parameters')
parser.add_argument("--lr-finder", dest='lr_finder', action='store_true', help="Should you run LR Finder")
parser.add_argument("--check-setup", dest='check_setup', action='store_true', help="Check your training setup")
parser.add_argument('--output-file', default='./model.pth', help='path where to save')
Expand Down
3 changes: 2 additions & 1 deletion test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def _test_trainer(
with pytest.raises(ValueError):
learner.lr_find(freeze_until, num_it=num_it + 1)

learner.lr_find(freeze_until, num_it=num_it)
# Test norm weight decay
learner.lr_find(freeze_until, norm_weight_decay=5e-4, num_it=num_it)
assert len(learner.lr_recorder) == len(learner.loss_recorder)
learner.plot_recorder(block=False)

Expand Down

0 comments on commit cd3a624

Please sign in to comment.