Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added custom weight decay for normalization layers #162

Merged
merged 8 commits into from
Nov 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions holocron/trainer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.utils.data import DataLoader
from torchvision.ops.boxes import box_iou

from .utils import freeze_bn, freeze_model
from .utils import freeze_bn, freeze_model, split_normalization_params

__all__ = ['Trainer', 'ClassificationTrainer', 'BinaryClassificationTrainer', 'SegmentationTrainer', 'DetectionTrainer']

Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(
self.epoch = 0
self.min_loss = math.inf
self.gpu = gpu
self._params: Optional[ContiguousParams] = None
self._params: Optional[List[ContiguousParams]] = None
self.lr_recorder: List[float] = []
self.loss_recorder: List[float] = []
self.set_device(gpu)
Expand Down Expand Up @@ -170,16 +170,33 @@ def _get_loss(self, x: Tensor, target: Tensor) -> Tensor:
out = self.model(x)
return self.criterion(out, target)

def _set_params(self) -> None:
self._params = ContiguousParams([p for p in self.model.parameters() if p.requires_grad])
def _set_params(self, norm_weight_decay: Optional[float] = None) -> None:
if norm_weight_decay is None:
self._params = [ContiguousParams([p for p in self.model.parameters() if p.requires_grad])]
else:
self._params = [
ContiguousParams(_params) if len(_params) > 0 else None
for _params in split_normalization_params(self.model)
]

def _reset_opt(self, lr: float) -> None:
def _reset_opt(self, lr: float, norm_weight_decay: Optional[float] = None) -> None:
"""Reset the target params of the optimizer"""
self.optimizer.defaults['lr'] = lr
self.optimizer.state = defaultdict(dict)
self.optimizer.param_groups = []
self._set_params()
self.optimizer.add_param_group(dict(params=self._params.contiguous())) # type: ignore[union-attr]
self._set_params(norm_weight_decay)
# Split it if norm layers needs custom WD
if norm_weight_decay is None:
self.optimizer.add_param_group(
dict(params=self._params[0].contiguous()) # type: ignore[index]
)
else:
wd_groups = [norm_weight_decay, self.optimizer.defaults.get('weight_decay', 0)]
for _params, _wd in zip(self._params, wd_groups): # type: ignore[arg-type]
if _params:
self.optimizer.add_param_group(
dict(params=_params.contiguous(), weight_decay=_wd)
)

@torch.inference_mode()
def evaluate(self):
Expand All @@ -202,7 +219,8 @@ def fit_n_epochs(
num_epochs: int,
lr: float,
freeze_until: Optional[str] = None,
sched_type: str = 'onecycle'
sched_type: str = 'onecycle',
norm_weight_decay: Optional[float] = None,
) -> None:
"""Train the model for a given number of epochs

Expand All @@ -211,11 +229,12 @@ def fit_n_epochs(
lr (float): learning rate to be used by the scheduler
freeze_until (str, optional): last layer to freeze
sched_type (str, optional): type of scheduler to use
norm_weight_decay (float, optional): weight decay to apply to normalization parameters
"""

self.model = freeze_model(self.model.train(), freeze_until)
# Update param groups & LR
self._reset_opt(lr)
self._reset_opt(lr, norm_weight_decay)
# Scheduler
self._reset_scheduler(lr, num_epochs, sched_type)

Expand All @@ -227,7 +246,8 @@ def fit_n_epochs(

self._fit_epoch(mb)
# Check whether ops invalidated the buffer
self._params.assert_buffer_is_valid() # type: ignore[union-attr]
for _group in self._params: # type: ignore[union-attr]
_group.assert_buffer_is_valid()
eval_metrics = self.evaluate()

# master bar
Expand All @@ -246,14 +266,16 @@ def lr_find(
freeze_until: Optional[str] = None,
start_lr: float = 1e-7,
end_lr: float = 1,
num_it: int = 100
norm_weight_decay: Optional[float] = None,
num_it: int = 100,
) -> None:
"""Gridsearch the optimal learning rate for the training

Args:
freeze_until (str, optional): last layer to freeze
start_lr (float, optional): initial learning rate
end_lr (float, optional): final learning rate
norm_weight_decay (float, optional): weight decay to apply to normalization parameters
num_it (int, optional): number of iterations to perform
"""

Expand All @@ -262,7 +284,7 @@ def lr_find(

self.model = freeze_model(self.model.train(), freeze_until)
# Update param groups & LR
self._reset_opt(start_lr)
self._reset_opt(start_lr, norm_weight_decay)
gamma = (end_lr / start_lr) ** (1 / (num_it - 1))
scheduler = MultiplicativeLR(self.optimizer, lambda step: gamma)

Expand Down Expand Up @@ -323,18 +345,25 @@ def plot_recorder(self, beta: float = 0.95, block: bool = True) -> None:
plt.grid(True, linestyle='--', axis='x')
plt.show(block=block)

def check_setup(self, freeze_until: Optional[str] = None, lr: float = 3e-4, num_it: int = 100) -> bool:
def check_setup(
self,
freeze_until: Optional[str] = None,
lr: float = 3e-4,
norm_weight_decay: Optional[float] = None,
num_it: int = 100,
) -> bool:
"""Check whether you can overfit one batch

Args:
freeze_until (str, optional): last layer to freeze
lr (float, optional): learning rate to be used for training
norm_weight_decay (float, optional): weight decay to apply to normalization parameters
num_it (int, optional): number of iterations to perform
"""

self.model = freeze_model(self.model.train(), freeze_until)
# Update param groups & LR
self._reset_opt(lr)
self._reset_opt(lr, norm_weight_decay)

x, target = next(iter(self.train_loader))
x, target = self.to_cuda(x, target)
Expand Down
41 changes: 36 additions & 5 deletions holocron/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
# This program is licensed under the Apache License version 2.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.

from typing import Optional
from typing import List, Optional, Tuple

from torch.nn import Module
from torch import nn
from torch.nn.modules.batchnorm import _BatchNorm

__all__ = ['freeze_bn', 'freeze_model']
__all__ = ['freeze_bn', 'freeze_model', 'split_normalization_params']


def freeze_bn(mod: Module) -> Module:
def freeze_bn(mod: nn.Module) -> nn.Module:
"""Prevents parameter and stats from updating in Batchnorm layers that are frozen

Args:
Expand All @@ -31,7 +31,11 @@ def freeze_bn(mod: Module) -> Module:
return mod


def freeze_model(model: Module, last_frozen_layer: Optional[str] = None, frozen_bn_stat_update: bool = False) -> Module:
def freeze_model(
model: nn.Module,
last_frozen_layer: Optional[str] = None,
frozen_bn_stat_update: bool = False
) -> nn.Module:
"""Freeze a specific range of model layers

Args:
Expand Down Expand Up @@ -60,3 +64,30 @@ def freeze_model(model: Module, last_frozen_layer: Optional[str] = None, frozen_
model = freeze_bn(model)

return model


def split_normalization_params(
model: nn.Module,
norm_classes: Optional[List[type]] = None,
) -> Tuple[List[nn.Parameter], List[nn.Parameter]]:
# Borrowed from https://github.com/pytorch/vision/blob/main/torchvision/ops/_utils.py
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
if not norm_classes:
norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm]

for t in norm_classes:
if not issubclass(t, nn.Module):
raise ValueError(f"Class {t} is not a subclass of nn.Module.")

classes = tuple(norm_classes)

norm_params: List[nn.Parameter] = []
other_params: List[nn.Parameter] = []
for module in model.modules():
if next(module.children(), None):
other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad)
elif isinstance(module, classes):
norm_params.extend(p for p in module.parameters() if p.requires_grad)
else:
other_params.extend(p for p in module.parameters() if p.requires_grad)
return norm_params, other_params
9 changes: 6 additions & 3 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,20 @@ def main(args):

if args.lr_finder:
print("Looking for optimal LR")
trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100))
trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100), norm_weight_decay=args.norm_weight_decay)
trainer.plot_recorder()
return

