In [2]:
import tensorflow
from tensorflow import keras

K = keras.backend

In [3]:
import math
import numpy as np
import matplotlib.pyplot as plt

# finds the maximum learning rate for 1 cycle scheduling

class exponentialLearningRate(keras.callbacks.Callback):
    def __init__ (self, factor):
        self.rates = []
        self.losses = []
        self.factor = factor
    def on_batch_end(self, batch, logs):
        self.rates.append(K.get_value(self.model.optimizer.learning_rate))
        self.losses.append(logs["loss"])
        K.set_value(self.model.optimizer.learning_rate, self.model.optimizer.learning_rate * self.factor)

def findLearningRate(model, X, y, epochs=1, batch_size=32, min_rate=1e-5, max_rate=10):
    init_weights = model.get_weights()
    iterations = math.ceil(len(X) / batch_size) * epochs
    factor = np.exp(np.log(max_rate / min_rate) / iterations)
    init_lr = K.get_value(model.optimizer.learning_rate)
    K.set_value(model.optimizer.learning_rate, min_rate)
    elr = exponentialLearningRate(factor)
    history = model.fit(X, y, epochs=epochs, batch_size=batch_size,
                        callbacks=[elr])
    K.set_value(model.optimizer.learning_rate, init_lr)
    model.set_weights(init_weights)
    return elr.rates, elr.losses

def plotCurves(rates, losses):
    plt.figure(figsize=[13,13])
    plt.gca().set_xscale('log')
    plt.plot(rates, losses)
    plt.xlabel('learning_rate')
    plt.y_label('losses')
    plt.hlines(min(losses), min(rates), max(rates))
    plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])

In [4]:
# 1-cycle scheduling class, updates LR as a on_batch_start callback, in other words update LR before the start of a new epoch

class one_cycle(keras.callbacks.Callback):
    def __init__(self, max_rate, init_rate=None, iterations=None, last_iterations=None, last_rate=None):
        self.iterations = iterations
        self.max_rate = max_rate
        self.init_rate = init_rate or max_rate / 10
        self.last_iterations = last_iterations or iterations // 10 + 1
        self.half_iteration = (iterations - self.last_iterations) // 2
        self.last_rate = last_rate or self.init_rate / 1000
        self.iteration = 0
    def _interpolate(self, iter1, iter2, rate1, rate2):
        return ((rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1)
    def on_batch_begin(self, batch, logs):

        if self.iterations < self.half_iteration: # for the first half of scheduling
            rate = self._interpolate(0, self.half_iteration, self.init_rate, self.max_rate)
        elif self.iterations < 2 * self.half_iteration: # for the second half of scheduling
            rate = self._interpolate(self.half_iteration, 2*self.half_iteration, self.init_rate, self.max_rate)
        else: # for the final few epochs with maximum value
            rate = self._interpolate(self.half_iteration*2, self.iterations, self.init_rate, self.max_rate)
        
        self.iterations += 1
        K.set_value(self.model.optimizer.learning_rate, rate)
