In [420]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions 
import numpy as np

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input

In [421]:
tf.matmul(np.random.randn(10,3), np.random.randn(3,1))

<tf.Tensor: shape=(10, 1), dtype=float64, numpy=
array([[ 3.4109625 ],
       [-0.76112431],
       [ 0.48006925],
       [-1.83191827],
       [ 1.08969601],
       [-1.26720735],
       [-1.65296943],
       [-0.02976456],
       [ 3.31460695],
       [ 0.80289594]])>

In [422]:
tf.einsum('Bi,ji-> Bj', tf.random.normal((20,2)), tf.random.normal((10,2)))

<tf.Tensor: shape=(20, 10), dtype=float32, numpy=
array([[ 5.1593208e-01,  2.6104140e-01,  8.6369559e-02,  2.4951875e-01,
         1.6564128e-01,  1.3348212e+00, -7.9781890e-02, -3.5307246e-01,
         5.4648530e-01, -6.1291951e-01],
       [-1.0573052e-01,  5.8942366e-02, -6.6476643e-01,  5.8033776e-01,
         4.3039382e-01, -2.5066190e+00,  1.8957260e-01,  1.2130514e+00,
        -7.9454148e-01,  1.5436103e+00],
       [ 1.2663823e+00,  7.9205292e-01, -6.5879077e-01,  1.4622597e+00,
         1.0314589e+00,  2.7123299e-01,  3.7285637e-02,  6.6845578e-01,
         4.2283610e-01,  4.0383461e-01],
       [ 3.7207988e-01,  3.4152952e-01, -8.1977254e-01,  1.0407503e+00,
         7.5242907e-01, -2.0814068e+00,  1.7859469e-01,  1.3003315e+00,
        -5.3631562e-01,  1.4909524e+00],
       [ 8.0117995e-01,  2.8915757e-01,  8.0288547e-01, -2.6517391e-01,
        -2.2268820e-01,  4.3807673e+00, -3.0292287e-01, -1.7272242e+00,
         1.5540621e+00, -2.4173419e+00],
       [ 1.7577991e+00,  

In [423]:
class MyDense(tf.keras.layers.Layer):
    
    def __init__(self, units, **kwargs):
        super(MyDense, self).__init__(**kwargs)
        self.units = units
        
    def build(self, input_shape):
        last_dim = input_shape[-1]
        self.kernel = self.add_weight(
            shape=(self.units,last_dim),
            initializer='random_normal', name='kernel')
        self.bias = self.add_weight(
            shape=(self.units,),
            initializer='zeros', name='bias')
        
        super(MyDense, self).build(input_shape)
    
    @tf.function
    def call(self, x):
        self.add_loss(lambda: tf.reduce_sum(tf.square(self.kernel)))
        return tf.einsum('Bi, ji -> Bj', x, self.kernel) + self.bias
        #return tf.matmul(x, self.kernel, transpose_b=True)

In [424]:
class MyModel(tf.keras.Model):

    def __init__(self, *args, **kwargs):
        super(MyModel, self).__init__(*args, **kwargs)

        #self.dense = MyDense(10)
        self.dense = MyDense(1)
    
    @tf.function
    def call(self, x):
        #print(self.losses)
        #x = self.dense(x)
        out = self.dense(x)
        return out

In [425]:
x = tf.random.normal([20,2])
y = tf.random.normal([20,1])

In [426]:
model = MyModel()

In [427]:
#model.build((20,2))

In [428]:
#model.summary()

In [429]:
#model.call(x)

In [430]:
model.losses

[]

In [431]:
loss_fn = tf.keras.losses.MeanSquaredError(
    reduction=tf.keras.losses.Reduction.SUM)
optimizer = tf.keras.optimizers.SGD(learning_rate=.01)

In [432]:
def train(x, y, model, loss_fn, optimizer, epochs):
    train_size = len(x)
    for epoch in range(epochs):

        # Needed so that I can call tape.gradient multiple times
        with tf.GradientTape(persistent=True) as tape:
            
            print('before')
            print(model.losses)
            
            #model.call(x)
            pred = model(x)
            loss = loss_fn(y, pred)
            print('after')
            print(loss)
            print(model.losses)



            weight_decay = model.losses
            #print(model.losses)
            
            total_loss = loss + weight_decay 
            #model.losses

        loss_grads = tape.gradient(total_loss, model.trainable_weights)
        optimizer.apply_gradients(zip(loss_grads, model.trainable_weights))

In [433]:
train(x,y, model, loss_fn, optimizer, 3)

before
[]
after
tf.Tensor(22.471128, shape=(), dtype=float32)
[<tf.Tensor: shape=(), dtype=float32, numpy=0.0011562583>]
before
[<tf.Tensor: shape=(), dtype=float32, numpy=0.0039412044>]
after
tf.Tensor(22.195353, shape=(), dtype=float32)
[<tf.Tensor: shape=(), dtype=float32, numpy=0.0039412044>]
before
[<tf.Tensor: shape=(), dtype=float32, numpy=0.0057871793>]
after
tf.Tensor(22.16725, shape=(), dtype=float32)
[<tf.Tensor: shape=(), dtype=float32, numpy=0.0057871793>]


In [419]:
model.losses

[<tf.Tensor: shape=(), dtype=float32, numpy=0.00482824>]

In [401]:
inputs = Input(shape=(10,))
x = MyDense(10)(inputs)
outputs = Dense(1)(x)
model = Model(inputs, outputs)
# Weight regularization.
#model.add_loss(lambda: tf.reduce_mean(x.kernel))

### TFP

In [2]:
tfp.__version__

'0.12.0-dev20200719'

In [7]:
num_groups = 3.
joint_model = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1., name='z_0'),       
    tfd.HalfCauchy(loc=tf.zeros([3]), scale=2., name='lambda_k'),
    lambda lambda_k, z_0: tfd.MultivariateNormalDiag( # z_k ~ MVN(z_0, lambda_k)
        loc=z_0[...,tf.newaxis],
        scale_diag=lambda_k,
        name='z_k'),
])

