In [1]:
import DLlib as dl
import tf2gan as gan
import wflib as wf

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions
tfb = tfp.bijectors

class Rician(tfd.Distribution):
    def __init__(self, nu, sigma, validate_args=False, allow_nan_stats=True, name="Rician"):
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._nu = tf.convert_to_tensor(nu, name="nu")
            sigma_aux = tf.convert_to_tensor(sigma, name="sigma")
            self._sigma = tf.where(sigma_aux<1e-16,1e-16,sigma_aux)
            super(Rician, self).__init__(
                dtype=self._nu.dtype,
                reparameterization_type=tfd.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name
            )

    @property
    def nu(self):
        return self._nu

    @property
    def sigma(self):
        return self._sigma

    def _batch_shape_tensor(self):
        return tf.broadcast_dynamic_shape(tf.shape(self._nu), tf.shape(self._sigma))

    def _batch_shape(self):
        return tf.broadcast_static_shape(self._nu.shape, self._sigma.shape)

    def _event_shape_tensor(self):
        return tf.constant([], dtype=tf.int32)

    def _event_shape(self):
        return tf.TensorShape([])

    def _log_prob(self, x):
        x = tf.convert_to_tensor(x, dtype=self.dtype)
        nu = self._nu
        sigma = self._sigma

        # Compute argument of the Bessel function
        arg = x * nu / (sigma**2)

        # Use exponentially scaled Bessel function for numerical stability:
        # log(I0(x)) = log(I0e(x)) + |x|
        log_bessel = tf.math.log(tf.math.bessel_i0e(arg)) + tf.abs(arg)

        # Combine all terms
        log_unnorm = (
            tf.math.log(x) - 2.0 * tf.math.log(sigma)
            - (x**2 + nu**2) / (2.0 * sigma**2)
        )
        return log_unnorm + log_bessel

    def _sample_n(self, n, seed=None):
        # Sampling: Rician(nu, sigma) = sqrt((X + nu)^2 + Y^2), 
        # where X,Y ~ N(0, sigma^2) iid
        shape = tf.constant([n])
        normal = tfd.Normal(loc=0.0, scale=self._sigma)
        x = normal.sample(shape, seed=seed)
        y = normal.sample(shape, seed=seed)
        return tf.sqrt((x + self._nu)**2 + y**2)

    def _mean(self):
        nu = self._nu
        sigma = self._sigma

        x = -tf.square(nu) / (2.0 * tf.square(sigma))
        half_x = -x / 2.0

        # Recover true I0 and I1 in log-space: I_n(x) = I_n^e(x) * exp(|x|)
        exp_half_x = tf.exp(tf.abs(half_x))

        # Compute L_{1/2}(x) = exp(x/2) * [ (1 - x) I0(-x/2) - x I1(-x/2) ]
        log_exp_term = x / 2.0 + tf.abs(half_x)
        log_L = log_exp_term + tf.math.log(
            tf.abs((1.0 - x) * tf.math.bessel_i0e(half_x) - x * tf.math.bessel_i1e(half_x)) + 1e-12
        )
        L = tf.exp(log_L)

        return sigma * tf.sqrt(np.pi / 2.0) * L

    def _variance(self):
        nu = self._nu
        sigma = self._sigma

        x = -tf.square(nu) / (2.0 * tf.square(sigma))
        half_x = -x / 2.0

        log_exp_term = x / 2.0 + tf.abs(half_x)
        L = tf.exp(log_exp_term) * (
            (1.0 - x) * tf.math.bessel_i0e(half_x)
            - x * tf.math.bessel_i1e(half_x)
        )

        return (
            2.0 * tf.square(sigma)
            + tf.square(nu)
            - (np.pi * tf.square(sigma) / 2.0) * tf.square(L)
        )


In [None]:
rician = Rician(nu=[1e-2,2e-2,5e-2,10e-2,20e-2,30e-2,50e-2,100e-2], sigma=tf.repeat([4e-2],8))
#[1e-18,1e-17,1e-16,1e-15,1e-14,1e15,1e16,1e17,1e18,1e19,1e20]
#[1.0,2.0,5.0,10.0,20.0,30.0,50.0,100.0]
rician.log_prob(1.0)   # should run fine

In [None]:
rician.sample(5)*200

In [None]:
print(200*rician.mean(),200*rician.variance())

In [None]:
# Flexible input shape
inputs = tf.keras.Input(shape=(10,))   # last dimension flexible

hidden = tf.keras.layers.Dense(16, activation="relu")(inputs)
params = tf.keras.layers.Dense(2)(hidden)

def rician_fn(params):
    nu, sigma = tf.split(params, 2, axis=-1)
    sigma = tf.nn.softplus(sigma) + 1e-5
    return Rician(nu=nu, sigma=sigma)

outputs = tfp.layers.DistributionLambda(make_distribution_fn=rician_fn)(params)

model = tf.keras.Model(inputs=inputs, outputs=outputs)


