In [1]:
import sympy as sp
from sympy import print_latex

In [2]:
# Define the symbols
S, w_1t, w_t0, w_10 = sp.symbols('S w_1t w_t0 w_10', real=True, positive=True)
i, j, k = sp.symbols('i j k', integer=True)

\begin{align*}
    p(x_t=k|x_0=i,x_1=j) = \frac{p(x_1|x_t=k) p(x_t=k|x_0)}{p(x_1|x_0)}
\end{align*}}

In [3]:
# Define the probability function
x_1 = i
x_0 = j
x_t = k
P_10 = 1/S + w_10 * (-1/S + sp.KroneckerDelta(i, j))
P_1t = 1/S + w_1t * (-1/S + sp.KroneckerDelta(i, k))
P_t0 = 1/S + w_t0 * (-1/S + sp.KroneckerDelta(k, j))

P_bridge = (P_1t*P_t0)/P_10

In [4]:
P_bridge.simplify()

(w_1t*(S*KroneckerDelta(i, k) - 1) + 1)*(w_t0*(S*KroneckerDelta(j, k) - 1) + 1)/(S*(w_10*(S*KroneckerDelta(i, j) - 1) + 1))

In [5]:
# Calculate the expected value (mean)
mean_expr = sp.Sum(i * P_bridge, (k, 1, S))
mean = mean_expr.doit().simplify()
mean

Piecewise((i*(S*w_1t*w_t0*KroneckerDelta(i, j) - w_1t*w_t0 + 1)/(w_10*(S*KroneckerDelta(i, j) - 1) + 1), (S >= i) & (S >= j) & (i >= 1) & (j >= 1)), (i*(1 - w_1t)/(w_10*(S*KroneckerDelta(i, j) - 1) + 1), (S >= j) & (j >= 1)), (i*(S*w_1t*w_t0*KroneckerDelta(i, j) - w_t0 + 1)/(w_10*(S*KroneckerDelta(i, j) - 1) + 1), (S >= i) & (i >= 1)), (i*(w_1t*w_t0 - w_1t - w_t0 + 1)/(w_10*(S*KroneckerDelta(i, j) - 1) + 1), True))

In [6]:
mean_ = mean.args[0].args[0]
mean_

i*(S*w_1t*w_t0*KroneckerDelta(i, j) - w_1t*w_t0 + 1)/(w_10*(S*KroneckerDelta(i, j) - 1) + 1)

In [30]:
print_latex(mean_)

\frac{i \left(S w_{1t} w_{t0} \delta_{i j} - w_{1t} w_{t0} + 1\right)}{w_{10} \left(S \delta_{i j} - 1\right) + 1}


In [27]:
# square the mean
mean_squared = mean**2
mean_squared.simplify()
mean_squared

Piecewise((i**2*(S*w_1t*w_t0*KroneckerDelta(i, j) - w_1t*w_t0 + 1)**2/(w_10*(S*KroneckerDelta(i, j) - 1) + 1)**2, (S >= i) & (S >= j) & (i >= 1) & (j >= 1)), (i**2*(1 - w_1t)**2/(w_10*(S*KroneckerDelta(i, j) - 1) + 1)**2, (S >= j) & (j >= 1)), (i**2*(S*w_1t*w_t0*KroneckerDelta(i, j) - w_t0 + 1)**2/(w_10*(S*KroneckerDelta(i, j) - 1) + 1)**2, (S >= i) & (i >= 1)), (i**2*(w_1t*w_t0 - w_1t - w_t0 + 1)**2/(w_10*(S*KroneckerDelta(i, j) - 1) + 1)**2, True))

In [28]:
mean_squared_ = mean_squared.args[0].args[0]
mean_squared_

i**2*(S*w_1t*w_t0*KroneckerDelta(i, j) - w_1t*w_t0 + 1)**2/(w_10*(S*KroneckerDelta(i, j) - 1) + 1)**2

In [29]:
print_latex(mean_squared_)

