<a href="https://colab.research.google.com/github/chihway/cosmology_on_beach_2022/blob/main/Tutorial3_cosmology_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Explore MCMC chains**

Time: 20 min

In this tutorial we will learn how to plot and interpret MCMC chains. 

We will learn to:
* Plot different number of parameters from the chains, print out their 1-sigma constraints
* Compare chains from different experiment and understand what the difference between them mean and don't mean

We will be making the plots using the software package [chainconsumer](https://github.com/Samreay/ChainConsumer).

In [None]:
import numpy as np
import pylab as mplot
%pylab inline
import os

In [None]:
# we will need to install chainconsumer to plot the chains
!pip install chainconsumer

In [None]:
from chainconsumer import ChainConsumer

We will be using the DES-SV and KiDS-450 chains from [this paper](https://arxiv.org/abs/1808.07335). The paper first tries to reproduce the cosmic shear results from the [DES-SV paper](https://arxiv.org/abs/1507.05552) and the [KiDS-450 paper](https://arxiv.org/abs/1601.05786). It then tries to unify the various analysis choices and compare them again. Let's again download the data.

In [None]:
!rm -rf data_3
!curl -O https://portal.nersc.gov/cfs/lsst/chihway/data_3.tar.gz
!tar -xvzf data_3.tar.gz

### First let's plot the chains from the paper directly

In [None]:
def weighted_percentile(data, percents, weights=None):

    if weights is None:
        return np.percentile(data, percents)
    ind = np.argsort(data)
    d = data[ind]
    w = weights[ind]
    p = 1.*w.cumsum()/w.sum()*100
    y = np.interp(percents, p, d)
    return y

def get_s8(omegam, sigma8):
    return sigma8*np.sqrt(omegam/0.3)
  
def constraints(s8, w):
    low = weighted_percentile(s8, 16, weights=w)
    high = weighted_percentile(s8, 84, weights=w)
    mean = np.average(s8, weights=w)
        
    print("$%.4f_{%.4f}^{%.4f}$" % (mean, mean-low, high-mean))
    return mean, mean-low, high-mean

In [None]:
kids_paper = 'data_3/kids450dir'
des_paper = 'data_3/dessv_chain_reduced_v2.txt'

In [None]:
data_params = []
weights = []
file_list = [des_paper, kids_paper]

Om = np.array([])
s8 = np.array([])
ww = np.array([])

data = np.loadtxt(file_list[0])
os.system("cat "+file_list[0]+" | tail -3 | head -1 | sed s/'='/' '/|awk '{print $2}'>nsample")
nsample = int(np.loadtxt('nsample'))
weights.append(data[-nsample:,14])
data_params.append([data[-nsample:,0], data[-nsample:,3]*(data[-nsample:,0]/0.3)**0.5]) #12

for i in range(8):
    data = np.loadtxt(file_list[1]+'_'+str(i+1)+'.txt')
    N = len(data)
    Om = np.concatenate((Om, data[int(N*0.3):,10]), axis=0)
    s8 = np.concatenate((s8, data[int(N*0.3):,14]), axis=0)
    ww = np.concatenate((ww, data[int(N*0.3):,0]), axis=0)
data_params.append([Om, get_s8(Om, s8)])
weights.append(ww)

In [None]:
c = ChainConsumer()
c.add_chain(data_params[0], weights=weights[0], parameters=[r"$\Omega_{\rm m}$", r"$S_{8}$"], name='DES-SV')
c.add_chain(data_params[1], weights=weights[1], name='KiDS-450')

c.configure(colors=['orange','g'], label_font_size=18, contour_label_font_size=20, 
            tick_font_size=20, linewidths=[1.5,1.5,1.5,1.5], sigma2d=False, shade=True, 
            kde=1.5, shade_alpha=[0.2,1,0.2,0.7], bar_shade=True, sigmas=[0,1,2])
fig = c.plotter.plot(extents=[(0.02,0.95),(0.4,1.1)])
fig.set_size_inches(4.5 + fig.get_size_inches()) 


### Next let's plot the "unified" chains

In [None]:
kids2_g = 'data_3/mcmc_kids_matched2_covg.txt'
des2_g = 'data_3/mcmc_des_matched2_covg.txt'


In [None]:
data_params = []
weights = []
file_list = [des2_g, kids2_g]

data = np.loadtxt(file_list[0])
os.system("cat "+file_list[0]+" | tail -3 | head -1 | sed s/'='/' '/|awk '{print $2}'>nsample")
nsample = int(np.loadtxt('nsample'))
weights.append(data[-nsample:,-1])
data_params.append([data[-nsample:,13], data[-nsample:,12]*(data[-nsample:,13]/0.3)**0.5]) #12

data = np.loadtxt(file_list[1])
os.system("cat "+file_list[1]+" | tail -3 | head -1 | sed s/'='/' '/|awk '{print $2}'>nsample")
nsample = int(np.loadtxt('nsample'))
weights.append(data[-nsample:,-1])
data_params.append([data[-nsample:,15], data[-nsample:,14]*(data[-nsample:,15]/0.3)**0.5]) #12


In [None]:
c = ChainConsumer()
c.add_chain(data_params[0], weights=weights[0], parameters=[r"$\Omega_{\rm m}$", r"$S_{8}$"], name='DES-SV')
c.add_chain(data_params[1], weights=weights[1], name='KiDS-450')

c.configure(colors=['orange','g'], label_font_size=18, contour_label_font_size=20, 
            tick_font_size=20, linewidths=[1.5,1.5,1.5,1.5], sigma2d=False, shade=True, 
            kde=1.5, shade_alpha=[0.2,1,0.2,0.7], bar_shade=True, sigmas=[0,1,2])
fig = c.plotter.plot(extents=[(0.02,0.95),(0.4,1.1)])
fig.set_size_inches(4.5 + fig.get_size_inches()) 
