In [1]:
import tensorflow as tf 
import numpy as np 

print(tf.__version__)

2.1.0


In [2]:

def get_layers(layer): 
    try: 
        return layer.layers
    except AttributeError: 
        return []

def get_mult(layer):
    #if not mult then assume 1
    try:
        return layer.lr_mult
    except AttributeError:
        return 1.
    
def assign_mult(layer, lr_mult):
    #if has mult, don't override
    try:
        layer.lr_mult 
    except AttributeError: 
        layer.lr_mult = lr_mult 
    
def get_lowest_layers(model):
    layers = get_layers(model)
    
    mult = get_mult(model)
    
    if len(layers) > 0: 
        for layer in layers: 
            #propage mult to lower layers
            assign_mult(layer, mult)
            for sublayer in get_lowest_layers(layer):
                yield sublayer
    else:
        yield model
    
def apply_mult_to_var(layer): 
    mult = get_mult(layer)
    for var in layer.trainable_variables:
        var.lr_mult = tf.convert_to_tensor(mult, tf.float32)

    return layer

def inject(model): 
    
    for layer in get_lowest_layers(model): 
        apply_mult_to_var(layer) 
    
    #get opt, move the original apply fn to a safe place, assign new apply fn 
    opt = model.optimizer
    opt._apply_gradients = opt.apply_gradients
    opt.apply_gradients = apply_gradients.__get__(opt)
    opt.testing_flag = True 
    
    return model
    
def apply_gradients(self, grads_and_vars, *args, **kwargs): 
    
    if self.testing_flag: 
        print('Training with layerwise learning rates')
        self.testing_flag = False
        
    grads = [] 
    var_list = [] 
    
    #scale each grad based on var's lr_mult
    for grad, var in grads_and_vars:
        grad = tf.math.scalar_mul(var.lr_mult, grad)
        grads.append(grad)
        var_list.append(var)
    
    grads_and_vars = list(zip(grads, var_list))
        
    return self._apply_gradients(grads_and_vars, *args, **kwargs)



In [5]:


def build_simple_model(opt, loss = 'binary_crossentropy'): 

    sub_model = tf.keras.Sequential([tf.keras.layers.Dense(5, activation=tf.nn.relu)
                                     , tf.keras.layers.Dense(5, activation=tf.nn.relu)])
    
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(10, activation=tf.nn.relu, input_shape=(1,)),  
        sub_model,        
        tf.keras.layers.Dense(1, activation = tf.nn.sigmoid)
    ])

    model.compile(loss = loss, optimizer = opt)
    return model 

def test_lr_mult(model, do_inject = True):
    if do_inject: 
        inject(model)
    x = np.ones(shape = (256,1), dtype = np.float32)
    y = np.ones(shape = (256,1), dtype = np.float32)
    return model.fit(x, y, batch_size = 32, epochs = 5, verbose = 0)


def h_to_list(h):
    return h.__dict__['history']['loss']

def test_zero_lr_mult(model_fn = build_simple_model
                      , opts = ['adam', 'sgd']
                      , losses = ['binary_crossentropy', 'MSE']): 
    #test highest level 
    
    for opt in opts: 
        for loss in losses: 

            model = model_fn(opt, loss)
            model.lr_mult = 0
            h = test_lr_mult(model)
            assert len(set(h.__dict__['history']['loss'])) == 1, 'WITH 0 LR ALL LOSSES SHOULD BE IDENTICAL'

            #test top level layer
            model = model_fn(opt, loss)
            for layer in model.layers: 
                layer.lr_mult = 0 
            h = test_lr_mult(model)
            assert len(set(h.__dict__['history']['loss'])) == 1, 'WITH 0 LR ALL LOSSES SHOULD BE IDENTICAL'
            

def test_some_lr_mult(model_fn = build_simple_model
                      , opts = ['adam', 'sgd']
                      , losses = ['binary_crossentropy', 'MSE']): 
    
    for opt in opts: 
        for loss in losses: 
            
            model = model_fn(opt, loss)
            h = test_lr_mult(model)
            h = h_to_list(h)
            assert h[0] > h[-1], 'LOSS SHOULD HAVE DECREASED'

            model = model_fn(opt, loss)
            model.lr_mult = 1
            h = test_lr_mult(model)
            h = h_to_list(h)
            assert h[0] > h[-1], 'LOSS SHOULD HAVE DECREASED'
            
            model = model_fn(opt, loss)
            model.layers[0].lr_mult = 0
            h = test_lr_mult(model)
            h = h_to_list(h)
            assert h[0] > h[-1], 'LOSS SHOULD HAVE DECREASED'
            
    
test_zero_lr_mult(build_simple_model)
test_some_lr_mult(build_simple_model)

Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
Training with layerwise learning rates
