# Benchmarks

In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
import jax
import jax.numpy as jnp

In [3]:
from zotbin.binned import load_binned

In [4]:
zedges, ell, ngals, noise, cl_in = load_binned('binned_40.npz')

In [5]:
nbin = 8
nzbin = len(zedges) - 1
w = jnp.ones((nbin, nzbin), jnp.float32) / nbin
weights = jnp.array([w, w])

In [6]:
from zotbin.reweight import gaussian_cl_covariance, reweighted_cov

## Cl Reweighting

In [7]:
@jax.jit
def reweight_cl(weights, ngals, cl_in):
    """
    """
    # assert len(weights) == len(ngals)
    nprobe = weights.shape[0]
    offset = 0
    w = [None] * nprobe
    nzbin = np.array([len(W) for W in weights])
    nout = np.sum(nzbin * (1 + np.arange(nprobe)))
    cl_out = [None] * nout
    for i1 in range(nprobe):
        nrow = len(weights[i1])
        rowstep = nprobe - i1
        for i2 in range(i1, nprobe):
            #assert weights[i2].shape[1] == len(ngals[i2])
            W = weights[i2] * ngals[i2]
            W /= jnp.sum(W, axis=1, keepdims=True)
            w[i2] = W
            cl = jnp.einsum('ip,spqk,jq->sijk', w[i1], cl_in[i2][i1], w[i2])
            for j in range(nrow):
                start = j if i1 == i2 else 0
                cl_out[offset + j * rowstep + i2 - i1] = cl[:, j, start:]
        offset += nrow * rowstep
    return jnp.concatenate(cl_out, axis=1)

In [8]:
cl = reweight_cl(weights, ngals, cl_in)

In [9]:
cl.shape

(8, 136, 100)

In [10]:
%timeit reweight_cl(weights, ngals, cl_in).block_until_ready()

379 µs ± 1.81 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## Noise Reweighting

In [11]:
import functools

In [12]:
@functools.partial(jax.jit, static_argnums=(1, 4))
def reweight_noise_cl(weights, gals_per_arcmin2, ngals, noise, nell):
    """
    """
    #assert len(weights) == len(noise)
    nprobe = weights.shape[0]
    noise_out = []
    ntracers = 0
    for i in range(nprobe):
        noise_inv_in = 1 / (ngals[i] * noise[i])
        noise_inv_out = gals_per_arcmin2 * weights[i].dot(noise_inv_in)
        noise_out.append(1 / noise_inv_out)
        ntracers += len(noise_inv_out)
    noise = jnp.concatenate(noise_out)

    # Define an ordering for the blocks of the signal vector
    cl_index = []
    for i in range(ntracers):
        for j in range(i, ntracers):
            cl_index.append((i, j))

    # Only include a noise contribution for the auto-spectra
    def get_noise_cl(inds):
        i, j = inds
        delta = 1.0 - jnp.clip(jnp.abs(i - j), 0.0, 1.0)
        return noise[i] * delta * jnp.ones(nell)

    return jax.lax.map(get_noise_cl, jnp.array(cl_index)) #, cl_index

In [13]:
nl = reweight_noise_cl(weights, 20., ngals, noise, len(ell))

In [14]:
nl.shape

(136, 100)

In [15]:
%timeit reweight_noise_cl(weights, 20., ngals, noise, len(ell)).block_until_ready()

2.4 ms ± 4.28 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Covariance Assembly

The original code:

In [16]:
def get_cl_index(ntracers):
    cl_index = []
    for i in range(ntracers):
        for j in range(i, ntracers):
            cl_index.append((i, j))
    return cl_index

cl_index = get_cl_index(2 * nbin)

In [17]:
cov = reweighted_cov(cl[-1], nl, cl_index, ell, 0.25)

In [18]:
%timeit reweighted_cov(cl[-1], nl, cl_index, ell, 0.25).block_until_ready()

275 ms ± 2.99 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


The new code:

In [19]:
cov2 = gaussian_cl_covariance(ell, cl[-1], nl, 0.25)

In [20]:
assert jnp.allclose(cov, cov2)

In [21]:
%timeit gaussian_cl_covariance(ell, cl[-1], nl, 0.25).block_until_ready()

1.48 ms ± 11.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
