In [None]:
# %load /Users/lzkelley/init.ipy
# Initialize Auto-Reloading Magic
%reload_ext autoreload
%autoreload 2

# Standard Imports
import os
import sys
import json
import copy
import shutil
import datetime
from collections import OrderedDict
from importlib import reload
import warnings

# Package Imports
import astropy as ap
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.patheffects

import sympy as sym
import sympy.physics
import sympy.physics.units as spu
from sympy.physics.units.systems import cgs
from sympy.physics.units import c, cm, g, s, km, gravitational_constant as G
pi = sym.pi

import h5py
import pandas as pd
import tqdm.notebook as tqdm
import corner

from zcode.constants import *
import zcode.plot as zplot
import zcode.math as zmath
import zcode.inout as zio
from zcode.inout.notebook import *
import zcode.astro as zastro

import kalepy as kale

# Silence annoying numpy errors
np.seterr(divide='ignore', invalid='ignore', over='ignore')
warnings.filterwarnings("ignore", category=UserWarning)

# Plotting settings
mpl.rc('font', **{'family': 'serif', 'sans-serif': ['Times'], 'size': 12})
mpl.rc('lines', solid_capstyle='round')
mpl.rc('mathtext', fontset='cm')
mpl.style.use('default')   # avoid dark backgrounds from dark theme vscode
plt.rcParams.update({'grid.alpha': 0.5})

sym.init_printing(use_latex='mathjax')


In [None]:
import cosmopy
reload(cosmopy)
cosmo = cosmopy.Cosmology()

Check the difference between converting forward (from redshift) and then converting backwards (to redshift)

In [None]:
NUM = 100
redz = 10**np.random.uniform(-2, 1, int(NUM))
redz = sorted(redz)

funcs = ['tage', 'tlbk', 'dlum', 'dcom']
# checks = ['age', 'lookback_time', 'luminosity_distance', 'comoving_distance']
funcs_for = ["z_to_" + ff for ff in funcs]
funcs_rev = [ff + "_to_z" for ff in funcs]

fig, axes = plt.subplots(figsize=[20, 6], ncols=len(funcs))

for ii, (ffor, frev) in enumerate(zip(funcs_for, funcs_rev)):
    val = getattr(cosmo, ffor)(redz)
    vz = getattr(cosmo, frev)(val)
    err = (vz - redz) / np.minimum(vz, redz)
    
    ax = axes[ii]
    ax.set(xscale='log', xlabel='redz', yscale='linear', ylabel='error')
    ax.plot(redz, err)
    
plt.show()


Compare the interpolation functions against the builtin astropy.cosmology calculations

In [None]:
def get_cosmo_errors(redz, cosmo, funcs_forward, funcs_check):
    errors = []
    for ii, (func, check) in enumerate(zip(funcs_forward, funcs_check)):
        val = getattr(cosmo, func)(redz)
        chk = getattr(cosmo, check)(redz).cgs.value
        err = (val - chk) / np.minimum(chk, val)
        errors.append(err)

    return errors

In [None]:
NUM = 1000
redz = 10**np.random.uniform(-2, 1, int(NUM))
redz = np.sort(redz)

funcs = ['tage', 'tlbk', 'dlum', 'dcom']
funcs_check = ['age', 'lookback_time', 'luminosity_distance', 'comoving_distance']
funcs_for = ["z_to_" + ff for ff in funcs]

errors = get_cosmo_errors(redz, cosmo, funcs_for, funcs_check)

fig, axes = plt.subplots(figsize=[20, 6], ncols=len(funcs))

for ii, (ax, err) in enumerate(zip(axes, errors)):
    ax.set(xscale='log', xlabel='redz', yscale='linear', ylabel='error')
    ax.plot(redz, err)
    
plt.show()


# Compare different number of interpolation points

In [None]:
# choose grid sizes to check
grid_sizes = [5, 10, 20, 40, 80, 160]

# choose redshifts
NUM = 300
redz = sorted(10**np.random.uniform(-2, 1, int(NUM)))

err_ave = np.zeros((len(grid_sizes), len(funcs)))
err_std = np.zeros_like(err_ave)
err_max = np.zeros_like(err_ave)

# calculate errors for each grid size
for ii, nn in enumerate(grid_sizes):
    cosmo = cosmopy.Cosmology(size=nn)
    errors = get_cosmo_errors(redz, cosmo, funcs_for, funcs_check)
    # get the average, standard-deviation, and maximum of errors
    for jj, err in enumerate(errors):
        err_fabs = np.fabs(err)
        err_ave[ii, jj] = np.mean(err_fabs)
        err_std[ii, jj] = np.std(err)
        err_max[ii, jj] = np.max(err_fabs)

# Plot
fig, axes = plt.subplots(figsize=[10, 4], ncols=len(funcs))
colors = []
for ii, ax in enumerate(axes):
    ax.set(xscale='log', xlabel='grid size', yscale='log', ylabel='error')
    # cc = None if ii == 0 else colors[ii]
    kw = dict(label='max') if ii == 0 else {}
    ax.plot(grid_sizes, err_max[:, ii], ls='-', **kw)
    kw = dict(label='std') if ii == 0 else {}
    ax.plot(grid_sizes, err_std[:, ii], ls='--', **kw)
    kw = dict(label='ave') if ii == 0 else {}
    ax.plot(grid_sizes, err_ave[:, ii], ls=':', **kw)

axes[0].legend()
plt.show()