In [None]:
# Example: input is an image
inputs = tf.keras.Input(shape=(64, 64, 3))   # 64x64 RGB image (example)

hidden = tf.keras.layers.Conv2D(16, 3, activation="relu", padding="same")(inputs)
params = tf.keras.layers.Conv2D(2, 1, activation="sigmoid", padding="same")(hidden)  
param_nu = tf.keras.layers.Lambda(lambda z: z[...,:1])(params)
param_nu = tf.keras.layers.Flatten()(param_nu)
param_sigma = tf.keras.layers.Lambda(lambda z: z[...,1:])(params)
param_sigma = tf.keras.layers.Flatten()(param_sigma)
params_flat = tf.keras.layers.concatenate([param_nu, param_sigma])
n_out = params_flat.shape[-1]//2
# shape = [batch, 64, 64, 2]  -> channel 0 = nu, channel 1 = sigma

def rician_fn(params):
    nu, sigma = tf.split(params, 2, axis=-1)   # split along last axis
    print(nu.shape, sigma.shape)
    sigma = tf.nn.softplus(sigma) + 1e-5       # ensure positivity
    return Rician(nu=nu, sigma=sigma)

outputs = tfp.layers.DistributionLambda(
                            lambda t: Rician(
                                nu=t[...,:1],
                                sigma=t[...,1:])
                            )(params)

model = tf.keras.Model(inputs=inputs, outputs=outputs)


In [None]:
x = tf.random.normal((2,64,64,3),dtype=tf.float32)
y = model(x)
print(y.shape)

In [None]:
tf.math.reduce_all(tf.math.is_nan(y.variance()))

In [None]:
tf.math.reduce_all(tf.math.is_nan(y.mean()))

In [None]:
tf.concat([[1], tf.shape(y.nu)], axis=0)

## Try WF-lib operators

In [18]:
A = tf.concat([tf.random.normal((1,4,64,64,1),dtype=tf.float32),
              tf.random.normal((1,4,64,64,1),dtype=tf.float32)],axis=-1)
A_abs = tf.math.sqrt(tf.reduce_sum(tf.square(A),axis=-1,keepdims=True))
A_abs = tf.where(A_abs<0.01,0.0,A_abs)

G_A2R2 = dl.UNet(input_shape=(None,64,64,1),bayesian=True,ME_layer=True,filters=12,output_activation='softplus',output_initializer='he_uniform')

In [21]:
G_A2B = dl.UNet(input_shape=(None,64,64,2),bayesian=True,ME_layer=True,filters=12)
uncertain_loss_R2 = gan.VarMeanSquaredErrorR2()
G_R2_optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3, beta_1=0.9, beta_2=0.999)

In [29]:
with tf.GradientTape() as t:
    A2B_R2 = G_A2R2(A_abs, training=True)
    A2B_FM = G_A2B(A, training=True)

    A2B_PM = tf.concat([A2B_FM,A2B_R2], axis=-1)
    A2B_WF, A2B2A_abs = wf.acq_to_acq(A, A2B_PM, only_mag=True)
    A2B2A_abs = tf.where(A[...,:1]!=0.0,A2B2A_abs,0.0)
    print(tf.math.reduce_any(tf.math.is_nan(A2B2A_abs)).numpy())

    A2B2A_var =  wf.acq_uncertainty(tf.stop_gradient(A2B_WF), A2B_FM, A2B_R2, ne=A.shape[1], rem_R2=False, only_mag=True)
    A2B2A_sampled_var = tf.concat([A2B2A_abs, A2B2A_var], axis=-1) # shape: [nb,ne,hgt,wdt,2]
    print(tf.math.reduce_any(tf.math.is_nan(A2B2A_var)).numpy(), tf.reduce_max(A2B2A_var).numpy())

    A2B2A_cycle_loss = uncertain_loss_R2(A_abs, A2B2A_sampled_var)
    print(A2B2A_cycle_loss.numpy())

    R2_TV = tf.reduce_sum(tf.image.total_variation(A2B_R2[:,0,:,:,:]))
    R2_L1 = tf.reduce_sum(tf.reduce_mean(tf.abs(A2B_R2),axis=(1,2,3,4)))
    reg_term = R2_TV * 0.0 + R2_L1 * 0.0
    
    G_loss = A2B2A_cycle_loss + reg_term
    print(G_loss.numpy())

    G_grad = t.gradient(G_loss, G_A2R2.trainable_variables)
    for gg in G_grad:
        tf.debugging.assert_all_finite(gg, 'Applied gradients must be all finite')
        tf.debugging.assert_type(gg, tf_type= tf.float32)
        #print(tf.math.reduce_any(tf.math.is_nan(gg)).numpy(), tf.math.reduce_max(gg).numpy())
    G_R2_optimizer.apply_gradients(zip(G_grad, G_A2R2.trainable_variables))

False
False 322035600000.0
4.2334957
4.2334957


In [30]:
tf.math.log(1e-5).numpy()

-11.512925