In [2]:
"""Interact with various pytorch learning-rate schedulers.
"""

from ipywidgets import widgets, interact
import matplotlib.pyplot as plt
import torch

In [3]:
# This example uses Alexnet
import torchvision
model = torchvision.models.alexnet(pretrained=True).cuda()

In [4]:
@interact(first_epoch=(0,100), last_epoch=(1,100), step_size=(1, 30, 1), gamma=(0, 1, 0.1), lr='0.001', T_max=(1,10))
def draw_schedules(first_epoch=0, last_epoch=50, step_size=3, gamma=0.9, lr=0.001, T_max=1):
    lr = float(lr)
    optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9, weight_decay=0.0001)

    schedulers = {}
    schedulers['ExponentialLR'] = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
    schedulers['StepLR'] = torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma)
    schedulers['CosineAnnealingLR'] = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max)
    
    # http://pytorch.org/docs/master/_modules/torch/optim/lr_scheduler.html
    epochs = []
    lr_values = {}
    for name in schedulers.keys():
        lr_values[name] = []

    for epoch in range(first_epoch, last_epoch):
        epochs.append(epoch)
        for name, scheduler in schedulers.items():
            scheduler.step(epoch)
            lr = scheduler.get_lr()
            lr_values[name].append(lr)    

    for name in schedulers.keys():
        plt.plot(epochs, lr_values[name])
    plt.ylabel('LR')
    plt.xlabel('epoch')
    plt.title('Learning Rate Schedulers')
    plt.show()
