Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Fixes a few bugs in three point function paircounting. #516

Merged
merged 8 commits into from Aug 23, 2018
75 changes: 75 additions & 0 deletions nbodykit/algorithms/tests/test_threeptcf.py
Expand Up @@ -5,6 +5,15 @@
import os

setup_logging("debug")

# The test result data (threeptcf_sim_result.dat) is computed with
# Daniel Eisenstein's
# C++ implementation on the same input data set for poles up to l=11;
# We shall agree with it to high precision.
#
# If we need to reproduced these files:
# Nick Hand sent the code and instructions to Yu Feng on Aug-20-2018.

data_dir = os.path.join(os.path.split(os.path.abspath(__file__))[0], 'data')

@MPITest([4])
Expand Down Expand Up @@ -101,3 +110,69 @@ def test_survey_threeptcf(comm):

if comm.rank == 0:
os.remove(filename)

@MPITest([1])
def test_sim_threeptcf_pedantic(comm):

CurrentMPIComm.set(comm)
BoxSize = 400.0

# load the test data
filename = os.path.join(data_dir, 'threeptcf_sim_data.dat')
cat = CSVCatalog(filename, names=['x', 'y', 'z', 'w'])
cat['Position'] = transform.StackColumns(cat['x'], cat['y'], cat['z'])
cat['Position'] *= BoxSize

cat = cat[::20]

# r binning
nbins = 8
edges = numpy.linspace(0, 200.0, nbins+1)

# run the algorithm
ells = [0, 2, 4, 8]
r = SimulationBox3PCF(cat, ells, edges, BoxSize=BoxSize, weight='w')
p_fast = r.run()
p_pedantic = r.run(pedantic=True)

# test equality
for i, ell in enumerate(sorted(ells)):
x1 = p_fast['corr_%d' %ell]
x2 = p_pedantic['corr_%d' %ell]
assert_allclose(x1, x2)

@MPITest([1])
def test_sim_threeptcf_shuffled(comm):

CurrentMPIComm.set(comm)
BoxSize = 400.0

# load the test data
filename = os.path.join(data_dir, 'threeptcf_sim_data.dat')
cat = CSVCatalog(filename, names=['x', 'y', 'z', 'w'])
cat['Position'] = transform.StackColumns(cat['x'], cat['y'], cat['z'])
cat['Position'] *= BoxSize

cat = cat

# r binning
nbins = 8
edges = numpy.linspace(0, 200.0, nbins+1)

# run the algorithm
ells = list(range(0, 2))[::-1]
r = SimulationBox3PCF(cat, ells, edges, BoxSize=BoxSize, weight='w')

# load the result from file
truth = numpy.empty((8,8,11))
with open(os.path.join(data_dir, 'threeptcf_sim_result.dat'), 'r') as ff:
for line in ff:
fields = line.split()
i, j = int(fields[0]), int(fields[1])
truth[i,j,:] = list(map(float, fields[2:]))
truth[j,i,:] = truth[i,j,:]

# test equality
for i, ell in enumerate(sorted(ells)):
x = r.poles['corr_%d' %ell]
assert_allclose(x * (4*numpy.pi)**2 / (2*ell+1), truth[...,i], rtol=1e-3, err_msg='mismatch for ell=%d' %ell)
72 changes: 50 additions & 22 deletions nbodykit/algorithms/threeptcf.py
@@ -1,6 +1,6 @@
from nbodykit import CurrentMPIComm
from nbodykit.binned_statistic import BinnedStatistic

