In [7]:
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_probability as tfp

tfk = tf.keras
tfd = tfp.distributions
tfb = tfp.bijectors

## tfp.bijectors.Blockwise

In [20]:
# Need 1 unbounded output and the rest bounded

blockwise = tfb.Blockwise(
  bijectors=[tfb.Exp(), tfb.Sigmoid()], block_sizes=[1, 2]
)

In [21]:
blockwise.forward([0., 1., 2.])

<tf.Tensor: shape=(3,), dtype=float32, numpy=array([1.       , 0.7310586, 0.8807971], dtype=float32)>

In [22]:
blockwise.forward([[0., 1., 2.]])

<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[1.       , 0.7310586, 0.8807971]], dtype=float32)>

In [23]:
tfb.Exp().forward(0.), tfb.Sigmoid().forward(1.), tfb.Sigmoid().forward(2.)

(<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.7310586>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.8807971>)

## Enforce output samples in [0, 1] for subset of dimensions

In [24]:
bs = [
    tfb.Identity(),
    tfb.Chain([tfb.Scale(scale=.5), tfb.Shift(shift=1.), tfb.Tanh()]),
    ]
blockwise = tfb.Blockwise(bs, block_sizes=[1, 2])

In [25]:
blockwise.forward([0., 1., 2.])

<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.       , 0.8807971, 0.9820138], dtype=float32)>

In [26]:
tfb.Chain([tfb.Scale(scale=.5), tfb.Shift(shift=1.), tfb.Tanh()]).forward(2.)

<tf.Tensor: shape=(), dtype=float32, numpy=0.9820138>

## Abstract it

In [29]:
n_dim = 3
bounded_idxs = [1, 2]

In [30]:
blockwise_bijectors = []
for i in range(n_dim):
    if i in bounded_idxs:
        sigmoid = tfb.Chain(
            [tfb.Scale(scale=.5), tfb.Shift(shift=1.), tfb.Tanh()],
            )
        blockwise_bijectors.append(sigmoid)
    else:
        blockwise_bijectors.append(tfb.Identity())
blockwise_bijector = tfb.Blockwise(blockwise_bijectors, block_sizes=[1]*n_dim)

In [31]:
blockwise.forward([0., 1., 2.])

<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.       , 0.8807971, 0.9820138], dtype=float32)>

In [32]:
sigmoid.forward(1.), sigmoid.forward(2.)

(<tf.Tensor: shape=(), dtype=float32, numpy=0.8807971>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.9820138>)