# Benchmarks

In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
import jax
import jax.numpy as jnp

In [3]:
from zotbin.binned import load_binned

In [4]:
zedges, ell, ngals, noise, cl_in = load_binned('binned_40.npz')

In [33]:
nbin = 8
nzbin = len(zedges) - 1
w = jnp.ones((nbin, nzbin), jnp.float32) / nbin
weights = jnp.array([w, w])

## Cl Reweighting

In [34]:
@jax.jit
def reweight_cl(weights, ngals, cl_in):
    """
    """
    # assert len(weights) == len(ngals)
    nprobe = weights.shape[0]
    offset = 0
    w = [None] * nprobe
    nzbin = np.array([len(W) for W in weights])
    nout = np.sum(nzbin * (1 + np.arange(nprobe)))
    cl_out = [None] * nout
    for i1 in range(nprobe):
        nrow = len(weights[i1])
        rowstep = nprobe - i1
        for i2 in range(i1, nprobe):
            #assert weights[i2].shape[1] == len(ngals[i2])
            W = weights[i2] * ngals[i2]
            W /= jnp.sum(W, axis=1, keepdims=True)
            w[i2] = W
            cl = jnp.einsum('ip,spqk,jq->sijk', w[i1], cl_in[i2][i1], w[i2])
            for j in range(nrow):
                start = j if i1 == i2 else 0
                cl_out[offset + j * rowstep + i2 - i1] = cl[:, j, start:]
        offset += nrow * rowstep
    return jnp.concatenate(cl_out, axis=1)

In [35]:
cl = reweight_cl(weights, ngals, cl_in)

In [36]:
cl.shape

(8, 136, 100)

In [8]:
%timeit reweight_cl(weights, ngals, cl_in).block_until_ready()

394 µs ± 1.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## Noise Reweighting

In [10]:
import functools

In [25]:
@functools.partial(jax.jit, static_argnums=(1, 4))
def reweight_noise_cl(weights, gals_per_arcmin2, ngals, noise, nell):
    """
    """
    #assert len(weights) == len(noise)
    nprobe = weights.shape[0]
    noise_out = []
    ntracers = 0
    for i in range(nprobe):
        noise_inv_in = 1 / (ngals[i] * noise[i])
        noise_inv_out = gals_per_arcmin2 * weights[i].dot(noise_inv_in)
        noise_out.append(1 / noise_inv_out)
        ntracers += len(noise_inv_out)
    noise = jnp.concatenate(noise_out)

    # Define an ordering for the blocks of the signal vector
    cl_index = []
    for i in range(ntracers):
        for j in range(i, ntracers):
            cl_index.append((i, j))

    # Only include a noise contribution for the auto-spectra
    def get_noise_cl(inds):
        i, j = inds
        delta = 1.0 - jnp.clip(jnp.abs(i - j), 0.0, 1.0)
        return noise[i] * delta * jnp.ones(nell)

    return jax.lax.map(get_noise_cl, jnp.array(cl_index)) #, cl_index

In [26]:
nl = reweight_noise_cl(weights, 20., ngals, noise, len(ell))

In [37]:
nl.shape

(136, 100)

In [16]:
%timeit reweight_noise_cl(weights, 20., ngals, noise, len(ell)).block_until_ready()

2.32 ms ± 9.33 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Covariance Assembly

In [27]:
def get_cl_index(ntracers):
    cl_index = []
    for i in range(ntracers):
        for j in range(i, ntracers):
            cl_index.append((i, j))
    return cl_index

cl_index = get_cl_index(2 * nbin)

In [45]:
def get_cov_blocks(cl_index):
    def find_index(a, b):
        if (a, b) in cl_index:
            return cl_index.index((a, b))
        else:
            return cl_index.index((b, a))
    cov_blocks = []
    for (i, j) in cl_index:
        for (m, n) in cl_index:
            cov_blocks.append(
                (find_index(i, m), find_index(j, n), find_index(i, n), find_index(j, m))
            )
    return cov_blocks

In [72]:
@functools.partial(jax.jit, static_argnums=(4,))
def reweighted_cov(cl_out, nl_out, cov_blocks, ell, fsky):
    """
    """
    # This is essentially jc.angular_cl.gaussian_cl_covariance without using probes...
    cl_obs = cl_out + nl_out
    ncl = cl_obs.shape[0]
    norm = (2 * ell + 1) * jnp.gradient(ell) * fsky

    def get_cov_block(inds):
        a, b, c, d = inds
        return cl_obs[a] * cl_obs[b] + cl_obs[c] * cl_obs[d]

    # Build a sparse representation of the output covariance.
    return jax.lax.map(get_cov_block, cov_blocks).reshape((ncl, ncl, len(ell))) / norm

In [65]:
cov_blocks = jnp.array(get_cov_blocks(cl_index))

In [76]:
cov_blocks

DeviceArray([[  0,   0,   0,   0],
             [  0,   1,   1,   0],
             [  0,   2,   2,   0],
             ...,
             [134, 134, 134, 134],
             [134, 135, 135, 134],
             [135, 135, 135, 135]], dtype=int32)

