Skip to content

Commit

Permalink
fixed columns identified allowing for some floating error
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Apr 25, 2021
1 parent 5abf2cb commit 18bb14f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 24 deletions.
28 changes: 12 additions & 16 deletions getdist/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,9 +1001,11 @@ def deleteFixedParams(self):
fixed = []
values = []
for i in range(self.samples.shape[1]):
if np.all(self.samples[:, i] == self.samples[0, i]):
fixed.append(i)
values.append(self.samples[0, i])
if np.isclose(self.samples[0, i], self.samples[-1, i]):
mean = np.average(self.samples[:, i])
if np.allclose(self.samples[:, i], mean, rtol=1e-12, atol=0):
fixed.append(i)
values.append(mean)
self.changeSamples(np.delete(self.samples, fixed, 1))
return fixed, values

Expand Down Expand Up @@ -1072,7 +1074,7 @@ def __init__(self, root=None, jobItem=None, paramNamesFile=None, names=None, lab
"""

self.chains = None
WeightedSamples.__init__(self, **kwargs)
super().__init__(**kwargs)
self.jobItem = jobItem
self.ignore_lines = float(kwargs.get('ignore_rows', 0))
self.root = root
Expand Down Expand Up @@ -1236,11 +1238,11 @@ def getParamSampleDict(self, ix, want_derived=True):
:return: ordered dictionary of parameter values
"""
res = dict()
res['weight'] = self.weights[ix]
res['loglike'] = self.loglikes[ix]
for i, name in enumerate(self.paramNames.names):
if want_derived or not name.isDerived:
res[name.name] = self.samples[ix, i]
res['weight'] = self.weights[ix]
res['loglike'] = self.loglikes[ix]
return res

def _makeParamvec(self, par):
Expand All @@ -1250,7 +1252,7 @@ def _makeParamvec(self, par):
par = par.name
if isinstance(par, str):
return self.samples[:, self.index[par]]
return WeightedSamples._makeParamvec(self, par)
return super()._makeParamvec(par)

def updateChainBaseStatistics(self):
# old name, use updateBaseStatistics
Expand Down Expand Up @@ -1450,17 +1452,11 @@ def deleteFixedParams(self):
Delete parameters that are fixed (the same value in all samples)
"""
if self.samples is not None:
fixed, values = WeightedSamples.deleteFixedParams(self)
fixed, values = super().deleteFixedParams()
self.chains = None
else:
fixed = []
values = []
chain = self.chains[0]
for i in range(chain.n):
if np.all(chain.samples[:, i] == chain.samples[0, i]):
fixed.append(i)
values.append(chain.samples[0, i])
for chain in self.chains:
fixed, values = self.chains[0].deleteFixedParams()
for chain in self.chains[1:]:
chain.changeSamples(np.delete(chain.samples, fixed, 1))
if hasattr(self, 'ranges'):
for ix, value in zip(fixed, values):
Expand Down
23 changes: 16 additions & 7 deletions getdist/cobaya_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,20 @@ def MCSamplesFromCobaya(info, collections, name_tag=None,
if ignore_rows != 0 and skip != 0:
logging.warning("You are asking for rows to be ignored (%r), but some (%r) were "
"already ignored in the original chain.", ignore_rows, skip)
var_params = [k for k, v in info_params.items() if is_sampled_param(v) or is_derived_param(v)]
var_params = [k for k, v in info_params.items() if
is_sampled_param(v) or is_derived_param(v)]
assert set(columns[2:]) == set(var_params), (
"Info and collection(s) are not compatible, because their parameters differ: "
"the collection(s) have %r and the info has %r. " % (columns[2:], var_params) +
"the collection(s) have %r and the info has %r. " % (
columns[2:], var_params) +
"Are you sure that you are using an *updated* info dictionary "
"(i.e. the output of `cobaya.run`)?")
# We need to use *collection* sorting, not info sorting!
names = [p + ("*" if is_derived_param(info_params[p]) else "")
for p in columns[2:]]
labels = [(info_params[p] or {}).get(_p_label, p) for p in columns[2:]]
ranges = {p: get_range(info_params[p]) for p in info_params} # include fixed parameters not in columns
ranges = {p: get_range(info_params[p]) for p in
info_params} # include fixed parameters not in columns
renames = {p: info_params.get(p, {}).get(_p_renames, []) for p in columns[2:]}
samples = [c[c.data.columns[2:]].values for c in collections]
weights = [c[_weight].values for c in collections]
Expand All @@ -111,6 +114,10 @@ def MCSamplesFromCobaya(info, collections, name_tag=None,
settings=settings)


def str_to_list(x):
return [x] if isinstance(x, str) else x


def get_info_params(info):
"""
Extracts parameter info from the new yaml format.
Expand All @@ -128,9 +135,9 @@ def get_info_params(info):
remove = info.get(_post, {}).get("remove", {})
for param in remove.get(_params, []) or []:
info_params_full.pop(param, None)
for like in remove.get(_likelihood, []) or []:
for like in str_to_list(remove.get(_likelihood) or []):
likes.remove(like)
for prior in remove.get(_prior, []) or []:
for prior in str_to_list(remove.get(_prior)) or []:
priors.remove(prior)
add = info.get(_post, {}).get("add", {})
# Adding derived params and updating 1d priors
Expand Down Expand Up @@ -181,7 +188,8 @@ def get_range(param_info):
value = float(value)
except ValueError:
# e.g. lambda function values
lims = (lambda i: [i.get("min", -np.inf), i.get("max", np.inf)])(param_info or {})
lims = (lambda i: [i.get("min", -np.inf), i.get("max", np.inf)])(
param_info or {})
else:
lims = (value, value)
return lims[0] if lims[0] != -np.inf else None, lims[1] if lims[1] != np.inf else None
Expand Down Expand Up @@ -242,7 +250,8 @@ def expand_info_param(info_param):


def get_sampler_type(filename_or_info, default_sampler_for_chain_type="mcmc"):
sampler = list(yaml_file_or_dict(filename_or_info).get(_sampler, [default_sampler_for_chain_type]))[0]
sampler = list(yaml_file_or_dict(filename_or_info).get(_sampler, [
default_sampler_for_chain_type]))[0]
return {"mcmc": "mcmc", "polychord": "nested", "minimize": "minimize"}[sampler]


Expand Down
2 changes: 1 addition & 1 deletion getdist/inifile.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def expand_placeholders(self, s):
index = s.index(')')
var = s[:index]
if var in os.environ:
res = res + os.environ[var]
res += os.environ[var]
else:
res = res + c
index += 1
Expand Down

0 comments on commit 18bb14f

Please sign in to comment.