 # Gradient monitoring

In [1]:
import tensorflow as tf
import numpy as np
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
from tqdm import tqdm

from datetime import datetime
from pathlib import Path

In [2]:
n = 10_000
x = np.random.rand(n, 2)
x1, x2 = x[:, 0], x[:, 1]
y = 3*x1 + 2*x2

In [3]:
def get_mlp(layers_n, initializer, act='relu', last_act=None):
    l_first, *ls, l_last = layers_n
    
    model = keras.Sequential(name='MLP')
    
    model.add(layers.Input(shape=(l_first,)))
    for i, l in enumerate(ls):
        model.add(layers.Dense(l, activation=act, kernel_initializer=initializer, name=f'Dense_{i}'))
    
    model.add(layers.Dense(l_last, activation=last_act, kernel_initializer=initializer, name=f'Output'))

    return model

In [4]:
def get_features_extractor(model):
    inputs = model.input
    outputs = {layer.name: layer.output for layer in model.layers}
    return keras.Model(inputs, outputs)

In [35]:
activation='sigmoid'
stddev = 1e-3
# initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=stddev)
# initializer = tf.keras.initializers.HeNormal()
initializer = tf.keras.initializers.VarianceScaling(16)
hidden_layers_size = 100
hidden_layers_num = 100
layers_n = [2] + [hidden_layers_size]*hidden_layers_num  + [1]
batch_size = 1
inner_folder = datetime.now().strftime('%Y_%m_%d__%H_%M_%S')
tensorboard_log_folder = f'./tensorboard_logs/{inner_folder}'
num_batches = 4

mlp = get_mlp(layers_n, initializer, act=activation)
feature_extractor = get_features_extractor(mlp)
lr = 1e-3

optimizer = tf.keras.optimizers.SGD(learning_rate=lr)

stop = False
for num_batch, (xbatch, ybatch) in tqdm(tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size).take(num_batches).enumerate()):
    outputs_dict = {}
    grads_dict = {}
    with tf.GradientTape() as tape:
        outputs = feature_extractor(xbatch)
        y_pred = tf.reshape(outputs['Output'], (-1,))
        loss = tf.keras.losses.MSE(ybatch, y_pred)
    grads = tape.gradient(loss, mlp.trainable_variables)
    optimizer.apply_gradients(zip(grads, mlp.trainable_variables))
    with tf.summary.create_file_writer(tensorboard_log_folder).as_default():
        for layer_name, output in outputs.items():
            outputs_dict[layer_name] = output
            if not tf.reduce_all(tf.math.is_finite(output)).numpy():
                stop = True
            tf.summary.histogram(name=f'output/{layer_name}', data=output, step=num_batch)
        for grad, var in reversed(list(zip(grads, mlp.trainable_variables))):
            grads_dict[var.name] = grad
            if not tf.reduce_all(tf.math.is_finite(grad)).numpy():
                stop = True
            tf.summary.histogram(name=f'grads/{var.name}', step=num_batch, data=grad)
    if stop:
        print(f'\nSTOP num_batch={num_batch}')
        break

100%|██████████| 4/4 [00:09<00:00,  2.27s/it]


In [20]:
outputs_dict.keys()