from collections import OrderedDict
import numpy
import logging
import kdcount
Expand Down Expand Up @@ -32,7 +32,7 @@ def __init__(self, source, poles, edges, required_cols, BoxSize=None, periodic=N
self.attrs['BoxSize'] = BoxSize
self.attrs['periodic'] = periodic

def _run(self, pos, w, pos_sec, w_sec, boxsize=None):
def _run(self, pos, w, pos_sec, w_sec, boxsize=None, bunchsize=10000):
"""
Internal function to run the 3PCF algorithm on the input data and
weights.
Expand All @@ -47,6 +47,8 @@ def _run(self, pos, w, pos_sec, w_sec, boxsize=None):
nbins = len(self.attrs['edges'])-1
Nell = len(self.attrs['poles'])
zeta = numpy.zeros((Nell,nbins,nbins), dtype='f8')
alms = {}
walms = {}

# compute the Ylm expressions we need
if self.comm.rank == 0:
Expand Down Expand Up @@ -90,29 +92,45 @@ def callback(r, i, j, iprim=None):
weights = Ylms[(l,m)] * w_sec[i]

# sum over for each radial bin
alm = numpy.zeros(nbins, dtype='c8')
alm += numpy.bincount(dig, weights=weights.real, minlength=nbins+2)[1:-1]
if m != 0:
alm += 1j*numpy.bincount(dig, weights=weights.imag, minlength=nbins+2)[1:-1]
alm = alms.setdefault((l, m), numpy.zeros(nbins, dtype='c16'))
walm = walms.setdefault((l, m), numpy.zeros(nbins, dtype='c16'))

# compute alm * conjugate(alm)
alm = w0*numpy.outer(alm, alm.conj())
if m != 0: alm += alm.T # add in the -m contribution for m != 0
zeta[l,...] += alm.real
r1 = numpy.bincount(dig, weights=weights.real, minlength=nbins+2)[1:-1]
alm[...] += r1
walm[...] += w0 * r1
if m != 0:
i1 = numpy.bincount(dig, weights=weights.imag, minlength=nbins+2)[1:-1]
alm[...] += 1j*i1
walm[...] += w0*1j*i1

# determine rank with largest load
loads = self.comm.allgather(len(pos))
largest_load = numpy.argmax(loads)
chunk_size = max(loads) // 10

# compute multipoles for each primary
# compute multipoles for each primary (s vector in the paper)
for iprim in range(len(pos)):
# alms must be clean for each primary particle; (s) in eq 15 and 8 of arXiv:1506.02040v2
alms.clear()
walms.clear()
tree_prim = kdcount.KDTree(numpy.atleast_2d(pos[iprim]), boxsize=boxsize).root
tree_sec.enum(tree_prim, rmax, process=callback, iprim=iprim)
tree_sec.enum(tree_prim, rmax, process=callback, iprim=iprim, bunch=bunchsize)

if self.comm.rank == largest_load and iprim % chunk_size == 0:
self.logger.info("%d%% done" % (10*iprim//chunk_size))

# combine alms into zeta(s);
# this cannot be done in the callback because
# it is a nonlinear function (outer product) of alm.
for (l, m) in alms:
alm = alms[(l, m)]
walm = walms[(l, m)]

# compute alm * conjugate(alm)
alm_w_alm = numpy.outer(walm, alm.conj())
if m != 0: alm_w_alm += alm_w_alm.T # add in the -m contribution for m != 0
zeta[Ylm_cache.ell_to_iell[l], ...] += alm_w_alm.real

# sum across all ranks
zeta = self.comm.allreduce(zeta)

Expand All @@ -121,14 +139,15 @@ def callback(r, i, j, iprim=None):
zeta /= (4*numpy.pi)

# make a BinnedStatistic
dtype = numpy.dtype([('corr_%d' %i, zeta.dtype) for i in range(zeta.shape[0])])
dtype = numpy.dtype([('corr_%d' % ell, zeta.dtype) for ell in self.attrs['poles']])
data = numpy.empty(zeta.shape[-2:], dtype=dtype)
for i in range(zeta.shape[0]):
data['corr_%d' %i] = zeta[i]
for i, ell in enumerate(self.attrs['poles']):
data['corr_%d' % ell] = zeta[i]

# save the result
edges = self.attrs['edges']
self.poles = BinnedStatistic(['r1', 'r2'], [edges, edges], data)
poles = BinnedStatistic(['r1', 'r2'], [edges, edges], data)
return poles

def __getstate__(self):
return {'poles':self.poles.data, 'attrs':self.attrs}
Expand Down Expand Up @@ -229,10 +248,10 @@ def __init__(self, source, poles, edges, BoxSize=None, periodic=True, weight='We
raise ValueError(("periodic pair counts cannot be computed for Rmax > BoxSize/2"))

# run the algorithm
self.run()
self.poles = self.run()


def run(self):
def run(self, pedantic=False):
"""
Compute the three-point CF multipoles. This attaches the following
the attributes to the class:
Expand Down Expand Up @@ -262,7 +281,10 @@ def run(self):
self.logger, smoothing)

# run the algorithm
self._run(pos, w, pos_sec, w_sec, boxsize=boxsize)
if pedantic:
return self._run(pos, w, pos_sec, w_sec, boxsize=boxsize, bunchsize=1)
else:
return self._run(pos, w, pos_sec, w_sec, boxsize=boxsize)

class SurveyData3PCF(Base3PCF):
"""
Expand Down Expand Up @@ -334,7 +356,7 @@ def __init__(self, source, poles, edges, cosmo, domain_factor=4,
self.attrs['domain_factor'] = domain_factor

# run the algorithm
self.run()
self.poles = self.run()

def run(self):
"""
Expand Down Expand Up @@ -364,7 +386,7 @@ def run(self):
domain_factor=self.attrs['domain_factor'])

# run the algorithm
self._run(pos, w, pos_sec, w_sec)
return self._run(pos, w, pos_sec, w_sec)


class YlmCache(object):
Expand All @@ -384,6 +406,12 @@ def __init__(self, ells, comm):

self.ells = numpy.asarray(ells).astype(int)
self.max_ell = max(self.ells)

# look up table from ell to iell, index for cummulating results.
self.ell_to_iell = numpy.empty(self.max_ell + 1, dtype=int)
for iell, ell in enumerate(self.ells):
self.ell_to_iell[ell] = iell

lms = [(l,m) for l in ells for m in range(0, l+1)]

# compute the Ylm string expressions in parallel
Expand Down Expand Up @@ -417,7 +445,7 @@ def from_cache(name, pow):
self._cache = {}

# make the Ylm functions
self._Ylms = {}
self._Ylms = OrderedDict()
for lm, expr in exprs:
expr = parse_expr(expr, local_dict={'zhat':zhat, 'xpyhat':xpyhat})
for var in args[lm]:
Expand Down