In [162]:
sqrt_norm = jnp.sqrt((2 * ell + 1) * np.gradient(ell))

In [206]:
import jax_cosmo.sparse

In [250]:
@jax.jit
def gaussian_cl_covariance_helper(ell, cl_signal, cl_noise, p, q, f_sky):
    cl_obs = cl_signal + cl_noise
    norm = (2 * ell + 1) * jnp.gradient(ell) * f_sky
    outer = cl_obs.reshape(-1, 1, len(ell)) * cl_obs / norm
    return outer[p, q] + outer[q, p]

In [251]:
def gaussian_cl_covariance(ell, ntracer, cl_signal, cl_noise, f_sky=0.25, sparse=True):
    ell = jnp.atleast_1d(ell)
    p, q = get_cov_pq(ntracer)
    cov = gaussian_cl_covariance_helper(ell, cl_signal, cl_noise, p, q, f_sky)
    return cov if sparse else jax_cosmo.sparse.to_dense(cov)

In [252]:
cov3 = gaussian_cl_covariance(ell, 16, cl[-1], nl, ell)

In [253]:
jnp.allclose(cov, cov3)

DeviceArray(True, dtype=bool)

In [248]:
%timeit gaussian_cl_covariance(ell, 16, cl[-1], nl).block_until_ready()

1.48 ms ± 18.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [244]:
%timeit gaussian_cl_covariance_helper(ell, cl[-1], nl, p, q).block_until_ready()

1.31 ms ± 19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [127]:
cov3.shape

(136, 136, 100)

In [227]:
_cov_pq_cache = {}

def get_cov_pq(ntracer):
    
    if ntracer not in _cov_pq_cache:
    
        j = jnp.arange(ntracer)
        i = j.reshape(-1, 1)
        k_of_ij = (2 * ntracer - i - 1) * i // 2 + j

        i_of_k, j_of_k = [], []
        k = 0
        for i in range(ntracer):
            for j in range(i, ntracer):
                i_of_k.append(i)
                j_of_k.append(j)
                assert k_of_ij[i, j] == k
                k += 1
        i_of_k = jnp.array(i_of_k)
        j_of_k = jnp.array(j_of_k)

        k1 = jnp.arange(len(i_of_k))
        k2 = k1.reshape(-1, 1)
        p = k_of_ij[i_of_k[k1], i_of_k[k2]]
        q = k_of_ij[j_of_k[k1], j_of_k[k2]]
        _cov_pq_cache[ntracer] = (p, q)

    return _cov_pq_cache[ntracer]

In [228]:
p, q = get_cov_pq(16)

In [229]:
%timeit get_cov_pq(16)

131 ns ± 0.0508 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)


In [61]:
get_cov_blocks(get_cl_index(3))

[(0, 0, 0, 0),
 (0, 1, 1, 0),
 (0, 2, 2, 0),
 (1, 1, 1, 1),
 (1, 2, 2, 1),
 (2, 2, 2, 2),
 (0, 1, 0, 1),
 (0, 3, 1, 1),
 (0, 4, 2, 1),
 (1, 3, 1, 3),
 (1, 4, 2, 3),
 (2, 4, 2, 4),
 (0, 2, 0, 2),
 (0, 4, 1, 2),
 (0, 5, 2, 2),
 (1, 4, 1, 4),
 (1, 5, 2, 4),
 (2, 5, 2, 5),
 (1, 1, 1, 1),
 (1, 3, 3, 1),
 (1, 4, 4, 1),
 (3, 3, 3, 3),
 (3, 4, 4, 3),
 (4, 4, 4, 4),
 (1, 2, 1, 2),
 (1, 4, 3, 2),
 (1, 5, 4, 2),
 (3, 4, 3, 4),
 (3, 5, 4, 4),
 (4, 5, 4, 5),
 (2, 2, 2, 2),
 (2, 4, 4, 2),
 (2, 5, 5, 2),
 (4, 4, 4, 4),
 (4, 5, 5, 4),
 (5, 5, 5, 5)]

In [73]:
cov2 = reweighted_cov(cl[-1], nl, cov_blocks, ell, 0.25)

In [135]:
jnp.array_equal(cov, cov2)

DeviceArray(True, dtype=bool)

In [39]:
# Original
%timeit reweighted_cov(cl[-1], nl, cl_index, ell, 0.25).block_until_ready()

274 ms ± 2.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [59]:
# Pass in python list cov_blocks
%timeit reweighted_cov(cl[-1], nl, cov_blocks, ell, 0.25).block_until_ready()

236 ms ± 351 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [69]:
# Pass in jnp.array cov_blocks
%timeit reweighted_cov(cl[-1], nl, cov_blocks, ell, 0.25).block_until_ready()

238 ms ± 2.88 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [71]:
# Move /norm out of norm
%timeit reweighted_cov(cl[-1], nl, cov_blocks, ell, 0.25).block_until_ready()

237 ms ± 890 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [75]:
# Include cov_blocks in jit
%timeit reweighted_cov(cl[-1], nl, cov_blocks, ell, 0.25).block_until_ready()

235 ms ± 459 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
