In [1]:
%load_ext autoreload
%autoreload 2
import tensorflow as tf
import tensorflow_probability as tfp
import scipy
import numpy as np
import keras
import datetime

2024-05-20 08:02:20.745341: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-05-20 08:02:20.806385: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-20 08:02:20.806422: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-20 08:02:20.814061: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-20 08:02:20.826538: I tensorflow/core/platform/cpu_feature_guar

In [2]:
num_locations_S=4
# model these 4 locations with mixture of 4 components
num_components_C=num_locations_S
scale = 1
means = [5.0, 15.0, 25.0, 35.0]

# simulate data for each location
data_distributions = [scipy.stats.norm(loc=mean, scale=scale) for mean in means]

num_examples_T = 100
y_TS = np.array([dist.rvs(num_examples_T) for dist in data_distributions]).T

# x doesn't matter because we aren't actually learning
num_features_F = 20
x_TF = np.random.randn(num_examples_T, num_features_F)


In [3]:
class MixtureWeightLayer(keras.layers.Layer):
    """Dumb layer that just returns mixture weights
    Constrained to unit norm
    """
    def __init__(self, num_locations, num_components=2, **kwargs):
        super().__init__(**kwargs)
        self.w = self.add_weight(shape=(1,num_locations, num_components ),
            initializer="uniform",
            trainable=True,
        )
        
        self.softmax = keras.layers.Softmax(axis=1)

    def call(self, inputs):
        return self.softmax(self.w)

In [4]:
# build layers
inputs = keras.Input(shape=num_features_F)
component_layers = [keras.layers.Dense(1, activation='softplus') for _ in range(num_components_C)]
mixture_weight_layer = MixtureWeightLayer(num_locations_S, num_components_C)
mixture_weights = mixture_weight_layer(inputs)
assert(mixture_weights.shape == (1,num_locations_S, num_components_C))

# Add a component dimension to outputs of each component model
reshape_layer = keras.layers.Reshape(name='mix_reshape', target_shape=(-1,1))
# Concatenate components along new dimension
concat_layer = keras.layers.Concatenate(name='mix_concat',axis=-1)

# get tfp mixture model
mixture_distribution_layer = tfp.layers.DistributionLambda(lambda params: 
        tfp.distributions.MixtureSameFamily(mixture_distribution=
                                                tfp.distributions.Categorical(probs=params[0]),
                                            components_distribution=
                                                tfp.distributions.Normal(loc=params[1],
                                                                        scale=scale,
                                                                        validate_args=True)))


2024-05-20 07:46:35.617116: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2024-05-20 07:46:35.617156: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:129] retrieving CUDA diagnostic information for host: s1cmp008.pax.tufts.edu
2024-05-20 07:46:35.617161: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:136] hostname: s1cmp008.pax.tufts.edu
2024-05-20 07:46:35.617241: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:159] libcuda reported version is: 535.129.3
2024-05-20 07:46:35.617274: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:163] kernel reported version is: 535.129.3
2024-05-20 07:46:35.617279: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:241] kernel version seems to match DSO: 535.129.3


In [5]:
# build model
component_predictions = [component(inputs) for component in component_layers]
combined_components = concat_layer([reshape_layer(member) for member in component_predictions])
output_distribution = mixture_distribution_layer([mixture_weights, combined_components])
model = keras.Model(inputs=inputs,outputs=output_distribution)

In [6]:
learning_rate = 1e-3
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

In [7]:

def overall_gradient_calculation(gradient_BSp, decision_gradient_BS):

    # parameters have their own shape. Here we find that shape and add appropriate dimension to gradient so we can broadcast
    num_param_dims = tf.rank(gradient_BSp)-2
    num_param_dims_tf = tf.cast(num_param_dims, tf.int32)
    new_shape = tf.concat([tf.shape(decision_gradient_BS), tf.ones([num_param_dims_tf], tf.int32)], axis=0)
    decision_gradient_BSp=tf.reshape(decision_gradient_BS, new_shape)

    # chain rule, multiply gradients
    overall_gradient_BSp = gradient_BSp*decision_gradient_BSp

    # sum over batch and location
    overall_gradient = tf.reduce_sum(overall_gradient_BSp, axis=[0,1])
    return overall_gradient