In [13]:
num_groups = 3.
joint_model = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1., name='z_0'),       
    tfd.Independent(tfd.HalfCauchy(loc=tf.zeros([3]), scale=2., name='lambda_k'), reinterpreted_batch_ndims=1),
    lambda lambda_k, z_0: tfd.MultivariateNormalDiag( # z_k ~ MVN(z_0, lambda_k)
        loc=z_0[...,tf.newaxis],
        scale_diag=lambda_k,
        name='z_k'),
])

In [29]:
def affine(x, kernel_diag, bias=tf.zeros([])):
    """`kernel_diag * x + bias` with broadcasting."""
    kernel_diag = tf.ones_like(x) * kernel_diag
    bias = tf.ones_like(x) * bias
    return x * kernel_diag + bias

In [32]:
num_groups = 3
joint_model = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1., name='z_0'),       
    tfd.HalfCauchy(loc=tf.zeros([3]), scale=2., name='lambda_k'),
    lambda lambda_k, z_0: tfd.MultivariateNormalDiag( # z_k ~ MVN(z_0, lambda_k)
        loc=affine(tf.ones([num_groups]), z_0[...,tf.newaxis]),
        scale_diag=lambda_k,
        name='z_k'),
])

In [14]:
joint_model

<tfp.distributions.JointDistributionSequential 'JointDistributionSequential' batch_shape=[[], [], []] event_shape=[[], [3], [3]] dtype=[float32, float32, float32]>

In [15]:
joint_model.sample(1)

[<tf.Tensor: shape=(1,), dtype=float32, numpy=array([-0.6760397], dtype=float32)>,
 <tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[26.191511  ,  0.46299776,  0.58484864]], dtype=float32)>,
 <tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[-12.368133  ,  -0.30241722,  -0.7617348 ]], dtype=float32)>]

In [16]:
joint_model.log_prob(joint_model.sample())
# ERROR 
joint_model.log_prob(joint_model.sample(4))

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([-17.527115, -22.11274 , -16.917967, -20.59402 ], dtype=float32)>

In [17]:
from tensorflow.keras.layers import Dense

In [111]:
class TestLayer(tf.keras.layers.Layer):
    
    def __init__(self, units, **kwargs):
        super(TestLayer, self).__init__(**kwargs)
        self.units=units
        
        
    def build(self, input_shape):
        self.dense = Dense(self.units)
        
    @tf.function  
    def call(self, x):
        
        out = self.dense(x)
        #import ipdb; ipdb.set_trace()
        self.add_loss(lambda: tf.reduce_sum(tf.square(self.dense.kernel)))
        #self.add_loss(tf.reduce_sum(x))
        return out

In [112]:
# model = tf.keras.models.Sequential([
#     TestLayer(1)
# ])
xi = tf.keras.layers.Input(shape=[1], batch_size=10)
layer = TestLayer(1)
out = layer(xi)
model = tf.keras.models.Model(inputs=xi, outputs=out)

model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.1), loss='mse', metrics='mse')

In [113]:
import numpy as np
np.random.seed(4)
x = tf.convert_to_tensor(np.random.randn(10))[..., tf.newaxis]
y = tf.convert_to_tensor(2*x + 0.5 + 0.1*np.random.randn(10))[..., tf.newaxis]

In [114]:

model.fit(x,y, batch_size=10, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7fc700328790>