dict_keys(['Dense_0', 'Dense_1', 'Dense_2', 'Dense_3', 'Dense_4', 'Dense_5', 'Dense_6', 'Dense_7', 'Dense_8', 'Dense_9', 'Dense_10', 'Dense_11', 'Dense_12', 'Dense_13', 'Dense_14', 'Dense_15', 'Dense_16', 'Dense_17', 'Dense_18', 'Dense_19', 'Dense_20', 'Dense_21', 'Dense_22', 'Dense_23', 'Dense_24', 'Dense_25', 'Dense_26', 'Dense_27', 'Dense_28', 'Dense_29', 'Dense_30', 'Dense_31', 'Dense_32', 'Dense_33', 'Dense_34', 'Dense_35', 'Dense_36', 'Dense_37', 'Dense_38', 'Dense_39', 'Dense_40', 'Dense_41', 'Dense_42', 'Dense_43', 'Dense_44', 'Dense_45', 'Dense_46', 'Dense_47', 'Dense_48', 'Dense_49', 'Dense_50', 'Dense_51', 'Dense_52', 'Dense_53', 'Dense_54', 'Dense_55', 'Dense_56', 'Dense_57', 'Dense_58', 'Dense_59', 'Dense_60', 'Dense_61', 'Dense_62', 'Dense_63', 'Dense_64', 'Dense_65', 'Dense_66', 'Dense_67', 'Dense_68', 'Dense_69', 'Dense_70', 'Dense_71', 'Dense_72', 'Dense_73', 'Dense_74', 'Dense_75', 'Dense_76', 'Dense_77', 'Dense_78', 'Dense_79', 'Dense_80', 'Dense_81', 'Dense_82', 'De

In [None]:
outputs_dict['Dense_0']

In [30]:
grads_dict.keys()

dict_keys(['Output/bias:0', 'Output/kernel:0', 'Dense_99/bias:0', 'Dense_99/kernel:0', 'Dense_98/bias:0', 'Dense_98/kernel:0', 'Dense_97/bias:0', 'Dense_97/kernel:0', 'Dense_96/bias:0', 'Dense_96/kernel:0', 'Dense_95/bias:0', 'Dense_95/kernel:0', 'Dense_94/bias:0', 'Dense_94/kernel:0', 'Dense_93/bias:0', 'Dense_93/kernel:0', 'Dense_92/bias:0', 'Dense_92/kernel:0', 'Dense_91/bias:0', 'Dense_91/kernel:0', 'Dense_90/bias:0', 'Dense_90/kernel:0', 'Dense_89/bias:0', 'Dense_89/kernel:0', 'Dense_88/bias:0', 'Dense_88/kernel:0', 'Dense_87/bias:0', 'Dense_87/kernel:0', 'Dense_86/bias:0', 'Dense_86/kernel:0', 'Dense_85/bias:0', 'Dense_85/kernel:0', 'Dense_84/bias:0', 'Dense_84/kernel:0', 'Dense_83/bias:0', 'Dense_83/kernel:0', 'Dense_82/bias:0', 'Dense_82/kernel:0', 'Dense_81/bias:0', 'Dense_81/kernel:0', 'Dense_80/bias:0', 'Dense_80/kernel:0', 'Dense_79/bias:0', 'Dense_79/kernel:0', 'Dense_78/bias:0', 'Dense_78/kernel:0', 'Dense_77/bias:0', 'Dense_77/kernel:0', 'Dense_76/bias:0', 'Dense_76/kern

In [35]:
np.set_printoptions(threshold=sys.maxsize)
grads_dict['Dense_91/kernel:0'].numpy()

array([[ 0., -0.,  0., -0.,  0.,  0.,  0., -0., -0.,  0., -0., -0., -0.,
         0., -0., -0.,  0.,  0., -0., -0.,  0., -0., -0.,  0., -0.,  0.,
        -0., -0., -0., -0.,  0.,  0., -0., -0., -0.,  0., -0.,  0.,  0.,
        -0.,  0., -0., -0., -0.,  0., -0.,  0., -0., -0., -0.,  0., -0.,
         0.,  0.,  0.,  0.,  0., -0.,  0.,  0., -0., -0., -0.,  0.,  0.,
        -0.,  0., -0.,  0.,  0.,  0., -0., -0., -0., -0., -0.,  0.,  0.,
        -0.,  0., -0., -0.,  0., -0.,  0., -0., -0.,  0., -0., -0.,  0.,
        -0.,  0., -0.,  0.,  0.,  0.,  0., -0., -0.],
       [-0.,  0., -0.,  0., -0., -0., -0.,  0.,  0., -0.,  0.,  0.,  0.,
        -0.,  0.,  0., -0., -0.,  0.,  0., -0.,  0.,  0., -0.,  0., -0.,
         0.,  0.,  0.,  0., -0., -0.,  0.,  0.,  0., -0.,  0., -0., -0.,
         0., -0.,  0.,  0.,  0., -0.,  0., -0.,  0.,  0.,  0., -0.,  0.,
        -0., -0., -0., -0., -0.,  0., -0., -0.,  0.,  0.,  0., -0., -0.,
         0., -0.,  0., -0., -0., -0.,  0.,  0.,  0.,  0.,  0., -0., -0

In [9]:
initializer = tf.keras.initializers.GlorotNormal()