In [85]:
import numpy as np
from exceptiongroup import print_exc
from scipy.stats import truncnorm
from six import print_
from tornado.gen import multi


def multiplyGauss(m1, s1, m2, s2):
    s = 1 / (1 / s1 + 1 / s2)
    m = (m1 / s1 + m2 / s2) * s
    return m, s

def divideGauss(m1, s1, m2, s2):
    m,s =  multiplyGauss(m1, s1, m2, -s2)
    return m, s

def truncGaussMM(a,b,m0,s0):
    a_scaled, b_scaled = (a - m0) / np.sqrt(s0), (b - m0) / np.sqrt(s0)
    m = truncnorm.mean(a_scaled,b_scaled, loc=m0, scale=np.sqrt(s0))
    s = truncnorm.var(a_scaled, b_scaled, loc=m0, scale=np.sqrt(s0))
    return m,s

def conditional_outcome(s1, s2, result, sigma_t):
    mean_diff = s1 - s2
    std_diff = sigma_t

    # Set truncation limits based on match outcome
    if result == 1:
        a, b = 0, np.inf  # Truncate for player 1 winning
    else:
        a, b = -np.inf, 0  # Truncate for player 2 winning

    # Standardize the truncation limits
    a_std = (a - mean_diff) / std_diff
    b_std = (b - mean_diff) / std_diff

    return a_std, b_std, mean_diff, std_diff


def gibbs_sampler_result(mu_s1, mu_s2, sigma_s1, sigma_s2, sigma_t, n_iter, result):
    s1_samples = np.zeros(n_iter)
    s2_samples = np.zeros(n_iter)
    t_samples = np.zeros(n_iter)
    y_samples = np.zeros(n_iter)

    s1_current = mu_s1
    s2_current = mu_s2

    #  prior means and variances
    mu_prior = np.array([mu_s1, mu_s2]).reshape((2, 1))
    sigma_prior = np.array([[sigma_s1, 0], [0, sigma_s2]])
    sigma_prior_inv = np.linalg.inv(sigma_prior)

    A = np.array([1, -1]).reshape((1, 2))
    A_T = A.T

    for i in range(n_iter):
        #  Sample t given s1_current, s2_current (t|s1,s2)
        a_std, b_std, mean_diff, std_diff = conditional_outcome(s1_current, s2_current, result, sigma_t)
        t_current = truncnorm.rvs(a=a_std, b=b_std, loc=mean_diff, scale=std_diff)

        #  Sample s1 and s2 given t_current(s1,s2|t) 
        #  posterior covariance
        sigma_post = np.linalg.inv(sigma_prior_inv + (A_T @ A) / sigma_t**2)
        # posterior mean
        mu_post = sigma_post @ (sigma_prior_inv @ mu_prior + (A_T * t_current) / sigma_t**2)
        s1_current, s2_current = np.random.multivariate_normal(mu_post.flatten(), sigma_post)
        s1_samples[i] = s1_current
        s2_samples[i] = s2_current
        t_samples[i] = t_current
        y_samples[i] = result

    # Compute sample means and variances
    mean_s1 = np.mean(s1_samples)
    var_s1 = np.var(s1_samples, ddof=1)
    mean_s2 = np.mean(s2_samples)
    var_s2 = np.var(s2_samples, ddof=1)

    return mean_s1, var_s1, mean_s2, var_s2

In [88]:
m1, s1 =  25, (25/3)**2
m2, s2 = 25, (25/3)**2
sv = (25/6)**2 
y0 = 1
mu_diff_m = m1 - m2
mu_diff_s = s1 + s2 + sv


mu3_m = m1
mu3_s = s1

mu4_m = m2 
mu4_s = s2 


#print("Mu4_m: ", mu4_m, "Mu4_s: ", mu4_s, "Mu3_m: ", mu3_m, "Mu3_s: ", mu3_s)

mu5_m = mu_diff_m 
mu5_s = mu_diff_s 


if y0 == 1:
    a,b = 0, np.inf
else:
    a,b = -np.inf, 0


pt_m, pt_s = truncGaussMM(a, b, mu5_m, mu5_s)
# Compute the message from t to f_3
mu6_m, mu6_s = divideGauss(pt_m, pt_s, mu5_m, mu5_s)
# Compute the message from f_3 to s1
mu7_m = mu6_m  + mu3_m 
mu7_s = mu6_s + sv  + mu3_s
# Compute the message from f_3 to s2
px_m, px_s = multiplyGauss(mu3_m, mu3_s, mu7_m, mu7_s)
print(f"Mean of s1: {px_m}")
print(f"Variance of s1: {px_s}")



#compute the message from f_3 to s2
mu_8m = mu6_m - mu4_m
mu_8s = mu6_s + sv + mu4_s
px_m, px_s = multiplyGauss(mu4_m, mu4_s, mu_8m, mu_8s)
print(f"Mean of s2: {px_m}")
print(f"Variance of 2: {px_s}")
##Run the gibbs sampler and the moment matching for two players and plot to compare the results 
mu_s1, mu_s2 = 25, 25
sigma_s1, sigma_s2 = (25/3)**2, (25/3)**2
sigma_t = (25/6)**2
n_iter = 1000
result = 1
mean_s1, var_s1, mean_s2, var_s2 = gibbs_sampler_result(mu_s1, mu_s2, sigma_s1, sigma_s2, sigma_t, n_iter, result)

print(f"Mean of s1: {mean_s1}")
print(f"Variance of s1: {var_s1}")
print(f"Mean of s2: {mean_s2}")
print(f"Variance of s2: {var_s2}")



Mean of s1: 29.43269200446036
Variance of s1: 49.79568603803762
Mean of s2: 15.285585951847445
Variance of 2: 49.79568603803762
Mean of s1: 27.44757619291085
Variance of s1: 63.958132020097324
Mean of s2: 22.53944047177495
Variance of s2: 63.631304455835625
