In [19]:
import torch

def compute_variance_telegram_bridge(S, i, j, w_10, w_1t, w_t0):
    # Kronecker delta in PyTorch, resulting in a tensor of 0s and 1s
    kronecker_delta_ij = (i == j).float()
    
    # Calculate each term based on the provided LaTeX expression
    term1 = -6 * i**2 * (S * w_10 * kronecker_delta_ij - w_10 + 1)
    term2 = (S * w_1t * w_t0 * kronecker_delta_ij - w_1t * w_t0 + 1)**2
    term3 = (w_10 * (S * kronecker_delta_ij - 1) + 1)
    term4 = term3 * (2 * S**2 + 3 * S + w_1t * w_t0 * (2 * S**2 + 3 * S + 1) - w_1t * (2 * S**2 + 3 * S + 1) - w_t0 * (2 * S**2 + 3 * S + 1) + 1)
    term5 = 6 * (S * w_10 * kronecker_delta_ij - w_10 + 1) * (S * i**2 * w_1t * w_t0 * kronecker_delta_ij - i**2 * w_1t * w_t0 + i**2 * w_1t - j**2 * w_1t * w_t0 + j**2 * w_t0)
    
    # Combine terms to compute the variance
    numerator = term1 * term2 + term3 * (term4 + term5)
    denominator = 6 * term3**2 * (S * w_10 * kronecker_delta_ij - w_10 + 1)
    
    variance = numerator / denominator
    
    return variance

In [20]:
def compute_mean(S, i, j, w_10, w_1t, w_t0):
    # Kronecker delta in PyTorch
    kronecker_delta_ij = (i == j).float()

    # Compute the mean using the provided expression
    mean = (i * (S * w_1t * w_t0 * kronecker_delta_ij - w_1t * w_t0 + 1)) / (w_10 * (S * kronecker_delta_ij - 1) + 1)

    return mean

In [27]:
def compute_second_moment(S, i, j, w_10, w_1t, w_t0):
    # Kronecker delta in PyTorch
    kronecker_delta_ij = (i == j).float()
    
    # Precompute term0
    term0 = (S**3/3 + S**2/2 + S/6)

    # Calculate each term based on the provided LaTeX expression
    term1 = (i**2 * w_1t * w_t0 * kronecker_delta_ij) / (w_10 * (kronecker_delta_ij - 1/S) + 1/S)
    term2 = (w_1t * w_t0 * term0) / (S**2 * w_10 * kronecker_delta_ij - S * w_10 + S)
    term3 = - (w_1t * term0) / (S**2 * w_10 * kronecker_delta_ij - S * w_10 + S)
    term4 = - (w_t0 * term0) / (S**2 * w_10 * kronecker_delta_ij - S * w_10 + S)
    term5 = term0 / (S**2 * w_10 * kronecker_delta_ij - S * w_10 + S)
    term6 = - (i**2 * w_1t * w_t0) / (S * (w_10 * (kronecker_delta_ij - 1/S) + 1/S))
    term7 = (i**2 * w_1t) / (S * (w_10 * (kronecker_delta_ij - 1/S) + 1/S))
    term8 = - (j**2 * w_1t * w_t0) / (S * (w_10 * (kronecker_delta_ij - 1/S) + 1/S))
    term9 = (j**2 * w_t0) / (S * (w_10 * (kronecker_delta_ij - 1/S) + 1/S))
    
    # Combine terms to compute the second moment
    second_moment = term1 + term2 + term3 + term4 + term5 + term6 + term7 + term8 + term9
    
    return second_moment

# Example usage
S_value = torch.tensor(10.0)  # State space size
i = torch.randint(1,9,(3,4)).float()  # Tensor for i
j = torch.randint(1,9,(3,4)).float()  # Tensor for j, must be the same size as i
w_10_value = torch.tensor(0.1)  # Weight w_10
w_1t_value = torch.tensor(0.2)  # Weight w_1t
w_t0_value = torch.tensor(0.3)  # Weight w_t0

# Call the function
second_moment_value = compute_second_moment(S_value, i, j, w_10_value, w_1t_value, w_t0_value)
variance_value = compute_variance_telegram_bridge(S_value, i, j, w_10_value, w_1t_value, w_t0_value)
mean_value = compute_mean(S_value, i, j, w_10_value, w_1t_value, w_t0_value)
variance_value2 = second_moment_value - mean_value**2

print(second_moment_value)

tensor([[30.6222, 44.3579, 34.9778, 42.4222],
        [33.9778, 15.9895, 40.5778, 28.9111],
        [42.6222, 29.8222, 46.6222, 35.8444]])


In [28]:
variance_value

tensor([[ -8.6489,   2.3129, -34.8375,  32.6044],
        [-19.4746,  10.0769, -29.2375,   1.6395],
        [  3.3511,  -9.4489,   7.3511, -17.6079]])

In [29]:
variance_value2

tensor([[ -8.6489,   2.3129, -34.8375,  32.6044],
        [-19.4746,  10.0769, -29.2375,   1.6395],
        [  3.3511,  -9.4489,   7.3511, -17.6079]])