In [2]:
%load_ext autoreload
%autoreload 2
import tensorflow as tf
import tensorflow_probability as tfp
import scipy
import numpy as np
import keras
import datetime

2024-05-20 07:56:41.269069: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-05-20 07:56:41.332575: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-20 07:56:41.332602: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-20 07:56:41.333683: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-20 07:56:41.341091: I tensorflow/core/platform/cpu_feature_guar

In [3]:
num_locations_S=4
# model these 4 locations with mixture of 4 components
num_components_C=num_locations_S
scale = 1
means = [5.0, 15.0, 25.0, 35.0]

# simulate data for each location
data_distributions = [scipy.stats.norm(loc=mean, scale=scale) for mean in means]

num_examples_T = 100
y_TS = np.array([dist.rvs(num_examples_T) for dist in data_distributions]).T

# x doesn't matter because we aren't actually learning
num_features_F = 20
x_TF = np.random.randn(num_examples_T, num_features_F)


In [4]:
class MixtureWeightLayer(keras.layers.Layer):
    """Dumb layer that just returns mixture weights
    Constrained to unit norm
    """
    def __init__(self, num_locations, num_components=2, **kwargs):
        super().__init__(**kwargs)
        self.w = self.add_weight(shape=(1,num_locations, num_components ),
            initializer="uniform",
            trainable=True,
        )
        
        self.softmax = keras.layers.Softmax(axis=1)

    def call(self, inputs):
        return self.softmax(self.w)

In [38]:
num_score_func_samples=50

In [39]:
# build layers
inputs = keras.Input(shape=num_features_F)
component_layers = [keras.layers.Dense(1, activation='softplus') for _ in range(num_components_C)]
mixture_weight_layer = MixtureWeightLayer(num_locations_S, num_components_C)
mixture_weights = mixture_weight_layer(inputs)
assert(mixture_weights.shape == (1,num_locations_S, num_components_C))

# Add a component dimension to outputs of each component model
reshape_layer = keras.layers.Reshape(name='mix_reshape', target_shape=(-1,1))
# Concatenate components along new dimension
concat_layer = keras.layers.Concatenate(name='mix_concat',axis=-1)

# get tfp mixture model
mixture_distribution_layer = tfp.layers.DistributionLambda(lambda params: 
        tfp.distributions.MixtureSameFamily(mixture_distribution=
                                                tfp.distributions.Categorical(probs=params[0]),
                                            components_distribution=
                                                tfp.distributions.Normal(loc=params[1],
                                                                        scale=scale,
                                                                        validate_args=True)),
                                                           convert_to_tensor_fn=lambda s: s.sample(num_score_func_samples))


In [40]:
# build model
component_predictions = [component(inputs) for component in component_layers]
combined_components = concat_layer([reshape_layer(member) for member in component_predictions])
output_distribution = mixture_distribution_layer([mixture_weights, combined_components])
model = keras.Model(inputs=inputs,outputs=output_distribution)

In [41]:
learning_rate = 1e-3
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

In [42]:

def overall_gradient_calculation(gradient_BSp, decision_gradient_BS):

    # parameters have their own shape. Here we find that shape and add appropriate dimension to gradient so we can broadcast
    num_param_dims = tf.rank(gradient_BSp)-2
    num_param_dims_tf = tf.cast(num_param_dims, tf.int32)
    new_shape = tf.concat([tf.shape(decision_gradient_BS), tf.ones([num_param_dims_tf], tf.int32)], axis=0)
    decision_gradient_BSp=tf.reshape(decision_gradient_BS, new_shape)

    # chain rule, multiply gradients
    overall_gradient_BSp = gradient_BSp*decision_gradient_BSp

    # sum over batch and location
    overall_gradient = tf.reduce_sum(overall_gradient_BSp, axis=[0,1])
    return overall_gradient

