In [22]:
"""Interact with various pytorch learning-rate schedulers.

This notebook let's you graph the behavior of several different
learning-rate (LR) schedulers, which may help you choose which
LR decay policy you want to use in your training.

The notebook also introduces a Shifted LR policy, which delays
the application of the "real" policy, by 'shift' epochs.
"""

from ipywidgets import widgets, interact
import matplotlib.pyplot as plt
import torch
import torch.optim.lr_scheduler as lr_scheduler

In [23]:
# This example uses Alexnet
import torchvision
model = torchvision.models.alexnet(pretrained=True).cuda()

In [108]:
class ShiftedLR(lr_scheduler._LRScheduler):
"""This LR decay scheduler, is a hybrid constant LR policy and
another LR scheduler.

The constant policy is applied for the first 'shift' epochs,
and thereafter the provided scheduler is applicat.
"""
    def __init__(self, optimizer, shift, scheduler, last_epoch=-1):       
        self.shift = shift
        self.scheduler = scheduler
        super(ShiftedLR, self).__init__(optimizer, last_epoch)
                
    def get_lr(self): 
        if self.last_epoch < self.shift:
            return self.base_lrs
        return self.scheduler.get_lr()
        
    
    def step(self, epoch=None):
        super(ShiftedLR, self).step(epoch)
        if self.last_epoch >= self.shift:
            self.scheduler.step(self.last_epoch - self.shift)
 

In [109]:
# plt.rcParams["figure.figsize"] = (10, 10)

@interact(shift=(0,100),
          first_epoch=(0, 100), 
          last_epoch=(1,200), 
          step_size=(1, 30, 1), 
          gamma=(0, 1, 0.1), 
          lr='0.001', 
          T_max=(1, 10))
def draw_schedules(shift=0, first_epoch=0, last_epoch=50, step_size=3, gamma=0.9, lr=0.001, T_max=1):
    lr = float(lr)
    optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9, weight_decay=0.0001)

    schedulers = {}
    schedulers['ExponentialLR'] = lr_scheduler.ExponentialLR(optimizer, gamma)
    schedulers['StepLR'] = lr_scheduler.StepLR(optimizer, step_size, gamma)
    schedulers['CosineAnnealingLR'] = lr_scheduler.CosineAnnealingLR(optimizer, T_max)
    schedulers['ShiftedLR'] = ShiftedLR(optimizer, first_epoch+shift, 
                                        lr_scheduler.ExponentialLR(optimizer, gamma))
    
    # http://pytorch.org/docs/master/_modules/torch/optim/lr_scheduler.html
    epochs = []
    lr_values = {}
    for name in schedulers.keys():
        lr_values[name] = []
    
    
    for epoch in range(first_epoch, last_epoch):
        epochs.append(epoch)
        for name, scheduler in schedulers.items():
            scheduler.step(epoch)
            _lr = scheduler.get_lr()
            lr_values[name].append(_lr)    

    for name in schedulers.keys():
        plt.plot(epochs, lr_values[name])
    plt.ylabel('LR')
    plt.xlabel('epoch')
    plt.title('Learning Rate Schedulers')
    plt.show()
