Skip to content

Commit

Permalink
Added support for change-points in serial transition models
Browse files Browse the repository at this point in the history
  • Loading branch information
christophmark committed Nov 16, 2016
1 parent 0bf13d8 commit 59b24c2
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 9 deletions.
2 changes: 1 addition & 1 deletion bayesloop/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def _unpackChangepointNames(self, transitionModel):

# extend hyper-parameter based on current (sub-)model
if hasattr(transitionModel, 'hyperParameterNames'):
if str(transitionModel) == 'Change-point model':
if str(transitionModel) == 'Change-point':
paramList.extend(transitionModel.hyperParameterNames)

return paramList
Expand Down
72 changes: 64 additions & 8 deletions bayesloop/transitionModels.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __init__(self, name='tChange', value=None, prior=None):
self.tOffset = 0 # is set to the time of the last Breakpoint by SerialTransition model

def __str__(self):
return 'Change-point model'
return 'Change-point'

def computeForwardPrior(self, posterior, t):
"""
Expand Down Expand Up @@ -635,10 +635,14 @@ class SerialTransitionModel:
"""
Different models act at different time steps. To model fundamental changes in parameter dynamics, different
transition models can be serially coupled. Depending on the time step, a corresponding sub-model is chosen to
compute the new prior distribution from the posterior distribution.
compute the new prior distribution from the posterior distribution. If a break-point lies in between two transition
models, the parameter values do not change abruptly at the time step of the break-point, whereas a change-point not
only changes the transition model, but also allows the parameters to change (the parameter distribution is re-set to
the prior distribution).
Args:
*args: Sequence of transition models and breakpoints (for n models, n-1 breakpoints have to be provided)
*args: Sequence of transition models and break-points/change-points (for n models, n-1
break-points/change-points have to be provided)
Example:
::
Expand All @@ -661,20 +665,39 @@ def __init__(self, *args):
self.hyperParameterValues = []
self.prior = []
self.models = []
self.changePointMask = []
for arg in args:
if str(arg) == 'Break-point': # definition of break-point
if str(arg) == 'Break-point':
self.hyperParameterNames.append(arg.name)
self.prior.append(arg.prior)

# exclude 'all' case, conversion to list is needed to avoid future warning about element-wise comparison
if isinstance(arg.value, str) and arg.value == 'all': # 'all' is passed without type change
self.hyperParameterValues.append(arg.value)
elif isinstance(arg.value, Iterable): # convert list/tuple in numpy array
self.hyperParameterValues.append(np.array(arg.value))
else: # single values are passed without type change
self.hyperParameterValues.append(arg.value)
self.changePointMask.append(0)
elif str(arg) == 'Change-point':
name = arg.hyperParameterNames[0]
value = arg.hyperParameterValues[0]
self.hyperParameterNames.append(name)
self.prior.append(arg.prior)

# exclude 'all' case, conversion to list is needed to avoid future warning about element-wise comparison
if isinstance(value, str) and value == 'all': # 'all' is passed without type change
self.hyperParameterValues.append(value)
elif isinstance(value, Iterable): # convert list/tuple in numpy array
self.hyperParameterValues.append(np.array(value))
else: # single values are passed without type change
self.hyperParameterValues.append(value)
self.changePointMask.append(1)
else: # sub-model
self.models.append(arg)

self.changePointMask = np.array(self.changePointMask).astype(np.bool)

# check: break times have to be passed in monotonically increasing order
# since multiple values can be passed for one break-point at init, we check first values only
firstValues = []
Expand All @@ -688,13 +711,13 @@ def __init__(self, *args):

if not all(x < y if not ((isinstance(x, str) and x == 'all') or (isinstance(y, str) and y == 'all')) else True
for x, y in zip(firstValues, firstValues[1:])):
raise ConfigurationError('Time steps for structural breaks have to be passed in monotonically increasing '
'order.')
raise ConfigurationError('Time steps for structural breaks and/or change-pointshave to be passed in '
'monotonically increasing order.')

# check: n models require n-1 break times
if not (len(self.models)-1 == len(self.hyperParameterValues)):
raise ConfigurationError('Wrong number of structural breaks/models. For n models, n-1 structural breaks '
'are required.')
raise ConfigurationError('Wrong number of structural breaks/change-points and models. For n models, n-1 '
'structural breaks/change-points are required.')

def __str__(self):
return 'Serial transition model'
Expand All @@ -717,6 +740,7 @@ def computeForwardPrior(self, posterior, t):
self.models[modelIndex].study = self.study # study needs to be propagated
self.models[modelIndex].tOffset = self.hyperParameterValues[modelIndex-1] if modelIndex > 0 else 0
newPrior = self.models[modelIndex].computeForwardPrior(posterior, t)
newPrior = self._forwardChangePointCheck(newPrior, t)
return newPrior

def computeBackwardPrior(self, posterior, t):
Expand All @@ -727,8 +751,40 @@ def computeBackwardPrior(self, posterior, t):
self.models[modelIndex].study = self.study # study needs to be propagated
self.models[modelIndex].tOffset = self.hyperParameterValues[modelIndex-1] if modelIndex > 0 else 0
newPrior = self.models[modelIndex].computeBackwardPrior(posterior, t)
newPrior = self._backwardChangePointCheck(newPrior, t)
return newPrior

def _forwardChangePointCheck(self, posterior, t):
"""
This function checks if a change-point is set to the current time step and replaces the posterior with the prior
distribution, just like the change-point transition model. This allows to use change-points in serial transition
models.
Args:
posterior(ndarray): Parameter distribution from current time step
t(int): integer time step
Returns:
ndarray: Prior parameter distribution for subsequent time step
"""
if t in np.array(self.hyperParameterValues)[self.changePointMask]:
# check if custom prior is used by observation model
if hasattr(self.study.observationModel.prior, '__call__'):
prior = self.study.observationModel.prior(*self.study.grid)
elif isinstance(self.study.observationModel.prior, np.ndarray):
prior = deepcopy(self.study.observationModel.prior)
else:
prior = np.ones(self.study.gridSize) # flat prior

# normalize prior (necessary in case an improper prior is used)
prior /= np.sum(prior)
return prior
else:
return posterior

def _backwardChangePointCheck(self, posterior, t):
return self._forwardChangePointCheck(posterior, t - 1)


class BreakPoint:
"""
Expand Down

0 comments on commit 59b24c2

Please sign in to comment.