<a href="https://colab.research.google.com/github/josephasal/cosmo_inference/blob/adaptive_cov/mcmc/convergence_diagnostics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Gelman-Rubin diagnostic

def gelman_rubin(chains):
  """
  Function that uses the Gelman-Rubin diagnostic test for convergence, compares variance between multiple chains to the variance within each chain

  input: 2d array of the chain for a particular parameter (rows = value of parameter at iteration, column = walker )

  output: Estimate of R

  """
  #Mean of each chain
  chain_means = np.mean(chains, axis=0) #mean of parameters at each point


  #Overall mean of all the chains across the whole thing
  overall_mean = np.mean(chain_means)

  #Between chain variance
  n, m = chains.shape
  B = (n/(m-1)) * np.sum((chain_means - overall_mean)**2) #lmao this was giving such a weird B value at first because it was summing before squaring, fixed with brackets now

  #Average chain variance
  W = 1/(m) * np.sum(np.var(chains, axis = 0, ddof = 1))

  #Calculate V
  V = ((n-1)/n) * W + ((m+1)/(m*n)) * B

  #Calculate R
  R = np.sqrt(V/W)

  return R

In [None]:
#Want to compare my MCMC code to emcee so going to use ESS to do so as we cannot use Gelman Rubin diagnostic on emcee because the chains are not independent

#Effective sample size

def autocorrelation(x, lag):

  """
  Calulcated the autocorrelation of array x at a given lag k
  Auto correlation is covariance(X,Y) over standard deviation
  """
  n = len(x)
  covariance = np.sum((x[:n-lag] - np.mean(x)) * (x[lag:] - np.mean(x)))
  std = np.sum((x - np.mean(x))**2)
  return covariance/std

def eff_sample_size(chain):
  """
  Calculates effective sample size, estimate of sample size that is not related

  inputs: chain of samples from mcmc

  outputs: effective sample size, number
  """
  #Need to do N divided by sum of lag from -inf to inf which simplifies to 1+2* of lag from 1 to T (first odd positive intger for which autocorrelation of that t+1 and t+2 are negative)
  #From chapter 11 in BDA3
  #just gonna use N//2 which is a common implementation for the sum

  N = len(chain)

  rho_sum = 0
  previous_rho = 0
  for i in range(1,N//2):
    rho_t = autocorrelation(chain, i)

    #Stop summing when pt+1 and pt+2 are negative
    if i > 1 and (rho_t + previous_rho) < 0:
      break

    rho_sum += rho_t
    previous_rho = rho_t #to go loop back for the comparison

  ess = N / (1 + 2 * rho_sum)

  return ess

#Well turns out this only does it for one chain rip, but i can use it to calculaurte all the chains ess shown in the next cell

In [2]:
#real ess calculator

def eff_sample_size_multichain(chains):

  """
  Calculated the total effective sample size by using all the chains this time

  Inputs:
  chains: array in the shape of (iterations, n_walker)

  outputs:
  total_ess: a number that says the effective sample size of all the chains
  """
  n_chains = chains.shape[1] #y column of chains array
  ess_values = [eff_sample_size(chains[:,i]) for i in range(n_chains)] #calculate ess for each chain
  total_ess = sum(ess_values) #add up all the chains ess to get one big final ess
  return total_ess