In [3]:
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import pandas as pd
import arviz as az
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px

In [11]:
tfd = tfp.distributions
tfb = tfp.bijectors
root = tfd.JointDistributionCoroutine.Root

## Latent AR Process and Smoothing (GRW)

In [2]:
num_steps = 100 

x = np.linspace(0, 50, num_steps)
f = np.exp(1.0 + np.power(x, 0.5) - np.exp(x/15.0))
y = f+ np.random.normal(scale= 1.0, size= x.shape)

In [20]:
fig = go.Figure()

fig.add_trace(go.Scatter(x=x, y=y, mode='markers', name='observed'))
fig.add_trace(go.Scatter(x=x, y=f, mode='lines', name='f(x)'))

In [21]:
@tfd.JointDistributionCoroutine
def smoothing_grw():
   alpha = yield root(tfd.Beta(5., 1.))
   variance = yield root(tfd.HalfNormal(10.))
   sigma0 = tf.sqrt(variance * alpha)
   sigma1 = tf.sqrt(variance * (1. - alpha))
   z = yield tfd.Sample(tfd.Normal(0., sigma0), num_steps)
   observed = yield tfd.Independent(tfd.Normal(tf.math.cumsum(z, axis= -1), sigma1[..., None]), name= "obs")

In [22]:
run_mcmc = tf.function(tfp.experimental.mcmc.windowed_adaptive_nuts, autograph= False, jit_compile= True)

In [23]:
%%time
mcmc_samples, sampler_stats = run_mcmc(1000, smoothing_grw, n_chains= 4, num_adaptation_steps= 1000, obs= tf.constant(y[None, ...], dtype= tf.float32))

2024-01-28 16:16:15.402543: W tensorflow/compiler/tf2xla/kernels/assert_op.cc:38] Ignoring Assert operator mcmc_retry_init/assert_equal_1/Assert/AssertGuard/Assert
I0000 00:00:1706454990.567371       1 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
2024-01-28 16:16:30.832007: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream.


CPU times: user 1min 2s, sys: 874 ms, total: 1min 3s
Wall time: 1min 4s


In [29]:
nsample, nchain = mcmc_samples[-1].shape[:2]

z = tf.reshape(tf.math.cumsum(mcmc_samples[-1], axis = -1), [nsample*nchain, -1])

fig = go.Figure()
fig.add_trace(go.Scatter(x=x, y=y, mode='markers', name='observed'))
fig.add_trace(go.Scatter(x=x, y=f, mode='lines', name='f(x)'))
fig.add_trace(go.Scatter(x=x, y=tf.reduce_mean(z, axis= 0), name= "z"))
fig.update_layout(xaxis_title= "x", yaxis_title= "y", title = "Latent AR & Smoothing with GRW")