# Imports

In [3]:
from scipy.stats import truncnorm, multivariate_normal, norm
import numpy as np
import matplotlib.pyplot as plt
import time
import seaborn as sns
import pandas as pd

# Hyperparameters

In [4]:
mu_1 = 25
mu_2 = 25
sigma_1 = 25/3
sigma_2 = 25/3
sigma_t = 25/6
iterations = 1200
burn_in = 25

# Q4 - Gibbs sampler

In [5]:
# sample from a truncated normal distribution
def truncated_normal(mean, std, lower, upper):
    a, b = (lower - mean) / std, (upper - mean) / std
    return truncnorm.rvs(a, b, loc=mean, scale=std)

In [2]:
# Gibbs Sampler
def gibbs_sampler(mu_1, sigma_1, mu_2, sigma_2 ,y, iterations):
    
    if y == 1:
        t = np.abs(np.random.randn())
    else:
        t = -np.abs(np.random.randn())

    
    samples = []
    
    for i in range(iterations):
        # Calculationg mean and sigma of p(s1,s2|t,y) = p(S|t,y)
        
        sigma_s_inv = np.array([[1.0/(sigma_1 * sigma_1), 0.0],
                                [0.0, 1.0/(sigma_2 * sigma_2)]])
        
        sigma_t_given_s_inv = 1.0 / (sigma_t * sigma_t)
        mu_s = np.array([mu_1, mu_2]).T.reshape(2,1) # 2 x 1 (2, 1)
        A = np.array([1.0, -1.0]).reshape(1, 2) # 1 x 2
        
        # (2x2)
        cov_s_given_t = np.linalg.inv(sigma_s_inv + (A.T @ (sigma_t_given_s_inv * A) ) )
        
        # (2x1)
        mean_s_given_t = cov_s_given_t @ ( (sigma_s_inv @ mu_s) + (A.T * sigma_t_given_s_inv * t) )
 
       
        
        mean_s_given_t = mean_s_given_t.reshape(2,)
        # QUESTION 3.1
        # Step 1: Draw s_1 and s_2 from the conditional distribution N(mean_s|t, cov_s|t) -
        s_1, s_2 = multivariate_normal.rvs(mean=mean_s_given_t, cov=cov_s_given_t)

        

        # QUESTION 3.2
        # Step 2: Draw t from the conditional distribution 
        mean_t = s_1 - s_2 
        if y == 1:
            t = truncated_normal(mean_t, sigma_t, 0, np.inf)  # For y = 1, t > 0
        else:
            t = truncated_normal(mean_t, sigma_t, -np.inf, 0)  # For y = -1, t < 0

        # Store samples
        samples.append((s_1, s_2, t))

    
  
    samples = np.array(samples)
    return samples

# Q5

# Q6

# Q7

# Q8