# simple regression experiments

using ray.tune for parallelization


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import os
# os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# from luminous_plugin_dir import set_luminous_params
from datetime import datetime
import tensorflow.keras as keras
from ray import tune
import ray

import pandas as pd
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.callbacks import ModelCheckpoint


# set_luminous_params(log_level=PluginLogLevel.WARNING)

def accum_dict_to_df(lcurves):

    df_list = []
    for key,value in lcurves.items():

        df=pd.DataFrame(value).applymap(lambda x: x.numpy())
        df["epoch"] = df.index+1
        df["accum_bits"] = key
        df_list.append(df)

    df_lc = pd.concat(df_list, ignore_index=True)
    return df_lc



def shallow_linear(num_inputs, suffix=''):
        inputs = keras.Input(shape=(num_inputs,), name="features"+suffix)
        outputs = layers.Dense(1, name="prediction"+suffix)(inputs)
        model = keras.Model(inputs=inputs, outputs=outputs)
        return model

def deep_linear(num_inputs, hidden = None, suffix='', activation=None):
    hidden = hidden or [128]
    inputs = keras.Input(shape=(num_inputs,), name="features"+suffix)
    x = inputs
    for num_units in hidden:
        x = layers.Dense(num_units, activation=activation or None)(x)
    outputs = layers.Dense(1, name="prediction"+suffix)(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

def create_model(learning_rate, batch_size):
    import numpy as np
    import os
#     os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import layers
   
    
    num_inputs = 8
    wts_scale = 1
    err_scale = 1e-2
    
    hidden = [4]*8
    act_fn = 'gelu'
    steps_per_epoch = 1
    

    #data params
    num_train = 0
    num_eval = 5000
    
    np.random.seed(222)
    tf.random.set_seed(222)
    Xdata = tf.random.normal(shape=(num_train+num_eval,num_inputs), dtype=tf.float32)


    model_fn = lambda : deep_linear(num_inputs, hidden=hidden, activation=act_fn)
    
    
    #create reference model and ground truth data
    with tf.device("cpu"):
     
        ref_model = model_fn()
        ref_model.compile()
        ref_weights = ref_model.get_weights()
        ref_init = [wts_scale*np.float32(np.random.randn(*x.shape)) for x in ref_weights]
        
        
        ref_model.set_weights(ref_init)
        ydata = ref_model(Xdata)
       
        model_wts_init = [x + err_scale*np.float32(np.random.randn(*x.shape)) for x in ref_init]
    
#     with tf.device("cpu"):

        # create trained model
        model = model_fn()
        

       # Instantiate an optimizer.
        optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
        # Instantiate a loss function.
        loss_fn = tf.keras.losses.MeanSquaredError()
        model.compile(
            optimizer=optimizer,
            loss=loss_fn,
            metrics=['mse']
        )
        # run fit 1x for tf-plugin to work correctly
        X = np.zeros(shape=(1, num_inputs), dtype = np.float32)
        Y = np.array([0], dtype=np.float32)
        model.fit(X, Y, epochs=1, verbose=1)


        # train val split      
        Xtrain = Xdata[0:num_train,:]
        ytrain = ydata[0:num_train,:]

        Xval = Xdata[num_train::,:]
        yval = ydata[num_train::,:]
         # run inference for tf-plugin to work correctly
        model(Xval, training=False)
        
        
        def data_generator():
            for i in range(steps_per_epoch):
                x = tf.random.normal((batch_size, num_inputs), dtype=tf.float32)
                y = tf.cast(ref_model(x),tf.float32)
                yield x,y

#         train_data = tf.data.Dataset.from_tensor_slices((Xtrain, ytrain))
#         train_data = train_data.shuffle(1024).batch(batch_size)
        train_data = tf.data.Dataset.from_generator(data_generator, (tf.float32, tf.float32))

        train_mse_metric = tf.keras.metrics.MeanSquaredError()
        mse = tf.keras.losses.MeanSquaredError()
                
        val_data = (Xval, yval)
        
        model.set_weights(model_wts_init)
        
    return model, train_data, val_data


class TuneReporterCallback(keras.callbacks.Callback):
    """Tune Callback for Keras.
    
    The callback is invoked every epoch.
    """

    def __init__(self, logs={}):
        self.iteration = 0
        super(TuneReporterCallback, self).__init__()

    def on_epoch_end(self, batch, logs={}):
        self.iteration += 1
        tune.report(keras_info=logs, mean_mse=logs.get("mse"),mean_loss=logs.get("loss"))

def tune_model(config):  # TODO: Change me.
    from luminous_plugin import PluginLogLevel, set_luminous_params
    model, train_data, val_data = create_model(config["lr"], config["batch_size"])
   
    
    with tf.device('cpu'):
        checkpoint_callback = ModelCheckpoint(
            "model.h5", monitor='loss', save_best_only=True, save_freq=1)

    # Enable Tune to make intermediate decisions by using a Tune Callback hook. This is Keras specific.
   
        callbacks = [checkpoint_callback, TuneReporterCallback()]
    
    set_luminous_params(bfloat16_acc_bits=config['accum_bits'],log_level=PluginLogLevel.WARNING)
    
    # Train the model
    model.fit(
        train_data, 
        validation_data=val_data,
        verbose=0,  
        epochs=50, 
        callbacks=callbacks)


hyperparameter_space = {
    "lr": tune.grid_search([1e-3]),  
    "batch_size": tune.grid_search([10000]),
    "accum_bits": tune.grid_search([24, 20, 18, 16]),
}

ray.shutdown()  # Restart Ray defensively in case the ray connection is lost. 
ray.init(log_to_driver=False)

analysis = tune.run(
    tune_model, 
    verbose=1, 
    config=hyperparameter_space,
    num_samples=1)

df = pd.merge(pd.concat(analysis.trial_dataframes, ignore_index=True),analysis.dataframe(),on='experiment_id' )


fig2=sns.relplot(data=df, x="training_iteration_x", y="mean_mse_x", style='config/batch_size',hue="config/accum_bits", kind='line', height=10)
fig2.set(yscale='log'), plt.grid(), plt.xlabel('Epochs'), plt.ylabel('Val MSE')


In [None]:
fig2=sns.relplot(data=df, x="training_iteration_x", y="mean_mse_x",hue="config/lr",style="config/batch_size", kind='line', height=10)
fig2.set(yscale='log'), plt.grid(), plt.xlabel('Epochs'), plt.ylabel('Val MSE')


In [None]:
analysis = tune.ExperimentAnalysis('/root/ray_results/tune_model_2021-09-08_21-58-53/experiment_state-2021-09-08_21-58-53.json')

In [None]:
df = pd.merge(pd.concat(analysis.trial_dataframes, ignore_index=True),analysis.dataframe(),on='experiment_id' )

fig2=sns.relplot(data=df, x="training_iteration_x", y="mean_mse_x", style='config/batch_size',hue="config/accum_bits", kind='line', height=10)
fig2.set(yscale='log'), plt.grid(), plt.xlabel('Epochs'), plt.ylabel('Val MSE')


In [None]:
#!cat /root/ray_results/tune_model_2021-09-06_14-42-20/tune_model_a39e9_00004_4_accum_bits=12_2021-09-06_14-42-20/error.txt

!cat /root/ray_results/tune_model_2021-09-07_19-27-51/tune_model_b12d5_00003_3_accum_bits=12_2021-09-07_19-27-51/error.txt




In [None]:
fig2=sns.relplot(data=df, x="training_iteration_x", y='mean_mse_x', style='config/batch_size',hue="config/accum_bits", kind='line', height=10)
fig2.set(yscale='log'), plt.grid(), plt.xlabel('Epochs'), plt.ylabel('Val MSE')


using custom training step, w/o parallelization (much slower)

In [None]:
import numpy as np

import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from luminous_plugin_dir import set_luminous_params
from datetime import datetime

import pandas as pd


def accum_dict_to_df(lcurves):

    df_list = []
    for key,value in lcurves.items():

        df=pd.DataFrame(value).applymap(lambda x: x.numpy())
        df["epoch"] = df.index+1
        df["accum_bits"] = key
        df_list.append(df)

    df_lc = pd.concat(df_list, ignore_index=True)
    return df_lc


import seaborn as sns
import matplotlib.pyplot as plt
#model params
num_inputs = 8
wts_scale = 3
np.random.seed(113)
# model_weights = wts_scale*np.float32(np.random.randn(num_inputs,1))
# model_bias = wts_scale*np.float32(np.random.randn(1))

hidden = [128]
act_fn = "linear"

err_scale = .001
# wts_error_init = err_scale*np.float32(np.random.randn(num_inputs,1))
# bias_error_init = err_scale*np.float32(np.random.randn(1))

#data params
num_train = 3000
num_eval = 300
tf.random.set_seed(111)





# logdir = "logs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S")



def shallow_linear(num_inputs, suffix=''):
    inputs = keras.Input(shape=(num_inputs,), name="features"+suffix)
    outputs = layers.Dense(1, name="prediction"+suffix)(inputs)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

def deep_linear(num_inputs, hidden = None, suffix='', activation=None):
    hidden = hidden or [128]
    inputs = keras.Input(shape=(num_inputs,), name="features"+suffix)
    x = inputs
    for num_units in hidden:
        x = layers.Dense(num_units, activation=activation or None)(x)
    outputs = layers.Dense(1, name="prediction"+suffix)(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model
    

    
    
Xdata = tf.random.normal(shape=(num_train+num_eval,num_inputs), dtype=tf.float32)
with tf.device("cpu"):
    ref_model = deep_linear(num_inputs, hidden = hidden, activation = act_fn)
    ref_weights = ref_model.get_weights()
    ref_init = [wts_scale*np.float32(np.random.randn(*x.shape)) for x in ref_weights]
    ref_model.set_weights(ref_init)
    # ydata = tf.matmul(Xdata,model_weights) + model_bias
    ref_model.compile()
    ydata = ref_model(Xdata)
    np.random.seed(11)
    model_wts_init = [x + err_scale*np.float32(np.random.randn(*x.shape)) for x in ref_init]

def exp_run(model_weights_init, accum_bits, num_train, batch_size, num_epochs, learning_rate):

    with tf.device("cpu"):


        model = deep_linear(num_inputs,hidden, activation=act_fn)
        model.set_weights(model_weights_init)
            # Instantiate an optimizer.
#         learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
#             learning_rate,
#             decay_steps=int(100000/batch_size),
#             decay_rate=0.5,
#             staircase=True)
        optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
        # Instantiate a loss function.
        loss_fn = tf.keras.losses.MeanSquaredError()
        model.compile(
            optimizer=optimizer,
            loss=loss_fn,
        )

        X = np.zeros(shape=(1, num_inputs), dtype = np.float32)
        Y = np.array([0], dtype=np.float32)
        model.fit(X, Y, epochs=1, verbose=1)
        
        Xtrain = Xdata[0:num_train,:]
        ytrain = ydata[0:num_train,:]

        Xval = Xdata[num_train::,:]
        yval = ydata[num_train::,:]

        model(Xval, training=False)

        train_data = tf.data.Dataset.from_tensor_slices((Xtrain, ytrain))
        train_data = train_data.shuffle(num_train).batch(batch_size)

        train_mse_metric = tf.keras.metrics.MeanSquaredError()
        mse = tf.keras.losses.MeanSquaredError()


    set_luminous_params(bfloat16_acc_bits=accum_bits)

    # Keep results for plotting
    train_mse_results = []
    val_mse_results = []
    weights_mse = []
    bias_mse = []
    val_mae = []
    grad_norms = []


    @tf.function
    def max_abs_error(pred, truth):
        return tf.reduce_max(tf.abs(pred-truth))

    @tf.function
    def train_step(x_batch, y_true):
        with tf.GradientTape() as tape:

            y_pred = model(x_batch, training=True)

            batch_loss = loss_fn(y_true, y_pred)

            grads = tape.gradient(batch_loss, model.trainable_weights)

        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        grad_norm = tf.norm(grads[0])

        train_mse_metric.update_state(y_true, y_pred)
        return batch_loss, grad_norm




    for epoch in range(num_epochs):
        
        print("epoch", epoch)
        grad_norm_steps = []
        for step, (x_batch, y_true) in enumerate(train_data):
            batch_loss, grad_norm = train_step(x_batch, y_true)
#             grad_norm_steps.append(grad_norm)


        train_mse = train_mse_metric.result()
        train_mse_results.append(train_mse)
        train_mse_metric.reset_states()
#         grad_norms.append(tf.reduce_mean(grad_norm_steps))

    #     # Run a validation loop at the end of each epoch.


        val_pred = model(Xval, training=False)
        val_mse_results.append(mse(yval, val_pred))
        val_mae.append(max_abs_error(yval, val_pred))

#         weights_mse.append(mse(model_weights, model.weights[0]))
#         bias_mse.append(mse(model_bias,model.weights[1]))

    return {"train_mse": train_mse_results,
            "val_mse": val_mse_results,
#             "weights_mse": weights_mse,
#             "bias_mse": bias_mse,
            "val_mae": val_mae,
#             "grad_norm": grad_norms
           }, model.get_weights()

num_epochs = 50

df_batch_list=[]
for batch_size in [100]:
    
    df_list = []
    for lr in [1e-4]:#[2e-7]:

        lcurves = {}
        final_weights ={}

        for accum_bits in [24, 20, 16]:

            
            lcurves[accum_bits], final_weights[accum_bits] = exp_run(model_wts_init, 
                                                         accum_bits, 
                                                         num_train, 
                                                         batch_size, 
                                                         num_epochs,
                                                         lr)

        
        df = accum_dict_to_df(lcurves)
        df['learning_rate'] = lr
        df_list.append(df)
    
    df_batch = pd.concat(df_list,ignore_index=True)
    df_batch['batch_size']=batch_size
    df_batch_list.append(df_batch)

df_curves = pd.concat(df_batch_list,ignore_index=True)
print('done.')  

fig2=sns.relplot(data=df_curves, x="epoch", y="val_mse",hue="accum_bits", kind='line', height=10)
fig2.set(yscale='log'), plt.grid()