def score_function_trick(jacobian_MBSp, decision_MBS):

    num_param_dims = tf.rank(jacobian_MBSp)-3
    
    # expand decision to match jacobian
    num_param_dims_tf = tf.cast(num_param_dims, tf.int32)
    new_shape = tf.concat([tf.shape(decision_MBS), tf.ones([num_param_dims_tf], tf.int32)], axis=0)
    decision_MBSp=tf.reshape(decision_MBS, new_shape)

    # do score function trick, scale gradient of log probability our function
    scaled_jacobian_MBSp = jacobian_MBSp*decision_MBSp

    # average over sample dims
    param_gradient_BSp = tf.reduce_mean(scaled_jacobian_MBSp, axis=0)
    return param_gradient_BSp


In [9]:
num_score_func_samples=100
logs = {}
stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = f'logs/func/{stamp}_samp{num_score_func_samples}'
writer = tf.summary.create_file_writer(logdir)
tf.summary.trace_on(graph=True, profiler=True)
with tf.GradientTape() as jacobian_tape, tf.GradientTape() as loss_tape:
    mixture_model = model(x_TF, training=True)
    sample_y_MBS = mixture_model.sample(num_score_func_samples)
    stopped_samples = tf.stop_gradient(sample_y_MBS)
    sample_log_probs_MBS = mixture_model.log_prob(stopped_samples)

    sample_decisions_MBS = tf.identity(sample_y_MBS)
    expected_decisions_BS = tf.reduce_mean(sample_decisions_MBS, axis=0)
    loss_B = keras.losses.mean_squared_error(y_TS, expected_decisions_BS)

jacobian_pMBS = jacobian_tape.jacobian(sample_log_probs_MBS, model.trainable_weights)
param_gradient_pBS = [score_function_trick(j, sample_decisions_MBS) for j in jacobian_pMBS]

loss_gradients_BS = loss_tape.gradient(loss_B, expected_decisions_BS)
overall_gradient = [overall_gradient_calculation(g, loss_gradients_BS) for g in param_gradient_pBS]

optimizer.apply_gradients(zip(overall_gradient, model.trainable_weights))

with writer.as_default():
    tf.summary.trace_export(
        name="my_func_trace",
        step=0,
        profiler_outdir=logdir)
    tf.summary.trace_off()

2024-05-20 07:48:20.392541: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.
2024-05-20 07:48:20.392568: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.


In [46]:
param_gradient_pBS