if args.check_setup:
print("Checking batch overfitting")
is_ok = trainer.check_setup(args.freeze_until, args.lr, num_it=min(len(train_loader), 100))
is_ok = trainer.check_setup(args.freeze_until, args.lr, norm_weight_decay=args.norm_weight_decay,
num_it=min(len(train_loader), 100))
print(is_ok)
return

print("Start training")
start_time = time.time()
trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched)
trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched, norm_weight_decay=args.norm_weight_decay)
total_time_str = str(datetime.timedelta(seconds=int(time.time() - start_time)))
print(f"Training time {total_time_str}")

Expand All @@ -220,6 +221,8 @@ def parse_args():
parser.add_argument('--sched', default='onecycle', type=str, help='Scheduler to be used')
parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate')
parser.add_argument('--wd', '--weight-decay', default=0, type=float, help='weight decay', dest='weight_decay')
parser.add_argument('--norm-wd', default=None, type=float, dest='norm_weight_decay',
help='weight decay of norm parameters')
parser.add_argument('--mixup-alpha', default=0, type=float, help='Mixup alpha factor')
parser.add_argument("--lr-finder", dest='lr_finder', action='store_true', help="Should you run LR Finder")
parser.add_argument("--check-setup", dest='check_setup', action='store_true', help="Check your training setup")
Expand Down
9 changes: 6 additions & 3 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,20 @@ def main(args):

