In [63]:
import jax.numpy as jnp
import jax.random as random
import jax
import numpy as np
import jax.lax as lax

In [18]:
x = [jnp.exp(2)]

In [29]:
normal_pdf = jax.scipy.stats.norm.pdf

In [30]:
normal_pdf?

In [45]:
normal_pdf(x=0, loc=(0,0), scale=(8,8))

TypeError: norm.logpdf requires ndarray or scalar arguments, got <class 'tuple'> at position 1.

In [69]:
trans_mat = jnp.array( [[0.7,0.3],[0.2,0.8]])

In [52]:
trans_mat[0,1]

DeviceArray(0.3, dtype=float32)

In [59]:
jnp.dot( trans_mat, np.array([1,1]) )

DeviceArray([1., 1.], dtype=float32)

In [120]:
def forward_backward_gaussian(trans_mat, means, standard_devs, init_probs, obs_data):
    def normal_pdf_vec(obs,means,standard_devs):
        return jnp.array([jax.scipy.stats.norm.pdf(obs, means[i], standard_devs[i]) for i in range(len(means))])

    # Compute forward probabilities using scan
    def forward_scan_fun(carry, obs_t):
        alpha_t = jnp.dot(carry, trans_mat) * normal_pdf_vec(obs_t, means, standard_devs)
        # alpha_t /= jnp.sum(alpha_t)
        return alpha_t, alpha_t  # so carry is alpha_t, and the y (stored output) is alpha_t

    forward_init = init_probs*normal_pdf_vec(obs_data[0],means,standard_devs)
    _, forward_after_init = lax.scan(f=forward_scan_fun, init=forward_init, xs=obs_data[1:])
        # outputs stack of y
    forward = jnp.append( jnp.array([forward_init]), forward_after_init, axis=0 )

    # Compute backward probabilities using scan
    def backward_scan_fun(carry, obs_t):
        beta_t = jnp.dot(trans_mat, carry * normal_pdf_vec(obs_t, means, standard_devs))
        # beta_t /= jnp.sum(beta_t)
        return beta_t, beta_t

    backward_init = jnp.array([ float(1) for _ in range(trans_mat.shape[0])])
    _, backward_after_init = lax.scan(f=backward_scan_fun, init= backward_init, xs=obs_data[1:], reverse=True)
    backward = jnp.append( backward_after_init, jnp.array([backward_init]), axis=0 )

    return forward, backward

In [121]:
out = forward_backward_gaussian(trans_mat, means=[-100,100], standard_devs=[1,1], init_probs=jnp.array([0.5,0.5]), obs_data=jnp.array([100,100,100,100,100]))

In [122]:
out

(DeviceArray([[0.        , 0.19947115],
              [0.        , 0.06366199],
              [0.        , 0.02031797],
              [0.        , 0.00648456],
              [0.        , 0.00206957]], dtype=float32),
 DeviceArray([[0.00389073, 0.01037529],
              [0.01219078, 0.03250875],
              [0.03819719, 0.10185917],
              [0.11968269, 0.31915385],
              [1.        , 1.        ]], dtype=float32))

In [83]:
obs_data = jnp.array([1,2,3,4,5])

In [86]:
obs_data[1:]

DeviceArray([2, 3, 4, 5], dtype=int32)

In [93]:
jnp.append( jnp.array([1,1]), out[0] )

DeviceArray([1.0000000e+00, 1.0000000e+00, 5.5704232e-02, 0.0000000e+00,
             0.0000000e+00, 6.6668326e-03, 0.0000000e+00, 2.1277452e-03,
             0.0000000e+00, 6.7907805e-04, 0.0000000e+00, 2.1673038e-04],            dtype=float32)

In [94]:
jnp.append?

In [108]:
out[0]

DeviceArray([[0.        , 0.06366199],
             [0.        , 0.02031797],
             [0.        , 0.00648456],
             [0.        , 0.00206957]], dtype=float32)

In [112]:
alpha_arr = jnp.append( jnp.array([[1,1]]) , out[0], axis = 0 )

In [113]:
alpha_arr

DeviceArray([[1.        , 1.        ],
             [0.        , 0.06366199],
             [0.        , 0.02031797],
             [0.        , 0.00648456],
             [0.        , 0.00206957]], dtype=float32)

In [123]:
out

(DeviceArray([[0.        , 0.19947115],
              [0.        , 0.06366199],
              [0.        , 0.02031797],
              [0.        , 0.00648456],
              [0.        , 0.00206957]], dtype=float32),
 DeviceArray([[0.00389073, 0.01037529],
              [0.01219078, 0.03250875],
              [0.03819719, 0.10185917],
              [0.11968269, 0.31915385],
              [1.        , 1.        ]], dtype=float32))

In [132]:
out[0][0]

DeviceArray([0.        , 0.19947115], dtype=float32)

In [133]:
jnp.bincount?

In [134]:
import numpy as np

In [135]:
np.asmatrix([1,1,1
             ])

matrix([[1, 1, 1]])

In [136]:
jnp.asarray([1,1,1])

DeviceArray([1, 1, 1], dtype=int32)