In [1]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions 

In [2]:
tfp.__version__

'0.12.0-dev20200719'

In [7]:
num_groups = 3.
joint_model = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1., name='z_0'),       
    tfd.HalfCauchy(loc=tf.zeros([3]), scale=2., name='lambda_k'),
    lambda lambda_k, z_0: tfd.MultivariateNormalDiag( # z_k ~ MVN(z_0, lambda_k)
        loc=z_0[...,tf.newaxis],
        scale_diag=lambda_k,
        name='z_k'),
])

In [13]:
num_groups = 3.
joint_model = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1., name='z_0'),       
    tfd.Independent(tfd.HalfCauchy(loc=tf.zeros([3]), scale=2., name='lambda_k'), reinterpreted_batch_ndims=1),
    lambda lambda_k, z_0: tfd.MultivariateNormalDiag( # z_k ~ MVN(z_0, lambda_k)
        loc=z_0[...,tf.newaxis],
        scale_diag=lambda_k,
        name='z_k'),
])

In [29]:
def affine(x, kernel_diag, bias=tf.zeros([])):
    """`kernel_diag * x + bias` with broadcasting."""
    kernel_diag = tf.ones_like(x) * kernel_diag
    bias = tf.ones_like(x) * bias
    return x * kernel_diag + bias

In [32]:
num_groups = 3
joint_model = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1., name='z_0'),       
    tfd.HalfCauchy(loc=tf.zeros([3]), scale=2., name='lambda_k'),
    lambda lambda_k, z_0: tfd.MultivariateNormalDiag( # z_k ~ MVN(z_0, lambda_k)
        loc=affine(tf.ones([num_groups]), z_0[...,tf.newaxis]),
        scale_diag=lambda_k,
        name='z_k'),
])

In [14]:
joint_model

<tfp.distributions.JointDistributionSequential 'JointDistributionSequential' batch_shape=[[], [], []] event_shape=[[], [3], [3]] dtype=[float32, float32, float32]>

In [15]:
joint_model.sample(1)

[<tf.Tensor: shape=(1,), dtype=float32, numpy=array([-0.6760397], dtype=float32)>,
 <tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[26.191511  ,  0.46299776,  0.58484864]], dtype=float32)>,
 <tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[-12.368133  ,  -0.30241722,  -0.7617348 ]], dtype=float32)>]

In [16]:
joint_model.log_prob(joint_model.sample())
# ERROR 
joint_model.log_prob(joint_model.sample(4))

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([-17.527115, -22.11274 , -16.917967, -20.59402 ], dtype=float32)>

In [17]:
from tensorflow.keras.layers import Dense

In [111]:
class TestLayer(tf.keras.layers.Layer):
    
    def __init__(self, units, **kwargs):
        super(TestLayer, self).__init__(**kwargs)
        self.units=units
        
        
    def build(self, input_shape):
        self.dense = Dense(self.units)
        
    @tf.function  
    def call(self, x):
        
        out = self.dense(x)
        #import ipdb; ipdb.set_trace()
        self.add_loss(lambda: tf.reduce_sum(tf.square(self.dense.kernel)))
        #self.add_loss(tf.reduce_sum(x))
        return out

In [112]:
# model = tf.keras.models.Sequential([
#     TestLayer(1)
# ])
xi = tf.keras.layers.Input(shape=[1], batch_size=10)
layer = TestLayer(1)
out = layer(xi)
model = tf.keras.models.Model(inputs=xi, outputs=out)

model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.1), loss='mse', metrics='mse')

In [113]:
import numpy as np
np.random.seed(4)
x = tf.convert_to_tensor(np.random.randn(10))[..., tf.newaxis]
y = tf.convert_to_tensor(2*x + 0.5 + 0.1*np.random.randn(10))[..., tf.newaxis]

In [114]:

model.fit(x,y, batch_size=10, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7fc700328790>