def score_function_trick(jacobian_MBSp, decision_MBS):

    num_param_dims = tf.rank(jacobian_MBSp)-3
    
    # expand decision to match jacobian
    num_param_dims_tf = tf.cast(num_param_dims, tf.int32)
    new_shape = tf.concat([tf.shape(decision_MBS), tf.ones([num_param_dims_tf], tf.int32)], axis=0)
    decision_MBSp=tf.reshape(decision_MBS, new_shape)

    # do score function trick, scale gradient of log probability our function
    scaled_jacobian_MBSp = jacobian_MBSp*decision_MBSp

    # average over sample dims
    param_gradient_BSp = tf.reduce_mean(scaled_jacobian_MBSp, axis=0)
    return param_gradient_BSp


In [43]:
@tf.function
def get_log_probs(mixture_model):
    stopped_samples = tf.stop_gradient(mixture_model)
    sample_log_probs_MBS = mixture_model.log_prob(stopped_samples)
    return sample_log_probs_MBS

In [44]:

logs = {}
stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = f'logs/func/{stamp}_coerce_tffuncsamp{num_score_func_samples}'
writer = tf.summary.create_file_writer(logdir)
tf.summary.trace_on(graph=True, profiler=True)
with tf.GradientTape() as jacobian_tape, tf.GradientTape() as loss_tape:
    mixture_model = model(x_TF, training=True)
    #sample_y_MBS = mixture_model.sample(num_score_func_samples)
    #stopped_samples = tf.stop_gradient(mixture_model)
    #sample_log_probs_MBS = mixture_model.log_prob(stopped_samples)
    sample_log_probs_MBS = get_log_probs(mixture_model)

    sample_decisions_MBS = tf.identity(mixture_model)
    expected_decisions_BS = tf.reduce_mean(sample_decisions_MBS, axis=0)
    loss_B = keras.losses.mean_squared_error(y_TS, expected_decisions_BS)

jacobian_pMBS = jacobian_tape.jacobian(sample_log_probs_MBS, model.trainable_weights)
param_gradient_pBS = [score_function_trick(j, sample_decisions_MBS) for j in jacobian_pMBS]

loss_gradients_BS = loss_tape.gradient(loss_B, expected_decisions_BS)
overall_gradient = [overall_gradient_calculation(g, loss_gradients_BS) for g in param_gradient_pBS]

optimizer.apply_gradients(zip(overall_gradient, model.trainable_weights))

with writer.as_default():
    tf.summary.trace_export(
        name="my_func_trace",
        step=0,
        profiler_outdir=logdir)
    tf.summary.trace_off()

2024-05-20 08:23:58.941309: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:104] Profiler session initializing.
2024-05-20 08:23:58.941338: I external/local_tsl/tsl/profiler/lib/profiler_session.cc:119] Profiler session started.


In [15]:
sample_log_probs_MBS.shape

<tf.Tensor: shape=(100, 4), dtype=float32, numpy=
array([[-1.0694184 , -1.3291337 , -1.059331  , -1.6327939 ],
       [-4.17158   , -1.0699117 , -1.1160941 , -1.4857092 ],
       [-1.4357271 , -1.2709595 , -1.1660409 , -3.5272446 ],
       [-1.6549181 , -1.2147293 , -2.0589461 , -2.158225  ],
       [-1.0923643 , -0.96904993, -1.8280009 , -0.93384254],
       [-1.4407645 , -1.4145877 , -3.8095746 , -2.2560375 ],
       [-1.2318821 , -1.1555796 , -1.1743959 , -1.4317412 ],
       [-1.318476  , -2.421885  , -1.4265397 , -2.0935707 ],
       [-1.8221313 , -1.1581162 , -1.4536357 , -1.2060163 ],
       [-7.105071  , -1.0565109 , -1.0803047 , -1.2971982 ],
       [-1.7958286 , -1.5772939 , -1.3834233 , -1.8899655 ],
       [-1.1924962 , -1.4505279 , -0.9608245 , -0.96449506],
       [-1.2257168 , -1.4363852 , -1.0039036 , -1.0558329 ],
       [-1.4705768 , -1.0905652 , -1.0862833 , -1.0654685 ],
       [-1.7284241 , -1.5138774 , -1.5474112 , -1.9084129 ],
       [-1.0247228 , -1.5127444 , -

In [17]:
stopped_samples.shape

TensorShape([100, 4])