Skip to content

Commit

Permalink
weighted_thin commutes with chain combining
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed May 18, 2021
1 parent 704a20d commit 846ee48
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 21 deletions.
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ install:
fi
- python --version
- pip install .
- pip install PyYAML
- pip install PyYAML flake8
- getdist --help
- git clone --depth=1 https://github.com/cmbant/getdist_testchains

script:
- flake8 . --select=E9,F63,F7,F82 --show-source --statistics
- python -m unittest getdist.tests.getdist_test

deploy:
Expand Down
57 changes: 37 additions & 20 deletions getdist/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
from copy import deepcopy
from collections import namedtuple
from typing import Sequence, Any, Optional, Union
from typing import Sequence, Any, Optional, Union, List

# whether to write to terminal chain names and burn in details when loaded from file
print_load_details = True
Expand All @@ -26,14 +26,12 @@ class WeightedSampleError(Exception):
"""
An exception that is raised when a WeightedSamples error occurs
"""
pass


class ParamError(WeightedSampleError):
"""
An Exception that indicates a bad parameter.
"""
pass


def last_modified(files):
Expand Down Expand Up @@ -914,7 +912,7 @@ def random_single_samples_indices(self, random_state=None):
thin_ix.append(i)
return np.array(thin_ix, dtype=int)

def thin(self, factor):
def thin(self, factor: int):
"""
Thin the samples by the given factor, giving set of samples with unit weight
Expand All @@ -924,21 +922,16 @@ def thin(self, factor):
self.setSamples(self.samples[thin_ix, :], loglikes=None if self.loglikes is None else self.loglikes[thin_ix],
min_weight_ratio=-1)

def weighted_thin(self, factor):
def weighted_thin(self, factor: int):
"""
Thin the samples by the given factor, preserving the weights.
This function also preserves separate chains.
Thin the samples by the given factor, preserving the weights (not all set to 1).
:param factor: The (integer) factor to thin by
"""
unique, counts = self.thin_indices_and_weights(factor, self.weights)
self.setSamples(self.samples[unique, :],
loglikes=None if self.loglikes is None
else self.loglikes[unique],
weights=counts,
min_weight_ratio=-1)
if self.chain_offsets is not None:
self.chain_offsets = np.array([np.sum(unique < off)
for off in self.chain_offsets])
loglikes=None if self.loglikes is None else self.loglikes[unique],
weights=counts, min_weight_ratio=-1)

def filter(self, where):
"""
Expand Down Expand Up @@ -1080,6 +1073,7 @@ def __init__(self, root=None, jobItem=None, paramNamesFile=None, names=None, lab
"""

self.chains = None
self.chain_offsets = None
super().__init__(**kwargs)
self.jobItem = jobItem
self.ignore_lines = float(kwargs.get('ignore_rows', 0))
Expand Down Expand Up @@ -1134,7 +1128,7 @@ def filter(self, where):
"""

if self.chains is None:
if hasattr(self, 'chain_offsets'):
if self.chain_offsets is not None:
# must update chain_offsets to be able to correctly split back into separate filtered chains if needed
lens = [0]
for off1, off2 in zip(self.chain_offsets[:-1], self.chain_offsets[1:]):
Expand All @@ -1144,6 +1138,24 @@ def filter(self, where):
else:
raise ValueError('chains are separated, makeSingle first or call filter on individual chains')

def weighted_thin(self, factor: int):
"""
Thin the samples by the given factor, giving (in general) non-unit integer weights.
This function also preserves separate chains.
:param factor: The (integer) factor to thin by
"""
if not self.chains and self.chain_offsets is None:
return super().weighted_thin(factor)
has_chains = self.chains
chains = self.getSeparateChains()
for chain in chains:
chain.weighted_thin(factor)
self.chains = chains
if not has_chains:
self.makeSingle()
self.needs_update = True

def getParamNames(self):
"""
Get :class:`~.paramnames.ParamNames` object with names for the parameters
Expand Down Expand Up @@ -1265,7 +1277,7 @@ def _makeParamvec(self, par):
elif par == 'loglike':
return self.loglikes
else:
raise ValueError('Unknown parameter %s' % par)
raise ParamError('Unknown parameter %s' % par)
return super()._makeParamvec(par)

def updateChainBaseStatistics(self):
Expand Down Expand Up @@ -1422,6 +1434,8 @@ def makeSingle(self):
:return: self
"""
if not self.chains:
raise ValueError('There are no separated chains for makeSingle()')
self.chain_offsets = np.cumsum(np.array([0] + [chain.samples.shape[0] for chain in self.chains]))
weights = None if self.chains[0].weights is None else np.hstack([chain.weights for chain in self.chains])
loglikes = None if self.chains[0].loglikes is None else np.hstack([chain.loglikes for chain in self.chains])
Expand All @@ -1430,7 +1444,7 @@ def makeSingle(self):
self.needs_update = True
return self

def getSeparateChains(self):
def getSeparateChains(self) -> List['WeightedSamples']:
"""
Gets a list of samples for separate chains.
If the chains have already been combined, uses the stored sample offsets to reconstruct the array
Expand All @@ -1441,9 +1455,12 @@ def getSeparateChains(self):
if self.chains is not None:
return self.chains
chainlist = []
for off1, off2 in zip(self.chain_offsets[:-1], self.chain_offsets[1:]):
chainlist.append(WeightedSamples(samples=self.samples[off1:off2], weights=self.weights[off1:off2],
loglikes=self.loglikes[off1:off2]))
if self.chain_offsets is None:
raise WeightedSampleError('Samples were not combined from separate chains')
else:
for off1, off2 in zip(self.chain_offsets[:-1], self.chain_offsets[1:]):
chainlist.append(WeightedSamples(samples=self.samples[off1:off2], weights=self.weights[off1:off2],
loglikes=self.loglikes[off1:off2]))
return chainlist

def removeBurnFraction(self, ignore_frac):
Expand Down

0 comments on commit 846ee48

Please sign in to comment.