In [69]:
import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions
tfb = tfp.bijectors

class Rician(tfd.Distribution):
    def __init__(self, nu, sigma, validate_args=False, allow_nan_stats=True, name="Rician"):
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._nu = tf.convert_to_tensor(nu, name="nu")
            self._sigma = tf.convert_to_tensor(sigma, name="sigma")
            super(Rician, self).__init__(
                dtype=self._nu.dtype,
                reparameterization_type=tfd.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name
            )

    @property
    def nu(self):
        return self._nu

    @property
    def sigma(self):
        return self._sigma

    def _batch_shape_tensor(self):
        return tf.shape(self._nu)

    def _batch_shape(self):
        return self._nu.shape

    def _event_shape_tensor(self):
        return tf.constant([], dtype=tf.int32)

    def _event_shape(self):
        return tf.TensorShape([])

    def _log_prob(self, x):
        x = tf.convert_to_tensor(x, dtype=self.dtype)
        nu = self._nu
        sigma = self._sigma

        log_unnorm = (
            tf.math.log(x) - 2.0 * tf.math.log(sigma)
            - (x**2 + nu**2) / (2.0 * sigma**2)
        )
        bessel_term = tf.math.log(tf.math.bessel_i0(x * nu / (sigma**2)))
        return log_unnorm + bessel_term

    def _sample_n(self, n, seed=None):
        # Sampling: Rician(nu, sigma) = sqrt((X + nu)^2 + Y^2), 
        # where X,Y ~ N(0, sigma^2) iid
        print([n], self._batch_shape_tensor())
        shape = tf.constant([n]) #tf.concat([[n], self._batch_shape_tensor(), self._event_shape_tensor()], axis=0)
        print(shape, self._nu.shape)
        normal = tfd.Normal(loc=0.0, scale=self._sigma)
        x = normal.sample(shape, seed=seed)
        y = normal.sample(shape, seed=seed)
        print(x.shape, y.shape)
        return tf.sqrt((x + self._nu)**2 + y**2)


In [68]:
rician = Rician(nu=2.0, sigma=1.0)
rician.log_prob(1.0)   # should run fine

<tf.Tensor: shape=(), dtype=float32, numpy=-1.6760064>

In [None]:
rician._batch_shape_tensor().numpy()

In [None]:
# Flexible input shape
inputs = tf.keras.Input(shape=(10,))   # last dimension flexible

hidden = tf.keras.layers.Dense(16, activation="relu")(inputs)
params = tf.keras.layers.Dense(2)(hidden)

def rician_fn(params):
    nu, sigma = tf.split(params, 2, axis=-1)
    sigma = tf.nn.softplus(sigma) + 1e-5
    return Rician(nu=nu, sigma=sigma)

outputs = tfp.layers.DistributionLambda(make_distribution_fn=rician_fn)(params)

model = tf.keras.Model(inputs=inputs, outputs=outputs)


In [73]:
# Example: input is an image
inputs = tf.keras.Input(shape=(64, 64, 3))   # 64x64 RGB image (example)

hidden = tf.keras.layers.Conv2D(16, 3, activation="relu", padding="same")(inputs)
params = tf.keras.layers.Conv2D(2, 1, padding="same")(hidden)  
param_nu = tf.keras.layers.Lambda(lambda z: z[...,:1])(params)
param_nu = tf.keras.layers.Flatten()(param_nu)
param_sigma = tf.keras.layers.Lambda(lambda z: z[...,1:])(params)
param_sigma = tf.keras.layers.Flatten()(param_sigma)
params_flat = tf.keras.layers.concatenate([param_nu, param_sigma])
n_out = params_flat.shape[-1]//2
# shape = [batch, 64, 64, 2]  -> channel 0 = nu, channel 1 = sigma

def rician_fn(params):
    nu, sigma = tf.split(params, 2, axis=-1)   # split along last axis
    print(nu.shape, sigma.shape)
    sigma = tf.nn.softplus(sigma) + 1e-5       # ensure positivity
    return Rician(nu=nu, sigma=sigma)

outputs = tfp.layers.DistributionLambda(
                            lambda t: Rician(
                                nu=t[...,:1],
                                sigma=t[...,1:])
                            )(params)

model = tf.keras.Model(inputs=inputs, outputs=outputs)


[1] Tensor("distribution_lambda_32/tensor_coercible/value/Rician/sample/Shape:0", shape=(4,), dtype=int32)
Tensor("distribution_lambda_32/tensor_coercible/value/Rician/sample/Const:0", shape=(1,), dtype=int32) (None, 64, 64, 1)
(1, None, 64, 64, 1) (1, None, 64, 64, 1)


In [74]:
x = tf.random.normal((1,64,64,3),dtype=tf.float32)
y = model(x)
print(y.shape)

[1] tf.Tensor([ 1 64 64  1], shape=(4,), dtype=int32)
tf.Tensor([1], shape=(1,), dtype=int32) (1, 64, 64, 1)
(1, 1, 64, 64, 1) (1, 1, 64, 64, 1)
(1, 64, 64, 1)


In [81]:
print(y.nu, y.sample())

[1] tf.Tensor([ 1 64 64  1], shape=(4,), dtype=int32)
tf.Tensor([1], shape=(1,), dtype=int32) (1, 64, 64, 1)
(1, 1, 64, 64, 1) (1, 1, 64, 64, 1)
tf.Tensor(
[[[[ 0.3237818 ]
   [ 0.14536268]
   [ 0.10078318]
   ...
   [ 0.33602127]
   [ 0.5561326 ]
   [ 0.02317528]]

  [[ 0.5800871 ]
   [-0.28341073]
   [ 0.6700015 ]
   ...
   [ 0.05599475]
   [ 0.24417934]
   [ 0.0644777 ]]

  [[ 0.7425677 ]
   [-0.230116  ]
   [ 0.30500317]
   ...
   [-0.21282737]
   [ 0.30063182]
   [ 0.0340409 ]]

  ...

  [[ 0.1470594 ]
   [-0.64656484]
   [ 0.31771302]
   ...
   [ 0.40191418]
   [-0.43665752]
   [ 0.30076498]]

  [[-0.339457  ]
   [ 1.5285726 ]
   [-0.09898955]
   ...
   [ 0.05756407]
   [ 0.97202563]
   [ 0.39080742]]

  [[ 0.05871027]
   [-0.06961875]
   [-0.02006909]
   ...
   [-0.27445498]
   [ 0.25613886]
   [ 0.13800645]]]], shape=(1, 64, 64, 1), dtype=float32) tf.Tensor(
[[[[0.13576257]
   [0.40081102]
   [0.9177094 ]
   ...
   [2.350984  ]
   [0.5635589 ]
   [0.3954569 ]]

  [[1.6421397 ]
