In [1]:
import sympy
import sympy.stats

In [2]:
temp: list[sympy.Symbol] = sympy.symbols(r'z_t z_{t-1} \gamma \beta_t k_t z_0')
(zt, ztm1, gamma, betat, kt, z0) = temp

# the method in the paper is equivalent to setting gamma = 0.5.

# eq. 5
ktm1 = kt / (1-betat)
bt = gamma*(1-kt)
btm1 = gamma*(1-ktm1)


def bernoulli_posterior(x, positive_prob) -> sympy.Expr:
    '''
    This function computes p(x|conditions) for a variable x that 
    follows a Bernoulli distribution B(x; positive_prob(conditions)), where positive_prob is a function of the conditions.
    '''
    return x*(positive_prob) + (1-x)*(1-positive_prob)

# eq. 4
# q(z^t|z^{t-1}) = B(zt; ztm1(1-betat)+gamma*betat)
q_zt_ztm1 = bernoulli_posterior(zt, ztm1*(1-betat)+gamma*betat)

# eq. 5
# q(z^t|z^0) = B(zt; kt*z0+bt)
q_zt_z0 = bernoulli_posterior(zt, kt*z0+bt)

# eq. 5
# q(z^{t-1}|z^0) = B(ztm1; ktm1*z0+btm1)
q_ztm1_z0 = bernoulli_posterior(ztm1, ktm1*z0+btm1)

# the denoizer predicts p_\theta(z^0|z^t) = B(z^0; pred1)
pred1 = sympy.symbols('pred1')
pred0 = 1-pred1

# eq. 9
# q(z^{t-1}|z^t, z^0) =  q(z^t|z^{t-1}) * q(z^{t-1}|z^0) / q(z^t|z^0)
q_ztm1_zt_z0 = (q_zt_ztm1 * q_ztm1_z0) / q_zt_z0
q_ztm1_zt_z0.simplify()

-(z_t*(\beta_t*\gamma - z_{t-1}*(\beta_t - 1)) - (z_t - 1)*(-\beta_t*\gamma + z_{t-1}*(\beta_t - 1) + 1))*(z_{t-1}*(\gamma*(\beta_t - 1)*(\beta_t + k_t - 1) - k_t*z_0*(\beta_t - 1)) + (z_{t-1} - 1)*(\gamma*(\beta_t - 1)*(\beta_t + k_t - 1) + k_t*z_0*(1 - \beta_t) - (\beta_t - 1)**2))/((\beta_t - 1)**2*(z_t*(\gamma*(k_t - 1) - k_t*z_0) + (z_t - 1)*(\gamma*(k_t - 1) - k_t*z_0 + 1)))

In [20]:
# subsititute the variables that we don't care much about in the following steps.

# q(z^{t-1}=1|z^t, z^0)
q_ztm1_eq_1_zt_z0 = q_ztm1_zt_z0.subs(gamma, sympy.Rational(.5)).subs(betat, 0.2).subs(zt, 0).subs(kt, 0.1).subs(ztm1, 1)

In [32]:
# eq 8 (ground truth)
q_z0_eq_1 = q_ztm1_eq_1_zt_z0.subs(z0,1)
q_z0_eq_0 = q_ztm1_eq_1_zt_z0.subs(z0,0)
p_ztm1 = q_z0_eq_1*pred1 + q_z0_eq_0*pred0
d = p_ztm1.simplify()
d

ValueError: only one element tensors can be converted to Python scalars

In [29]:
# eq. 10 on the paper. Its numerical result does not match eq. 8.
import IPython
import IPython.display


e = ((1-betat)*zt+gamma*betat)*(kt*pred1+bt*0.5)
f = ((1-betat)*(1-zt)+gamma*betat)*(kt*pred0+bt*0.5)

# the eq. 10 claims this is the correct value of pθ(zt−1|zt) 
g = (e/(e+f)).subs(gamma, sympy.Rational(.5)).subs(betat, 0.2).subs(zt, 0).subs(kt, 0.1).simplify()
IPython.display.display(g)
IPython.display.display((d-g).simplify()) # large

(-0.01*pred1 - 0.0225)/(0.08*pred1 - 0.315)

(0.00363636363636364*pred1**2 + 0.00204545454545456*pred1 - 0.00255681818181817)/(0.08*pred1 - 0.315)

In [30]:
# This should be the correct version of eq. 10. But it only works for pred1 = 0 or 1.
e = ((1-betat)*zt+gamma*betat)*(kt*pred1+bt-gamma*betat)
f = ((1-betat)*(1-zt)+gamma*betat)*(kt*pred0+bt-gamma*betat)


g = (e/(e+f)).subs(gamma, sympy.Rational(.5)).subs(betat, 0.2).subs(zt, 0).subs(kt, 0.1).simplify()
IPython.display.display(g)
IPython.display.display((d-g).simplify()) # small. zero when pred1 = 0 or 1.

(-0.01*pred1 - 0.035)/(0.08*pred1 - 0.44)

(0.00363636363636364*pred1**2 - 0.00363636363636364*pred1 + 6.93889390390723e-18)/(0.08*pred1 - 0.44)

In [None]:
# The paper used this idea that directly plugs in pred1 into z0. Works for pred1 = 0 or 1.
d = q_ztm1_eq_1_zt_z0.subs(z0, pred1)
d.simplify()

(-0.0125*pred1 - 0.04375)/(0.1*pred1 - 0.55)