if args.lr_finder:
print("Looking for optimal LR")
trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100))
trainer.lr_find(args.freeze_until, norm_weight_decay=args.norm_weight_decay, num_it=min(len(train_loader), 100))
trainer.plot_recorder()
return

if args.check_setup:
print("Checking batch overfitting")
is_ok = trainer.check_setup(args.freeze_until, args.lr, num_it=min(len(train_loader), 100))
is_ok = trainer.check_setup(args.freeze_until, args.lr, norm_weight_decay=args.norm_weight_decay,
num_it=min(len(train_loader), 100))
print(is_ok)
return

print("Start training")
start_time = time.time()
trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched)
trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched, norm_weight_decay=args.norm_weight_decay)
total_time_str = str(datetime.timedelta(seconds=int(time.time() - start_time)))
print(f"Training time {total_time_str}")

Expand All @@ -203,6 +204,8 @@ def parse_args():
parser.add_argument('--sched', default='onecycle', type=str, help='scheduler')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--wd', '--weight-decay', default=0, type=float, help='weight decay', dest='weight_decay')
parser.add_argument('--norm-wd', default=None, type=float, dest='norm_weight_decay',
help='weight decay of norm parameters')
parser.add_argument("--lr-finder", dest='lr_finder', action='store_true', help="Should you run LR Finder")
parser.add_argument("--check-setup", dest='check_setup', action='store_true', help="Check your training setup")
parser.add_argument('--output-file', default='./model.pth', help='path where to save')
Expand Down
9 changes: 6 additions & 3 deletions references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,19 +208,20 @@ def main(args):

if args.lr_finder:
print("Looking for optimal LR")
trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100))
trainer.lr_find(args.freeze_until, norm_weight_decay=args.norm_weight_decay, num_it=min(len(train_loader), 100))
trainer.plot_recorder()
return

if args.check_setup:
print("Checking batch overfitting")
is_ok = trainer.check_setup(args.freeze_until, args.lr, num_it=min(len(train_loader), 100))
is_ok = trainer.check_setup(args.freeze_until, args.lr, norm_weight_decay=args.norm_weight_decay,
num_it=min(len(train_loader), 100))
print(is_ok)
return

print("Start training")
start_time = time.time()
trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched)
trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched, norm_weight_decay=args.norm_weight_decay)
total_time_str = str(datetime.timedelta(seconds=int(time.time() - start_time)))
print(f"Training time {total_time_str}")

Expand All @@ -245,6 +246,8 @@ def parse_args():
parser.add_argument('--sched', default='onecycle', type=str, help='scheduler')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--wd', '--weight-decay', default=0, type=float, help='weight decay', dest='weight_decay')
parser.add_argument('--norm-wd', default=None, type=float, dest='norm_weight_decay',
help='weight decay of norm parameters')
parser.add_argument("--lr-finder", dest='lr_finder', action='store_true', help="Should you run LR Finder")
parser.add_argument("--check-setup", dest='check_setup', action='store_true', help="Check your training setup")
parser.add_argument('--output-file', default='./model.pth', help='path where to save')
Expand Down
3 changes: 2 additions & 1 deletion test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def _test_trainer(
with pytest.raises(ValueError):
learner.lr_find(freeze_until, num_it=num_it + 1)

learner.lr_find(freeze_until, num_it=num_it)
# Test norm weight decay
learner.lr_find(freeze_until, norm_weight_decay=5e-4, num_it=num_it)
assert len(learner.lr_recorder) == len(learner.loss_recorder)
learner.plot_recorder(block=False)

Expand Down