\frac{i^{2} \left(S w_{1t} w_{t0} \delta_{i j} - w_{1t} w_{t0} + 1\right)^{2}}{\left(w_{10} \left(S \delta_{i j} - 1\right) + 1\right)^{2}}


In [10]:
# Calculate the second moment E[X^2]
second_moment_expr = sp.Sum(k**2 * P_bridge, (k, 1, S))
second_moment = second_moment_expr.doit()
second_moment 

Piecewise((i**2*w_1t*w_t0*KroneckerDelta(i, j)/(w_10*(KroneckerDelta(i, j) - 1/S) + 1/S) + w_1t*w_t0*(S**3/3 + S**2/2 + S/6)/(S**2*w_10*KroneckerDelta(i, j) - S*w_10 + S) - w_1t*(S**3/3 + S**2/2 + S/6)/(S**2*w_10*KroneckerDelta(i, j) - S*w_10 + S) - w_t0*(S**3/3 + S**2/2 + S/6)/(S**2*w_10*KroneckerDelta(i, j) - S*w_10 + S) + (S**3/3 + S**2/2 + S/6)/(S**2*w_10*KroneckerDelta(i, j) - S*w_10 + S) - i**2*w_1t*w_t0/(S*(w_10*(KroneckerDelta(i, j) - 1/S) + 1/S)) + i**2*w_1t/(S*(w_10*(KroneckerDelta(i, j) - 1/S) + 1/S)) - j**2*w_1t*w_t0/(S*(w_10*(KroneckerDelta(i, j) - 1/S) + 1/S)) + j**2*w_t0/(S*(w_10*(KroneckerDelta(i, j) - 1/S) + 1/S)), (S >= i) & (S >= j) & (i >= 1) & (j >= 1)), (w_1t*w_t0*(S**3/3 + S**2/2 + S/6)/(S**2*w_10*KroneckerDelta(i, j) - S*w_10 + S) - w_1t*(S**3/3 + S**2/2 + S/6)/(S**2*w_10*KroneckerDelta(i, j) - S*w_10 + S) - w_t0*(S**3/3 + S**2/2 + S/6)/(S**2*w_10*KroneckerDelta(i, j) - S*w_10 + S) + (S**3/3 + S**2/2 + S/6)/(S**2*w_10*KroneckerDelta(i, j) - S*w_10 + S) - j**2*w_

In [11]:
second_moment_ = second_moment.args[0].args[0]
second_moment_

i**2*w_1t*w_t0*KroneckerDelta(i, j)/(w_10*(KroneckerDelta(i, j) - 1/S) + 1/S) + w_1t*w_t0*(S**3/3 + S**2/2 + S/6)/(S**2*w_10*KroneckerDelta(i, j) - S*w_10 + S) - w_1t*(S**3/3 + S**2/2 + S/6)/(S**2*w_10*KroneckerDelta(i, j) - S*w_10 + S) - w_t0*(S**3/3 + S**2/2 + S/6)/(S**2*w_10*KroneckerDelta(i, j) - S*w_10 + S) + (S**3/3 + S**2/2 + S/6)/(S**2*w_10*KroneckerDelta(i, j) - S*w_10 + S) - i**2*w_1t*w_t0/(S*(w_10*(KroneckerDelta(i, j) - 1/S) + 1/S)) + i**2*w_1t/(S*(w_10*(KroneckerDelta(i, j) - 1/S) + 1/S)) - j**2*w_1t*w_t0/(S*(w_10*(KroneckerDelta(i, j) - 1/S) + 1/S)) + j**2*w_t0/(S*(w_10*(KroneckerDelta(i, j) - 1/S) + 1/S))

In [14]:
second_moment_.args[0]

(S**3/3 + S**2/2 + S/6)/(S**2*w_10*KroneckerDelta(i, j) - S*w_10 + S)

In [15]:
second_moment_.args[1]

-w_1t*(S**3/3 + S**2/2 + S/6)/(S**2*w_10*KroneckerDelta(i, j) - S*w_10 + S)

In [16]:
second_moment_.args[2]

-w_t0*(S**3/3 + S**2/2 + S/6)/(S**2*w_10*KroneckerDelta(i, j) - S*w_10 + S)

In [17]:
second_moment_.args[3]

w_1t*w_t0*(S**3/3 + S**2/2 + S/6)/(S**2*w_10*KroneckerDelta(i, j) - S*w_10 + S)

In [18]:
second_moment_.args[4]

i**2*w_1t/(S*(w_10*(KroneckerDelta(i, j) - 1/S) + 1/S))

In [19]:
second_moment_.args[5]

j**2*w_t0/(S*(w_10*(KroneckerDelta(i, j) - 1/S) + 1/S))

In [20]:
second_moment_.args[6]

i**2*w_1t*w_t0*KroneckerDelta(i, j)/(w_10*(KroneckerDelta(i, j) - 1/S) + 1/S)

In [21]:
second_moment_.args[7]

-i**2*w_1t*w_t0/(S*(w_10*(KroneckerDelta(i, j) - 1/S) + 1/S))

In [22]:
second_moment_.args[8]

-j**2*w_1t*w_t0/(S*(w_10*(KroneckerDelta(i, j) - 1/S) + 1/S))

In [32]:
# Compute the variance Var(X) = E[X^2] - mean^2
variance_expr = second_moment - mean_squared
variance = variance_expr.simplify()
variance

Piecewise(((-6*i**2*(S*w_10*KroneckerDelta(i, j) - w_10 + 1)*(S*w_1t*w_t0*KroneckerDelta(i, j) - w_1t*w_t0 + 1)**2 + (w_10*(S*KroneckerDelta(i, j) - 1) + 1)*((w_10*(S*KroneckerDelta(i, j) - 1) + 1)*(2*S**2 + 3*S + w_1t*w_t0*(2*S**2 + 3*S + 1) - w_1t*(2*S**2 + 3*S + 1) - w_t0*(2*S**2 + 3*S + 1) + 1) + 6*(S*w_10*KroneckerDelta(i, j) - w_10 + 1)*(S*i**2*w_1t*w_t0*KroneckerDelta(i, j) - i**2*w_1t*w_t0 + i**2*w_1t - j**2*w_1t*w_t0 + j**2*w_t0)))/(6*(w_10*(S*KroneckerDelta(i, j) - 1) + 1)**2*(S*w_10*KroneckerDelta(i, j) - w_10 + 1)), (S >= i) & (S >= j) & (i >= 1) & (j >= 1)), ((-6*i**2*(w_1t - 1)**2*(S*w_10*KroneckerDelta(i, j) - w_10 + 1) + (w_10*(S*KroneckerDelta(i, j) - 1) + 1)*(-6*j**2*w_t0*(w_1t - 1)*(S*w_10*KroneckerDelta(i, j) - w_10 + 1) + (w_10*(S*KroneckerDelta(i, j) - 1) + 1)*(2*S**2 + 3*S + w_1t*w_t0*(2*S**2 + 3*S + 1) - w_1t*(2*S**2 + 3*S + 1) - w_t0*(2*S**2 + 3*S + 1) + 1)))/(6*(w_10*(S*KroneckerDelta(i, j) - 1) + 1)**2*(S*w_10*KroneckerDelta(i, j) - w_10 + 1)), (S >= j) & (j 

In [34]:
print_latex(variance_)

\frac{- 6 i^{2} \left(S w_{10} \delta_{i j} - w_{10} + 1\right) \left(S w_{1t} w_{t0} \delta_{i j} - w_{1t} w_{t0} + 1\right)^{2} + \left(w_{10} \left(S \delta_{i j} - 1\right) + 1\right) \left(\left(w_{10} \left(S \delta_{i j} - 1\right) + 1\right) \left(2 S^{2} + 3 S + w_{1t} w_{t0} \cdot \left(2 S^{2} + 3 S + 1\right) - w_{1t} \left(2 S^{2} + 3 S + 1\right) - w_{t0} \cdot \left(2 S^{2} + 3 S + 1\right) + 1\right) + 6 \left(S w_{10} \delta_{i j} - w_{10} + 1\right) \left(S i^{2} w_{1t} w_{t0} \delta_{i j} - i^{2} w_{1t} w_{t0} + i^{2} w_{1t} - j^{2} w_{1t} w_{t0} + j^{2} w_{t0}\right)\right)}{6 \left(w_{10} \left(S \delta_{i j} - 1\right) + 1\right)^{2} \left(S w_{10} \delta_{i j} - w_{10} + 1\right)}


In [44]:
print_latex(second_moment_)

\frac{i^{2} w_{1t} w_{t0} \delta_{i j}}{w_{10} \left(\delta_{i j} - \frac{1}{S}\right) + \frac{1}{S}} + \frac{w_{1t} w_{t0} \left(\frac{S^{3}}{3} + \frac{S^{2}}{2} + \frac{S}{6}\right)}{S^{2} w_{10} \delta_{i j} - S w_{10} + S} - \frac{w_{1t} \left(\frac{S^{3}}{3} + \frac{S^{2}}{2} + \frac{S}{6}\right)}{S^{2} w_{10} \delta_{i j} - S w_{10} + S} - \frac{w_{t0} \left(\frac{S^{3}}{3} + \frac{S^{2}}{2} + \frac{S}{6}\right)}{S^{2} w_{10} \delta_{i j} - S w_{10} + S} + \frac{\frac{S^{3}}{3} + \frac{S^{2}}{2} + \frac{S}{6}}{S^{2} w_{10} \delta_{i j} - S w_{10} + S} - \frac{i^{2} w_{1t} w_{t0}}{S \left(w_{10} \left(\delta_{i j} - \frac{1}{S}\right) + \frac{1}{S}\right)} + \frac{i^{2} w_{1t}}{S \left(w_{10} \left(\delta_{i j} - \frac{1}{S}\right) + \frac{1}{S}\right)} - \frac{j^{2} w_{1t} w_{t0}}{S \left(w_{10} \left(\delta_{i j} - \frac{1}{S}\right) + \frac{1}{S}\right)} + \frac{j^{2} w_{t0}}{S \left(w_{10} \left(\delta_{i j} - \frac{1}{S}\right) + \frac{1}{S}\right)}


In [39]:
import torch

def compute_mean(S, i, j, w_10, w_1t, w_t0):
    # Kronecker delta in PyTorch
    kronecker_delta_ij = (i == j).float()

    # Compute the mean using the provided expression
    mean = (i * (S * w_1t * w_t0 * kronecker_delta_ij - w_1t * w_t0 + 1)) / (w_10 * (S * kronecker_delta_ij - 1) + 1)

    return mean

# Example usage
S_value = torch.tensor(10.0)  # State space size
i = torch.randint(1,9,(3,4)).float()  # Tensor for i
j = torch.randint(1,9,(3,4)).float()  # Tensor for j, must be the same size as i
w_10_value = torch.tensor(0.1)  # Weight w_10
w_1t_value = torch.tensor(0.2)  # Weight w_1t
w_t0_value = torch.tensor(0.3)  # Weight w_t0

# Call the function
second_moment_value = compute_second_moment(S_value, i, j, w_10_value, w_1t_value, w_t0_value)
mean_value = compute_mean(S_value, i, j, w_10_value, w_1t_value, w_t0_value)
variance_value = second_moment_value - mean_value**2
print(mean)
print(second_moment)
print(variance)

tensor([[6.2667, 8.3556, 3.1333, 2.0889],
        [5.2222, 4.1778, 7.3111, 4.0526],
        [8.3556, 3.1333, 1.0444, 7.3111]])
tensor([[42.6222, 34.9778, 34.9556, 28.8444],
        [28.9111, 26.7111, 38.2444, 24.2421],
        [38.1778, 34.9556, 25.1778, 35.8444]])
tensor([[  3.3511, -34.8375,  25.1378,  24.4810],
        [  1.6395,   9.2573, -15.2079,   7.8183],
        [-31.6375,  25.1378,  24.0869, -17.6079]])
