Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Scheduler with Warmup (#1184)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1184

Current implementations of warmup in pytext either involve doing warmup and optionally inverse square root decay (TODO) or using polynomial decay (TODO). However, through my experiments, I notice for large batch training a warmup period is helpful on other schedulers as well, especially when trying to mimic results of small batch training on large batches.

This diff adds support for `SchedulerWithWarmup`, underneath it holds two schedulers, WarmupScheduler and any other scheduler. After `warmup_steps`, the scheduler will switch from warmup to the specified scheduler.

This allows something like Warmup with Expontential Decay.

Since the scheduler is built on top of the existing warmup scheduler, any new features that come to that scheduler, will directly be applicable here.

Sample Config

```
"SchedulerWithWarmup": {
  "warmup_scheduler": {
    "warmup_steps": 500
  },
  "scheduler": {
    "ExponentialLR": {
      "gamma": 0.95
    }
  }
}
```

Reviewed By: ArmenAg

Differential Revision: D18838272

fbshipit-source-id: 1b1b107552f2f8f38ed8cc319b9b64096d0bc07c
  • Loading branch information
Akshat Shrivastava authored and facebook-github-bot committed Dec 6, 2019
1 parent ca5c251 commit a56c761
Showing 1 changed file with 69 additions and 2 deletions.
71 changes: 69 additions & 2 deletions pytext/optimizer/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import math
from typing import Optional
from typing import Optional, Union

import torch
from pytext.config import ConfigBase
from pytext.config.component import Component, ComponentType
from pytext.config.component import Component, ComponentType, create_scheduler
from pytext.optimizer import Optimizer
from torch.optim.lr_scheduler import (
CosineAnnealingLR as TorchCosineAnnealingLR,
Expand Down Expand Up @@ -470,3 +470,70 @@ def step_batch(self):
self.current_steps += 1
# update optimizer.param_groups's learning rate
self.step()


class SchedulerWithWarmup(_LRScheduler, BatchScheduler):
"""
Wraps another scheduler with a warmup phase. After `warmup_steps` defined in
warmup_scheduler.warmup_steps, the scheduler will switch to use the specified
scheduler in `scheduler`.
`warmup_scheduler`: is the configuration for the WarmupScheduler, that warms up
learning rate over `warmup_steps` linearly.
`scheduler`: is the main scheduler that will be applied after the warmup phase
(once `warmup_steps` have passed)
"""

class Config(BatchScheduler.Config):
# the definition of the warmup scheduler for the warmup phase
warmup_scheduler: WarmupScheduler.Config = WarmupScheduler.Config()

# the definition of the main scheduler to apply once the warmup phase
# has passed
scheduler: Union[
ExponentialLR.Config,
CosineAnnealingLR.Config,
ReduceLROnPlateau.Config,
LmFineTuning.Config,
CyclicLR.Config,
]

@classmethod
def from_config(cls, config: Config, optimizer: Optimizer):
warmup_scheduler = create_scheduler(config.warmup_scheduler, optimizer)
scheduler = create_scheduler(config.scheduler, optimizer)
return cls(
optimizer, warmup_scheduler, scheduler, config.warmup_scheduler.warmup_steps
)

def prepare(self, train_iter, total_epochs):
super().prepare(train_iter, total_epochs)
self.warmup_scheduler.prepare(train_iter, total_epochs)
self.scheduler.prepare(train_iter, total_epochs)

def __init__(self, optimizer, warmup_scheduler, scheduler, switch_steps):
self.optimizer = optimizer
self.warmup_scheduler = warmup_scheduler
self.scheduler = scheduler
self.switch_steps = switch_steps
self.curr_steps = 0

def step_batch(self):
if self.curr_steps < self.switch_steps:
self.curr_steps += 1
return self.warmup_scheduler.step_batch()
else:
return self.scheduler.step_batch()

def step_epoch(self, metrics, epoch):
if self.curr_steps < self.switch_steps:
return self.warmup_scheduler.step_epoch(metrics=metrics, epoch=epoch)
else:
return self.scheduler.step_epoch(metrics=metrics, epoch=None)

def get_lr(self):
if self.curr_steps < self.switch_steps:
return self.warmup_scheduler.get_lr()
else:
return self.scheduler.get_lr()

0 comments on commit a56c761

Please sign in to comment.