Skip to content

Commit

Permalink
fix upper limit prior_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
Antony Lewis committed Aug 5, 2015
1 parent 7ed44a5 commit f7214f2
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
6 changes: 3 additions & 3 deletions getdist/mcsamples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,7 @@ def get1DDensityGridData(self, j, writeDataToFile=False, get_density=False, para
prior_mask[winw] = 0.5
prior_mask[: winw] = 0
if par.has_limits_top:
prior_mask[-winw] = 0.5
prior_mask[-(winw + 1)] = 0.5
prior_mask[-winw:] = 0
a0 = convolve1D(prior_mask, Kernel.Win, 'valid', cache=cache)
ix = np.nonzero(a0 * density1D.P)
Expand Down Expand Up @@ -1417,13 +1417,13 @@ def _setEdgeMask2D(self, parx, pary, prior_mask, winw, alledge=False):
prior_mask[:, winw] /= 2
prior_mask[:, :winw] = 0
if parx.has_limits_top:
prior_mask[:, -winw] /= 2
prior_mask[:, -(winw + 1)] /= 2
prior_mask[:, -winw:] = 0
if pary.has_limits_bot:
prior_mask[winw, :] /= 2
prior_mask[:winw:] = 0
if pary.has_limits_top:
prior_mask[-winw, :] /= 2
prior_mask[-(winw + 1), :] /= 2
prior_mask[-winw:, :] = 0
if alledge:
prior_mask[:, :winw] = 0
Expand Down
40 changes: 36 additions & 4 deletions getdist_tests/getdist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import subprocess
import shutil
from getdist import loadMCSamples, plots, IniFile
from getdist_tests.test_distributions import Test2DDistributions
from getdist_tests.test_distributions import Test2DDistributions, Gaussian1D, Gaussian2D
from getdist.mcsamples import MCSamples


class GetDistFileTest(unittest.TestCase):
Expand Down Expand Up @@ -101,9 +102,9 @@ class GetDistTest(unittest.TestCase):
def setUp(self):
np.random.seed(10)
self.testdists = Test2DDistributions()
self.samples = self.testdists.bimodal[0].MCSamples(12000, logLikes=True)

def testTables(self):
self.samples = self.testdists.bimodal[0].MCSamples(12000, logLikes=True)
self.assertEqual(str(self.samples.getLatex(limit=2)),
"(['x', 'y'], ['0.0^{+2.1}_{-2.1}', '0.0^{+1.3}_{-1.3}'])", "MCSamples.getLatex error")
table = self.samples.getTable(columns=1, limit=1, paramList=['x'])
Expand All @@ -117,11 +118,42 @@ def testLimits(self):
samples = self.testdists.cut_correlated.MCSamples(12000, logLikes=False)
stats = samples.getMargeStats()
lims = stats.parWithName('x').limits
self.assertAlmostEqual(lims[0].lower, 0.2175, 3)
self.assertAlmostEqual(lims[1].lower, 0.0548, 3)
self.assertAlmostEqual(lims[0].lower, 0.2205, 3)
self.assertAlmostEqual(lims[1].lower, 0.0491, 3)
self.assertTrue(lims[2].onetail_lower)

# check some analytics (note not very accurate actually)
samples = Gaussian1D(0, 1, xmax=1).MCSamples(1500000, logLikes=False)
stats = samples.getMargeStats()
lims = stats.parWithName('x').limits
self.assertAlmostEqual(lims[0].lower, -0.792815, 2)
self.assertAlmostEqual(lims[0].upper, 0.792815, 2)
self.assertAlmostEqual(lims[1].lower, -1.72718, 2)

def testDensitySymmetries(self):
# check flipping samples gives flipped density
samps = Gaussian1D(0, 1, xmin=-1, xmax=4).MCSamples(12000)
d = samps.get1DDensity('x')
samps.samples[:, 0] *= -1
samps = MCSamples(samples=samps.samples, names=['x'], ranges={'x':[-4, 1]})
d2 = samps.get1DDensity('x')
self.assertTrue(np.allclose(d.P, d2.P[::-1]))

samps = Gaussian2D([0, 0], np.diagflat([1, 2]), xmin=-1, xmax=2, ymin=0, ymax=3).MCSamples(12000)
d = samps.get2DDensity('x', 'y')
samps.samples[:, 0] *= -1
samps = MCSamples(samples=samps.samples, names=['x', 'y'], ranges={'x':[-2, 1], 'y':[0, 3]})
d2 = samps.get2DDensity('x', 'y')
self.assertTrue(np.allclose(d.P, d2.P[:, ::-1]))
samps.samples[:, 0] *= -1
samps.samples[:, 1] *= -1
samps = MCSamples(samples=samps.samples, names=['x', 'y'], ranges={'x':[-1, 2], 'y':[-3, 0]})
d2 = samps.get2DDensity('x', 'y')
self.assertTrue(np.allclose(d.P, d2.P[::-1, ::]))


def testPlots(self):
self.samples = self.testdists.bimodal[0].MCSamples(12000, logLikes=True)
g = plots.getSinglePlotter()
samples = self.samples
p = samples.getParams()
Expand Down

0 comments on commit f7214f2

Please sign in to comment.