[<tf.Tensor: shape=(100, 4, 20, 1), dtype=float32, numpy=
 array([[[[ 3.4712605e-02],
          [ 5.1432997e-03],
          [-1.1872584e-03],
          ...,
          [ 3.0226505e-03],
          [-1.1700059e-02],
          [-1.0001410e-02]],
 
         [[ 3.3066045e-02],
          [ 4.8993328e-03],
          [-1.1309420e-03],
          ...,
          [ 2.8792738e-03],
          [-1.1145078e-02],
          [-9.5270034e-03]],
 
         [[ 3.0951781e-02],
          [ 4.5860657e-03],
          [-1.0586287e-03],
          ...,
          [ 2.6951709e-03],
          [-1.0432455e-02],
          [-8.9178402e-03]],
 
         [[ 3.7666213e-02],
          [ 5.5809300e-03],
          [-1.2882793e-03],
          ...,
          [ 3.2798396e-03],
          [-1.2695586e-02],
          [-1.0852406e-02]]],
 
 
        [[[-4.0223960e-02],
          [-1.7453597e-01],
          [ 1.4151731e-01],
          ...,
          [-3.3374926e-01],
          [ 2.6767838e-01],
          [-2.0758545e-03]],
 
         

In [48]:
# Set linear weights to 0
for index in [0, 2, 4, 6]:
    variable = model.trainable_variables[index]
    print(f'Name: {variable.name}')
    print(f'Shape: {variable.shape}')

    variable.assign(tf.zeros(variable.shape))
# set biases to true value
for index, bias in zip([1, 3, 5, 7], means):
    variable = model.trainable_variables[index]
    print(f'Name: {variable.name}')
    print(f'Shape: {variable.shape}')

    variable.assign(tfp.math.softplus_inverse([bias]))
index = -1
variable = model.trainable_variables[index]
print(f'Name: {variable.name}')
print(f'Shape: {variable.shape}')

# mixture weights are an identity matrix (1st component = first location). Add batch dimension
weights = tf.expand_dims(tf.eye(4), axis=0)
# log because we are using softmax
variable.assign(tf.math.log(weights+1e-13))

Name: dense_20/kernel:0
Shape: (20, 1)
Name: dense_21/kernel:0
Shape: (20, 1)
Name: dense_22/kernel:0
Shape: (20, 1)
Name: dense_23/kernel:0
Shape: (20, 1)
Name: dense_20/bias:0
Shape: (1,)
Name: dense_21/bias:0
Shape: (1,)
Name: dense_22/bias:0
Shape: (1,)
Name: dense_23/bias:0
Shape: (1,)
Name: Variable:0
Shape: (1, 4, 4)


<tf.Variable 'UnreadVariable' shape=(1, 4, 4) dtype=float32, numpy=
array([[[  0.      , -29.933605, -29.933605, -29.933605],
        [-29.933605,   0.      , -29.933605, -29.933605],
        [-29.933605, -29.933605,   0.      , -29.933605],
        [-29.933605, -29.933605, -29.933605,   0.      ]]], dtype=float32)>

In [63]:
@tf.function(autograph=False)
def get_prob_decisions(model):
    sample_y_MBS = model.sample(num_score_func_samples)
    stopped_samples = tf.stop_gradient(sample_y_MBS)
    sample_log_probs_MBS = model.log_prob(stopped_samples)

    sample_decisions_MBS = tf.identity(sample_y_MBS)
    expected_decisions_BS = tf.reduce_mean(sample_decisions_MBS, axis=0)
    return sample_log_probs_MBS, sample_decisions_MBS,  expected_decisions_BS

@tf.function
def get_loss(y_batch, expected_decisions_BS):
    loss_B = keras.losses.mean_squared_error(y_batch, expected_decisions_BS)
    return loss_B

In [64]:
@tf.function
def train_step(model, x_batch, y_batch):
  with tf.GradientTape() as jacobian_tape, tf.GradientTape() as loss_tape:
      mixture_model = model(x_batch, training=True)
      sample_log_probs_MBS, sample_decisions_MBS,  expected_decisions_BS = get_prob_decisions(mixture_model)
      loss_B = get_loss(y_batch, expected_decisions_BS)

  jacobian_pMBS = jacobian_tape.jacobian(sample_log_probs_MBS, model.trainable_weights)
  param_gradient_pBS = [score_function_trick(j, sample_decisions_MBS) for j in jacobian_pMBS]
  
  loss_gradients_BS = loss_tape.gradient(loss_B, expected_decisions_BS)
  overall_gradient = [overall_gradient_calculation(g, loss_gradients_BS) for g in param_gradient_pBS]

  optimizer.apply_gradients(zip(overall_gradient, model.trainable_weights))
  return loss_B
     

In [65]:
logs = {}
stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = 'logs/func/%s' % stamp
writer = tf.summary.create_file_writer(logdir)
#callbacks.on_train_begin(logs=logs)
for epoch in range(1):
    #callbacks.on_epoch_begin(epoch, logs=logs)
    #for batch, (x_batch, y_batch) in enumerate(train_dataset):
        #callbacks.on_batch_begin(batch, logs=logs)
        #callbacks.on_train_batch_begin(batch, logs=logs)
        
        tf.summary.trace_on(graph=True, profiler=True)
        with tf.GradientTape() as jacobian_tape, tf.GradientTape() as loss_tape:
            mixture_model = model(x_TF, training=True)
            sample_log_probs_MBS, sample_decisions_MBS,  expected_decisions_BS = get_prob_decisions(mixture_model)
            loss_B = get_loss(y_TS, expected_decisions_BS)

        jacobian_pMBS = jacobian_tape.jacobian(sample_log_probs_MBS, model.trainable_weights)
        param_gradient_pBS = [score_function_trick(j, sample_decisions_MBS) for j in jacobian_pMBS]
        
        loss_gradients_BS = loss_tape.gradient(loss_B, expected_decisions_BS)
        overall_gradient = [overall_gradient_calculation(g, loss_gradients_BS) for g in param_gradient_pBS]

        optimizer.apply_gradients(zip(overall_gradient, model.trainable_weights))
        with writer.as_default():
            tf.summary.trace_export(
                name="my_func_trace",
                step=0,
                profiler_outdir=logdir)
        tf.summary.trace_off()
        #callbacks.on_test_batch_end(batch, logs=logs)
        #callbacks.on_batch_end(batch, logs=logs)
    #callbacks.on_epoch_end(epoch, logs=logs)
#callbacks.on_train_end(logs=logs)


2024-05-17 14:34:53.558406: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.
2024-05-17 14:34:53.558434: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.
