Skip to content

Commit

Permalink
Migrate remaining LR schedulers (#1448)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#1448

Test Plan: Imported from OSS

Reviewed By: alexeib

Differential Revision: D25092150

Pulled By: myleott

fbshipit-source-id: fd066a0eba388bb0c344082a8fa1132974d53d40
  • Loading branch information
myleott authored and facebook-github-bot committed Nov 20, 2020
1 parent b3f0183 commit bf71f14
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 158 deletions.
6 changes: 5 additions & 1 deletion fairseq/dataclass/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,11 @@ def get_default(f):
if isinstance(val, tuple):
val = list(val)

if getattr(v.type, "__origin__", None) is List:
if (
getattr(v.type, "__origin__", None) is List
# skip interpolation
and not (isinstance(val, str) and val.startswith("${"))
):
# if type is int but val is float, then we will crash later - try to convert here
t_args = v.type.__args__
if len(t_args) == 1:
Expand Down
35 changes: 13 additions & 22 deletions fairseq/optim/lr_scheduler/cosine_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,26 @@
from dataclasses import dataclass, field
from typing import List

from fairseq.dataclass import FairseqDataclass
from omegaconf import II, DictConfig
from omegaconf import II

from . import FairseqLRScheduler, register_lr_scheduler
from fairseq.dataclass import FairseqDataclass
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler


@dataclass
class CosineConfig(FairseqDataclass):
class CosineLRScheduleConfig(FairseqDataclass):
warmup_updates: int = field(
default=0,
metadata={"help": "warmup the learning rate linearly for the first N updates"},
)
warmup_init_lr: float = field(
default=-1,
metadata={
"help": "initial learning rate during warmup phase; default is args.lr"
"help": "initial learning rate during warmup phase; default is cfg.lr"
},
)
max_lr: float = field(
default=1.0, metadata={"help": "max learning rate, must be more than args.lr"}
default=1.0, metadata={"help": "max learning rate, must be more than cfg.lr"}
)
t_mult: float = field(
default=1.0, metadata={"help": "factor to grow the length of each period"}
Expand All @@ -38,13 +38,12 @@ class CosineConfig(FairseqDataclass):
lr_shrink: float = field(
default=0.1, metadata={"help": "shrink factor for annealing"}
)
# TODO common var for parent class
lr: List[float] = II("optimization.lr")
max_update: int = II("optimization.max_update")


@register_lr_scheduler("cosine", dataclass=CosineConfig)
class CosineSchedule(FairseqLRScheduler):
@register_lr_scheduler("cosine", dataclass=CosineLRScheduleConfig)
class CosineLRSchedule(FairseqLRScheduler):
"""Assign LR based on a cyclical schedule that follows the cosine function.
See https://arxiv.org/pdf/1608.03983.pdf for details.
Expand All @@ -55,7 +54,7 @@ class CosineSchedule(FairseqLRScheduler):
During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates)
lr = lrs[update_num]
After warmup::
Expand All @@ -67,9 +66,7 @@ class CosineSchedule(FairseqLRScheduler):
after every iteration.
"""

def __init__(
self, cfg: DictConfig, fairseq_optimizer
):
def __init__(self, cfg: CosineLRScheduleConfig, fairseq_optimizer):
super().__init__(cfg, fairseq_optimizer)
if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1:
raise ValueError(
Expand All @@ -78,11 +75,7 @@ def __init__(
)

warmup_end_lr = cfg.max_lr
lr = (
cfg.lr[0]
if isinstance(cfg.lr, Collection)
else cfg.lr
)
lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr
if cfg.warmup_init_lr < 0:
cfg.warmup_init_lr = lr

Expand All @@ -100,10 +93,8 @@ def __init__(
self.period = cfg.max_update - cfg.warmup_updates

if cfg.warmup_updates > 0:
# linearly warmup for the first args.warmup_updates
self.lr_step = (
warmup_end_lr - cfg.warmup_init_lr
) / cfg.warmup_updates
# linearly warmup for the first cfg.warmup_updates
self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates
else:
self.lr_step = 1

Expand Down
3 changes: 1 addition & 2 deletions fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from argparse import Namespace

from fairseq.dataclass.utils import gen_parser_from_dataclass

from .. import FairseqOptimizer
from fairseq.optim import FairseqOptimizer


class FairseqLRScheduler(object):
Expand Down
65 changes: 36 additions & 29 deletions fairseq/optim/lr_scheduler/fixed_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,44 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import LegacyFairseqLRScheduler, register_lr_scheduler
from dataclasses import dataclass, field
from typing import Optional, List
from omegaconf import II

from fairseq.dataclass import FairseqDataclass
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler

@register_lr_scheduler("fixed")
class FixedSchedule(LegacyFairseqLRScheduler):
"""Decay the LR on a fixed schedule."""

def __init__(self, args, optimizer):
super().__init__(args, optimizer)
@dataclass
class FixedLRScheduleConfig(FairseqDataclass):
force_anneal: Optional[int] = field(
default=None,
metadata={"help": "force annealing at specified epoch"},
)
lr_shrink: float = field(
default=0.1,
metadata={"help": "shrink factor for annealing, lr_new = (lr * lr_shrink)"},
)
warmup_updates: int = field(
default=0,
metadata={"help": "warmup the learning rate linearly for the first N updates"},
)
lr: List[float] = II("optimization.lr")


# set defaults
args.warmup_updates = getattr(args, "warmup_updates", 0) or 0
@register_lr_scheduler("fixed", dataclass=FixedLRScheduleConfig)
class FixedLRSchedule(FairseqLRScheduler):
"""Decay the LR on a fixed schedule."""

self.lr = args.lr[0]
if args.warmup_updates > 0:
self.warmup_factor = 1.0 / args.warmup_updates
def __init__(self, cfg: FixedLRScheduleConfig, optimizer):
super().__init__(cfg, optimizer)

self.lr = cfg.lr[0]
if cfg.warmup_updates > 0:
self.warmup_factor = 1.0 / cfg.warmup_updates
else:
self.warmup_factor = 1

@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
help='force annealing at specified epoch (epochs start at 1)')
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing, lr_new = (lr * lr_shrink)')
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
# fmt: on

def state_dict(self):
return {"lr": self.lr}

Expand All @@ -42,14 +49,14 @@ def load_state_dict(self, state_dict):
self.lr = state_dict["lr"]

def get_next_lr(self, epoch):
lrs = self.args.lr
if self.args.force_anneal is None or epoch < self.args.force_anneal:
lrs = self.cfg.lr
if self.cfg.force_anneal is None or epoch < self.cfg.force_anneal:
# use fixed LR schedule
next_lr = lrs[min(epoch - 1, len(lrs) - 1)]
else:
# annneal based on lr_shrink
next_lr = lrs[-1] * self.args.lr_shrink ** (
epoch + 1 - self.args.force_anneal
next_lr = lrs[-1] * self.cfg.lr_shrink ** (
epoch + 1 - self.cfg.force_anneal
)
return next_lr

Expand All @@ -61,8 +68,8 @@ def step_begin_epoch(self, epoch):

def step_update(self, num_updates):
"""Update the learning rate after each update."""
if self.args.warmup_updates > 0 and num_updates < self.args.warmup_updates:
self.warmup_factor = (num_updates + 1) / float(self.args.warmup_updates)
if self.cfg.warmup_updates > 0 and num_updates < self.cfg.warmup_updates:
self.warmup_factor = (num_updates + 1) / float(self.cfg.warmup_updates)
self.optimizer.set_lr(self.warmup_factor * self.lr)
else:
self.optimizer.set_lr(self.lr)
Expand Down
35 changes: 13 additions & 22 deletions fairseq/optim/lr_scheduler/inverse_square_root_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,28 @@
from dataclasses import dataclass, field
from typing import List

from fairseq.dataclass import FairseqDataclass
from omegaconf import II, DictConfig
from omegaconf import II

from . import FairseqLRScheduler, register_lr_scheduler
from fairseq.dataclass import FairseqDataclass
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler


@dataclass
class InverseSquareRootScheduleConfig(FairseqDataclass):
class InverseSquareRootLRScheduleConfig(FairseqDataclass):
warmup_updates: int = field(
default=4000,
metadata={"help": "warmup the learning rate linearly for the first N updates"},
)
warmup_init_lr: float = field(
default=-1,
metadata={
"help": "initial learning rate during warmup phase; default is args.lr"
"help": "initial learning rate during warmup phase; default is cfg.lr"
},
)
# TODO common vars at parent class
lr: List[float] = II("optimization.lr")


@register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootScheduleConfig)
@register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootLRScheduleConfig)
class InverseSquareRootSchedule(FairseqLRScheduler):
"""Decay the LR based on the inverse square root of the update number.
Expand All @@ -40,36 +39,28 @@ class InverseSquareRootSchedule(FairseqLRScheduler):
During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates)
lr = lrs[update_num]
After warmup::
decay_factor = args.lr * sqrt(args.warmup_updates)
decay_factor = cfg.lr * sqrt(cfg.warmup_updates)
lr = decay_factor / sqrt(update_num)
"""

def __init__(self, cfg: DictConfig, optimizer):
def __init__(self, cfg: InverseSquareRootLRScheduleConfig, optimizer):
super().__init__(cfg, optimizer)
if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1:
raise ValueError(
"Cannot use a fixed learning rate schedule with inverse_sqrt."
" Consider --lr-scheduler=fixed instead."
)
warmup_end_lr = (
cfg.lr[0]
if isinstance(cfg.lr, Collection)
else cfg.lr
)
warmup_end_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr
if cfg.warmup_init_lr < 0:
cfg.warmup_init_lr = (
0 if cfg.warmup_updates > 0 else warmup_end_lr
)
cfg.warmup_init_lr = 0 if cfg.warmup_updates > 0 else warmup_end_lr

# linearly warmup for the first args.warmup_updates
self.lr_step = (
warmup_end_lr - cfg.warmup_init_lr
) / cfg.warmup_updates
# linearly warmup for the first cfg.warmup_updates
self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates

# then, decay prop. to the inverse square root of the update number
self.decay_factor = warmup_end_lr * cfg.warmup_updates ** 0.5
Expand Down
12 changes: 5 additions & 7 deletions fairseq/optim/lr_scheduler/polynomial_decay_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from omegaconf import II

from fairseq.dataclass import FairseqDataclass
from . import FairseqLRScheduler, register_lr_scheduler
from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler


@dataclass
class PolynomialDecayScheduleConfig(FairseqDataclass):
class PolynomialDecayLRScheduleConfig(FairseqDataclass):
warmup_updates: int = field(
default=0,
metadata={"help": "warmup the learning rate linearly for the first N updates"},
Expand All @@ -36,13 +36,11 @@ class PolynomialDecayScheduleConfig(FairseqDataclass):
lr: List[float] = II("optimization.lr")


@register_lr_scheduler("polynomial_decay", dataclass=PolynomialDecayScheduleConfig)
class PolynomialDecaySchedule(FairseqLRScheduler):
@register_lr_scheduler("polynomial_decay", dataclass=PolynomialDecayLRScheduleConfig)
class PolynomialDecayLRSchedule(FairseqLRScheduler):
"""Decay the LR on a fixed schedule."""

cfg: PolynomialDecayScheduleConfig

def __init__(self, cfg: PolynomialDecayScheduleConfig, optimizer):
def __init__(self, cfg: PolynomialDecayLRScheduleConfig, optimizer):
super().__init__(cfg, optimizer)

assert cfg.total_num_update > 0
Expand Down
Loading

0 comments on commit bf71f14

Please sign in to comment.