In [1]:
#!/usr/bin/env python3
"""
Hierarchical Dirichlet Process Hidden Markov Model (HDPHMM).
The HDPHMM object collects a number of observed emission sequences, and estimates
latent states at every time point, along with a probability structure that ties latent
states to emissions. This structure involves
  + A starting probability, which dictates the probability that the first state
  in a latent seqeuence is equal to a given symbol. This has a hierarchical Dirichlet
  prior.
  + A transition probability, which dictates the probability that any given symbol in
  the latent sequence is followed by another given symbol. This shares the same
  hierarchical Dirichlet prior as the starting probabilities.
  + An emission probability, which dictates the probability that any given emission
  is observed conditional on the latent state at the same time point. This uses a
  Dirichlet prior.
Fitting HDPHMMs requires MCMC estimation. MCMC estimation is thus used to calculate the
posterior distribution for the above probabilities. In addition, we can use MAP
estimation (for example) to fix latent states, and facilitate further analysis of a
Chain.
"""
# Support typehinting.
from __future__ import annotations
from typing import Any, Union, Optional, Set, Dict, Iterable, List, Callable, Generator

import numpy as np
import random
import copy
import terminaltables
import tqdm
import functools
import multiprocessing
import string
from scipy import special, stats
from sympy.functions.combinatorial.numbers import stirling
from chain import Chain
from utils import label_generator, dirichlet_process_generator, shrink_probabilities
from warnings import catch_warnings

In [2]:
# Shorthand for numeric types.
Numeric = Union[int, float]

# Oft-used dictionary initializations with shorthands.
DictStrNum = Dict[Optional[str], Numeric]
InitDict = DictStrNum
DictStrDictStrNum = Dict[Optional[str], DictStrNum]
NestedInitDict = DictStrDictStrNum

In [3]:
NestedInitDict

typing.Dict[typing.Union[str, NoneType], typing.Dict[typing.Union[str, NoneType], typing.Union[int, float]]]

In [4]:
emission_sequences = [[7,6,3,53,45,8,75,109],[7,45,1,8,7,6,2,67]]
emissions=None

chains = [Chain(sequence) for sequence in emission_sequences]

priors = {
            "alpha": lambda: np.random.gamma(2, 2),
            "gamma": lambda: np.random.gamma(3, 3),
            "alpha_emission": lambda: np.random.gamma(2, 2),
            "gamma_emission": lambda: np.random.gamma(3, 3),
            "kappa": lambda: np.random.beta(1, 1)}

hyperparameters = {param: prior() for param, prior in priors.items()}

n_initial: InitDict
n_emission: NestedInitDict
n_transition: NestedInitDict
n_initial = {None: 0}
n_emission = {None: {None: 0}}
n_transition = {None: {None: 0}}

p_initial: InitDict
p_emission: NestedInitDict
p_transition: NestedInitDict
p_initial = {None: 1}
p_emission = {None: {None: 1}}
p_transition = {None: {None: 1}}

auxiliary_transition_variables: NestedInitDict
beta_transition: InitDict
beta_emission: InitDict
auxiliary_transition_variables = {None: {None: 0}}
beta_transition = {None: 1}
beta_emission = {None: 1}

if emissions is None:
        emissions = functools.reduce(  # type: ignore
                set.union, (set(c.emission_sequence) for c in chains), set()
            )
elif not isinstance(emissions, set):
    raise ValueError("emissions must be a set")
    emissions = emissions  # type: ignore
    states: Set[Optional[str]] = set()
        
_label_generator = label_generator(string.ascii_lowercase)


def c() -> int:
    """
    Number of chains in the HMM.
    :return: int
    """
    return len(chains)
    

def k() -> int:
    """
    Number of latent states in the HMM currently.
    :return: int
    """
    return len(states)
    

def n() -> int:
    """
    Number of unique emissions. If `emissions` was specified when the HDPHMM was
    created, then this counts the number of elements in `emissions`. Otherwise,
    counts the number of observed emissions across all emission sequences.
    :return: int
    """
    return len(emissions)  

eps = 1e-02


In [5]:
base_sequence = [[7,6,3,53,45,8,75,109],[7,45,1,8,7,6,2,67]]

In [6]:
import bayesian_hmm
sequences = [[7,6,3,53,45,8,75,109],[7,45,1,8,7,6,2,67]]
sequences = [[1,2,3],[4,5,6]]

hmm = bayesian_hmm.HDPHMM(sequences, sticky=False)

hmm.initialise(k=10)


  "bayesian_hmm is in beta testing and future versions may behave differently"


In [10]:
print("states: ", hmm.states, "\n")

print("n_initial ", hmm.n_initial, "\n")

print("n_transition: ", hmm.n_transition, "\n")

print("n_emission: ", hmm.n_emission, "\n")

print("hyperparams: ", hmm.hyperparameters, "\n")

print("p_initial: ", hmm.p_initial, "\n")
print("p_emission: ", hmm.p_emission, "\n")
print("p_transition: ", hmm.p_transition, "\n")

print("auxiliary_transition_variables: ", hmm.auxiliary_transition_variables, "\n")
print("beta_transition: ", hmm.beta_transition, "\n")
print("beta_emission: ", hmm.beta_emission, "\n")
print("alpha_emission: ", hmm.beta_emission, "\n")

print("chain object: ", hmm.chains, "\n")

print("latent sequence: ", hmm.chains[0].latent_sequence, "\n")

print("number of chains: ", hmm.c, "\n")

print("length of chains: ", hmm.n, "\n")

print("tabulate: ", hmm.tabulate(), "\n")







states:  {'h', 'i', 'g', 'c', 'f', 'b', 'e', 'd', 'j', 'a'} 

n_initial  {'h': 1, 'i': 0, None: 0, 'g': 0, 'c': 1, 'f': 0, 'b': 0, 'e': 0, 'd': 0, 'j': 0, 'a': 0} 

n_transition:  {'h': {'h': 0, 'i': 0, None: 0, 'g': 0, 'c': 0, 'f': 0, 'b': 0, 'e': 0, 'd': 1, 'j': 0, 'a': 0}, 'i': {'h': 0, 'i': 0, None: 0, 'g': 0, 'c': 0, 'f': 0, 'b': 0, 'e': 0, 'd': 0, 'j': 0, 'a': 0}, None: {'h': 0, 'i': 0, None: 0, 'g': 0, 'c': 0, 'f': 0, 'b': 0, 'e': 0, 'd': 0, 'j': 0, 'a': 0}, 'g': {'h': 0, 'i': 0, None: 0, 'g': 0, 'c': 0, 'f': 0, 'b': 0, 'e': 0, 'd': 0, 'j': 0, 'a': 0}, 'c': {'h': 0, 'i': 0, None: 0, 'g': 0, 'c': 0, 'f': 0, 'b': 0, 'e': 0, 'd': 1, 'j': 0, 'a': 0}, 'f': {'h': 0, 'i': 0, None: 0, 'g': 0, 'c': 0, 'f': 0, 'b': 0, 'e': 0, 'd': 0, 'j': 0, 'a': 0}, 'b': {'h': 0, 'i': 0, None: 0, 'g': 0, 'c': 0, 'f': 0, 'b': 0, 'e': 0, 'd': 0, 'j': 0, 'a': 0}, 'e': {'h': 0, 'i': 0, None: 0, 'g': 0, 'c': 0, 'f': 0, 'b': 0, 'e': 0, 'd': 0, 'j': 0, 'a': 0}, 'd': {'h': 0, 'i': 0, None: 0, 'g': 0, 'c': 0, 'f'

In [36]:
next(hmm._label_generator)

'g1'

In [39]:
{param: prior() for param, prior in hmm.priors.items()}

{'alpha': 0.8538839655070956,
 'gamma': 18.492683631680094,
 'alpha_emission': 1.0899950209165887,
 'gamma_emission': 20.0570355275628,
 'kappa': 0}

In [48]:
sorted(np.random.dirichlet([18.492683631680094 / (10 + 1)] * (10 + 1)),reverse=True,)
[18.492683631680094 / (10 + 1)] * (10 + 1)

[1.6811530574254632,
 1.6811530574254632,
 1.6811530574254632,
 1.6811530574254632,
 1.6811530574254632,
 1.6811530574254632,
 1.6811530574254632,
 1.6811530574254632,
 1.6811530574254632,
 1.6811530574254632,
 1.6811530574254632]

In [21]:
for chain in hmm.chains:
    for t in range(chain.T - 1):
        print(chain.latent_sequence[t])
        print(chain.emission_sequence[t])
    

h
1
d
2
c
4
d
5


In [23]:
set(sorted(functools.reduce(set.union, (set(c.latent_sequence) for c in hmm.chains), set())))

{'c', 'd', 'e', 'f', 'h'}

In [32]:
set(sorted(functools.reduce(set.union, (set(c.latent_sequence) for c in hmm.chains), set())))

{'c', 'd', 'e', 'f', 'h'}

In [34]:
zxcv = {"asd":2,"qwer":3}
zxcv.pop("asd")

2

In [40]:
hmm.priors['alpha']()

7.416536776226354

In [62]:
from scipy import special
#help(special.gamma)
special.gamma(1)

1.0

In [9]:
hmm.states.union({"a1",None})

{None, 'a', 'a1', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'}

In [42]:
hmm.p_initial['h']

0.08372070135868223

In [43]:
for t in range(10 - 2, -1, -1):
    print(t)

8
7
6
5
4
3
2
1
0


In [6]:
k = 20
"""
Initialise the HDPHMM. This involves:
+ Choosing starting values for all hyperparameters
+ Initialising all Chains (see Chain.initialise for further info)
+ Initialising priors for probabilities (i.e. the Hierarchical priors)
+ Updating all counts
        
sampling latent states, auxiliary beam variables,
Typically called directly from a HDPHMM object.
:param k: number of symbols to sample from for latent states
:return: None
"""
    
# create as many states as needed
states = [next(_label_generator) for _ in range(k)]
states = set(states)

# set hyperparameters
hyperparameters = {param: prior() for param, prior in priors.items()}

# initialise chains
for c in chains:
    c.initialise(states)

# initialise hierarchical priors
temp_beta = sorted(
    np.random.dirichlet(
                [self.hyperparameters["gamma"] / (self.k + 1)] * (self.k + 1)
            ),
            reverse=True,
        )
beta_transition = dict(zip(list(self.states) + [None], temp_beta))
beta_transition = shrink_probabilities(beta_transition)
auxiliary_transition_variables = {
            s1: {s2: 1 for s2 in self.states.union({None})}
            for s1 in self.states.union({None})
        }

# update counts before resampling
_n_update()

# resample remaining hyperparameters
resample_beta_transition()
resample_beta_emission()
resample_p_initial()
resample_p_transition()
resample_p_emission()

# set initialised flag
_initialised = True

TypeError: 'set' object is not subscriptable

In [6]:

print("states: ", hmm.states)

print("gamma: ", hyperparameters["gamma"])

print("n_initial: ", n_initial) # number of times a state is the initial state of the sequence
print("n_transition: ", n_transition[label]) # 
print("n_emission: ", n_emission) # the number of times a particular emission belongs to each state

print("p_initial: ",p_initial)
print("p_emission: ", p_emission)
print("p_transition: ",p_transition)

print("auxiliary_transition_variables: ", auxiliary_transition_variables)
print("beta_transition: ", beta_transition)
print("beta_emission", beta_emission)



NameError: name 'states' is not defined

In [156]:
# update emission probabilities
### don't understand the draws from dirichlet
temp_p_emission = np.random.dirichlet([hyperparameters["alpha"] * beta_emission[e] for e in emissions])


### e.g.: array([1.])
p_emission[label] = dict(zip(emissions, temp_p_emission))
### e.g.: {None: {None: 1}, 'a1': {None: 0.9999999999999999}}
        
            
# save label as a state 
states = states.union({label})

KeyError: 1

In [162]:
beta_emission

{None: 1}

In [98]:
class HDPHMM(object):
    """
    The Hierarchical Dirichlet Process Hidden Markov Model object. In fact, this is a
    sticky-HDPHMM, since we allow a biased self-transition probability.
    """
    
    def __init__(
        self,
        emission_sequences: Iterable[List[Optional[str]]],
        emissions=None,  # type: ignore
        # emissions: Optional[Iterable[Union[str, int]]] = None # ???
        sticky: bool = True,
        priors: Dict[str, Callable[[], Any]] = None,
    ) -> None:
        """
        Create a Hierarchical Dirichlet Process Hidden Markov Model object, which can
        (optionally) be sticky. The emission sequences must be provided, although all
        other parameters are initialised with reasonable default values. It is
        recommended to specify the `sticky` parameter, depending on whether you believe
        the HMM to have a high probability of self-transition.
        
        :param emission_sequences: iterable, containing the observed emission sequences.
        emission sequences can be different lengths, or zero length.
        
        :param emissions: set, optional. If not all emissions are guaranteed to be
        observed in the data, this can be used to force non-zero emission probabilities
        for unobserved emissions.
        
        :param sticky: bool, flag to indicate whether the HDPHMM is sticky or not.
        Sticky HDPHMMs have an additional value (kappa) added to the probability of self
        transition. It is recommended to set this depending on the knowledge of the
        problem at hand.
        
        :param priors: dict, containing priors for the model hyperparameters. Priors
        should be functions with zero arguments. The following priors are accepted:
          + alpha: prior distribution of the alpha parameter. Alpha
            parameter is the value used in the hierarchical Dirichlet prior for
            transitions and starting probabilities. Higher values of alpha keep rows of
            the transition matrix more similar to the beta parameters.
          + gamma: prior distribution of the gamma parameter. Gamma controls the
            strength of the uninformative prior in the starting and transition
            distributions. Hence, it impacts the likelihood of resampling unseen states
            when estimating beta coefficients. That is, higher values of gamma mean the
            HMM is more likely to explore new states when resampling.
          + alpha_emission: prior distribution of the alpha parameter for the
            emission prior distribution. Alpha controls how tightly the conditional
            emission distributions follow their hierarchical prior. Hence, higher values
            of alpha_emission mean more strength in the hierarchical prior.
          + gamma_emission: prior distribution of the gamma parameter for the
            emission prior distribution. Gamma controls the strength of the
            uninformative prior in the emission distribution. Hence, higher values of
            gamma mean more strength of belief in the prior.
          + kappa: prior distribution of the kappa parameter for the
            self-transition probability. Ignored if `sticky==False`. Kappa prior should
            have support in (0, 1) only. Higher values of kappa mean the chain is more
            likely to explore states with high self-transition probabilty.
        """
        
        
        # store chains
        self.chains = [Chain(sequence) for sequence in emission_sequences]
        
        # sticky flag
        if type(sticky) is not bool:
            raise ValueError("`sticky` must be type bool")
        self.sticky = sticky
        
        # store hyperparameter priors as callables
        self.priors = {
            "alpha": lambda: np.random.gamma(2, 2),
            "gamma": lambda: np.random.gamma(3, 3),
            "alpha_emission": lambda: np.random.gamma(2, 2),
            "gamma_emission": lambda: np.random.gamma(3, 3),
            "kappa": lambda: np.random.beta(1, 1),
        }
        # update prior params if given
        if priors is not None:
            self.priors.update(priors)
        if len(self.priors) > 5:
            raise ValueError("Unknown hyperparameter priors present")
            
        # set kappa prior to one if not sticky    
        if not self.sticky:
            self.priors["kappa"] = lambda: 0
            if priors is not None and "kappa" in priors:
                raise ValueError("`sticky` is False, but kappa prior function given")
                
        # store initial hyperparameter values using existing callables
        self.hyperparameters = {param: prior() for param, prior in self.priors.items()}
        
        # use internal properties to store fit hyperparameters

        self.n_initial: InitDict
        self.n_emission: NestedInitDict
        self.n_transition: NestedInitDict
        self.n_initial = {None: 0}
        self.n_emission = {None: {None: 0}}
        self.n_transition = {None: {None: 0}}
        
        # use internal properties to store current state for probabilities
        self.p_initial: InitDict
        self.p_emission: NestedInitDict
        self.p_transition: NestedInitDict
        self.p_initial = {None: 1}
        self.p_emission = {None: {None: 1}}
        self.p_transition = {None: {None: 1}}
        
        # store derived hyperparameters
        self.auxiliary_transition_variables: NestedInitDict
        self.beta_transition: InitDict
        self.beta_emission: InitDict
        self.auxiliary_transition_variables = {None: {None: 0}}
        self.beta_transition = {None: 1}
        self.beta_emission = {None: 1}
        
        # states & emissions
        # TODO: figure out emissions's type...
        if emissions is None:
            emissions = functools.reduce(  # type: ignore
                set.union, (set(c.emission_sequence) for c in self.chains), set()
            )
        elif not isinstance(emissions, set):
            raise ValueError("emissions must be a set")
        self.emissions = emissions  # type: ignore
        self.states: Set[Optional[str]] = set()

        # generate non-repeating character labels for latent states
        self._label_generator = label_generator(string.ascii_lowercase)

        # keep flag to track initialisation
        self._initialised = False
        
    @property
    def initialised(self) -> bool:
        """
        Test whether a HDPHMM is initialised.
        :return: bool
        """
        return self._initialised
    
    @initialised.setter
    def initialised(self, value: Any) -> None:
        if value:
            raise AssertionError("HDPHMM must be initialised through initialise method")
        elif not value:
            self._initialised = False
        else:
            raise ValueError("initialised flag must be Boolean")
            
    @property
    def c(self) -> int:
        """
        Number of chains in the HMM.
        :return: int
        """
        return len(self.chains)
    
    
    @property
    def k(self) -> int:
        """
        Number of latent states in the HMM currently.
        :return: int
        """
        return len(self.states)
    
    @property
    def n(self) -> int:
        """
        Number of unique emissions. If `emissions` was specified when the HDPHMM was
        created, then this counts the number of elements in `emissions`. Otherwise,
        counts the number of observed emissions across all emission sequences.
        :return: int
        """
        return len(self.emissions)
    
    
    def tabulate(self) -> np.array:
        """
        Convert the latent and emission sequences for all chains into a single numpy
        array. Array contains an index which matches a Chain's index in
        HDPHMM.chains, the current latent state, and the emission for all chains at
        all times.
        :return: numpy array with dimension (l, 3), where l is the length of the Chain
        """
        hmm_array = np.concatenate(
            tuple(
                np.concatenate(
                    (np.array([[n] * self.chains[n].T]).T, self.chains[n].tabulate()),
                    axis=1,
                )
                for n in range(self.c)
            ),
            axis=0,
        )
        return hmm_array
    
    def __repr__(self) -> str:
        return "<bayesian_hmm.HDPHMM, size {C}>".format(C=self.c)
    
    def __str__(self, print_len: int = 15) -> str:
        fs = (
            "bayesian_hmm.HDPHMM,"
            + " ({C} chains, {K} states, {N} emissions, {Ob} observations)"
        )
        return fs.format(C=self.c, K=self.k, N=self.n, Ob=sum(c.T for c in self.chains))
    
    
    def state_generator(self, eps: Numeric = 1e-12) -> Generator[str, None, None]:
        """
        Create a new state for the HDPHMM, and update all parameters accordingly.
        This involves updating
          + The counts for the new symbol
          + The auxiliary variables for the new symbol
          + The probabilities for the new symbol
          + The states captured by the HDPHMM
        :return: str, label of the new state
        """
        while True:
            
            """
            self.'n' s
            """
            
            label = next(self._label_generator) # generate label for state name

            # update counts with zeros (assume _n_update called later)
            # state irrelevant for constant count (all zeros)
            self.n_initial[label] = 0 # set n_initial value corresponding to label to zero
            self.n_transition[label] = {s: 0 for s in self.states.union({label, None})} # n_transitions for new state
            for s in self.states: # could probably use same format as line above
                self.n_transition[s].update({label: 0}) # n_transition for new state to 0
            self.n_emission[label] = {e: 0 for e in self.emissions} # n_emissions for new state to 0

            # update auxiliary transition variables
            self.auxiliary_transition_variables[label] = { #### don't quite undertand aux transition variables
                s2: 1 for s2 in list(self.states) + [label, None]
            }
            for s1 in self.states:
                self.auxiliary_transition_variables[s1][label] = 1
            """
            self.beta_transition
            """
            # update beta_transition value and split out from current pseudo state
            temp_beta = np.random.beta(1, self.hyperparameters["gamma"]) # use prior gamma to generate value from beta
            self.beta_transition[label] = temp_beta * self.beta_transition[None] # Stick breaking: None is dict key
            self.beta_transition[None] = (1 - temp_beta) * self.beta_transition[None] # update value for None key using 1-p
            
            """
            self.p_initial
            """
            # update starting probability 
            ### same thing as before
            temp_p_initial = np.random.beta(1, self.hyperparameters["gamma"])  # use prior gamma to generate value from beta
            self.p_initial[label] = temp_p_initial * self.p_initial[None] # stick breaking: break up temp_p_initial
            self.p_initial[None] = (1 - temp_p_initial) * self.p_initial[None] # and then isolate the rest
            
            """
            self.p_transition
            """
            
            # update transition from new state
            ### draw from dirichlet dist using betas generated from stick breaking
            temp_p_transition = np.random.dirichlet([self.beta_transition[s] for s in list(self.states) + [label, None]])
            p_transition_label = dict(zip(list(self.states) + [label, None], temp_p_transition)) # e.g.: {'a1': 0.9781489319558994, None: 0.021851068044100645}
            self.p_transition[label] = shrink_probabilities(p_transition_label, eps) ### shrink resulting probs ever so slightly so that they don't quite sum up to 1
                                    ### e.g. {None: {None: 1}, 'a1': {'a1': 0.9781489319549431, None: 0.021851068045056942}}
                
                
                
            # update transitions into new state
            for state in self.states.union({None}): # (note that label not included in self.states)
                
                
                ### more stick breaking
                temp_p_transition = np.random.beta(1, self.hyperparameters["gamma"])
                self.p_transition[state][label] = (self.p_transition[state][None] * temp_p_transition)
                self.p_transition[state][None] = self.p_transition[state][None] * (1 - temp_p_transition)

                ### {None: {None: 0.07818045843222077, 'a1': 0.9218195415677792}, 'a1': {'a1': 0.9781489319549431, None: 0.021851068045056942}}
                
            # update emission probabilities
            ### don't understand the draws from dirichlet
            temp_p_emission = np.random.dirichlet([self.hyperparameters["alpha"] * self.beta_emission[e] for e in self.emissions])
                ### e.g.: array([1.])
            self.p_emission[label] = dict(zip(self.emissions, temp_p_emission))
                ### e.g.: {None: {None: 1}, 'a1': {None: 0.9999999999999999}}
        
            
            # save label as a state 
            self.states = self.states.union({label})

            yield label
    
    
    
    
    
    
    
            
            
    
    
        

In [121]:
base_sequence = [[7,6,3,53,45,8,75,109],[7,45,1,8,7,6,2,67]]

In [102]:
h = HDPHMM(emission_sequence)

In [61]:
asdf.priors['alpha']

<function __main__.HDPHMM.__init__.<locals>.<lambda>()>

In [69]:
zxcv = lambda: np.random.gamma(2, 2)

In [91]:
states = set()

In [92]:
states.union({"thing",None})

{None, 'thing'}

In [74]:
priors = {
            "alpha": lambda: np.random.gamma(2, 2),
            "gamma": lambda: np.random.gamma(3, 3),
            "alpha_emission": lambda: np.random.gamma(2, 2),
            "gamma_emission": lambda: np.random.gamma(3, 3),
            "kappa": lambda: np.random.beta(1, 1),
}

In [75]:
priors

{'alpha': <function __main__.<lambda>()>,
 'gamma': <function __main__.<lambda>()>,
 'alpha_emission': <function __main__.<lambda>()>,
 'gamma_emission': <function __main__.<lambda>()>,
 'kappa': <function __main__.<lambda>()>}

In [76]:
{param: prior() for param, prior in priors.items()}

{'alpha': 1.0819281624774297,
 'gamma': 2.1454465997108616,
 'alpha_emission': 1.681501928063262,
 'gamma_emission': 4.761217109850887,
 'kappa': 0.13788877309004735}

In [117]:
#!/usr/bin/env python3
"""
Hierarchical Dirichlet Process Hidden Markov Model (HDPHMM).
The HDPHMM object collects a number of observed emission sequences, and estimates
latent states at every time point, along with a probability structure that ties latent
states to emissions. This structure involves
  + A starting probability, which dictates the probability that the first state
  in a latent seqeuence is equal to a given symbol. This has a hierarchical Dirichlet
  prior.
  + A transition probability, which dictates the probability that any given symbol in
  the latent sequence is followed by another given symbol. This shares the same
  hierarchical Dirichlet prior as the starting probabilities.
  + An emission probability, which dictates the probability that any given emission
  is observed conditional on the latent state at the same time point. This uses a
  Dirichlet prior.
Fitting HDPHMMs requires MCMC estimation. MCMC estimation is thus used to calculate the
posterior distribution for the above probabilities. In addition, we can use MAP
estimation (for example) to fix latent states, and facilitate further analysis of a
Chain.
"""
# Support typehinting.
from __future__ import annotations
from typing import Any, Union, Optional, Set, Dict, Iterable, List, Callable, Generator

import numpy as np
import random
import copy
import terminaltables
import tqdm
import functools
import multiprocessing
import string
from scipy import special, stats
from sympy.functions.combinatorial.numbers import stirling
from chain import Chain
from utils import label_generator, dirichlet_process_generator, shrink_probabilities
from warnings import catch_warnings

# Shorthand for numeric types.
Numeric = Union[int, float]

# Oft-used dictionary initializations with shorthands.
DictStrNum = Dict[Optional[str], Numeric]
InitDict = DictStrNum
DictStrDictStrNum = Dict[Optional[str], DictStrNum]
NestedInitDict = DictStrDictStrNum


class HDPHMM(object):
    """
    The Hierarchical Dirichlet Process Hidden Markov Model object. In fact, this is a
    sticky-HDPHMM, since we allow a biased self-transition probability.
    """

    def __init__(
        self,
        emission_sequences: Iterable[List[Optional[str]]],
        emissions=None,  # type: ignore
        # emissions: Optional[Iterable[Union[str, int]]] = None # ???
        sticky: bool = True,
        priors: Dict[str, Callable[[], Any]] = None,
    ) -> None:
        """
        Create a Hierarchical Dirichlet Process Hidden Markov Model object, which can
        (optionally) be sticky. The emission sequences must be provided, although all
        other parameters are initialised with reasonable default values. It is
        recommended to specify the `sticky` parameter, depending on whether you believe
        the HMM to have a high probability of self-transition.
        
        :param emission_sequences: iterable, containing the observed emission sequences.
        emission sequences can be different lengths, or zero length.
        
        :param emissions: set, optional. If not all emissions are guaranteed to be
        observed in the data, this can be used to force non-zero emission probabilities
        for unobserved emissions.
        
        :param sticky: bool, flag to indicate whether the HDPHMM is sticky or not.
        Sticky HDPHMMs have an additional value (kappa) added to the probability of self
        transition. It is recommended to set this depending on the knowledge of the
        problem at hand.
        
        :param priors: dict, containing priors for the model hyperparameters. Priors
        should be functions with zero arguments. The following priors are accepted:
          + alpha: prior distribution of the alpha parameter. Alpha
            parameter is the value used in the hierarchical Dirichlet prior for
            transitions and starting probabilities. Higher values of alpha keep rows of
            the transition matrix more similar to the beta parameters.
          + gamma: prior distribution of the gamma parameter. Gamma controls the
            strength of the uninformative prior in the starting and transition
            distributions. Hence, it impacts the likelihood of resampling unseen states
            when estimating beta coefficients. That is, higher values of gamma mean the
            HMM is more likely to explore new states when resampling.
          + alpha_emission: prior distribution of the alpha parameter for the
            emission prior distribution. Alpha controls how tightly the conditional
            emission distributions follow their hierarchical prior. Hence, higher values
            of alpha_emission mean more strength in the hierarchical prior.
          + gamma_emission: prior distribution of the gamma parameter for the
            emission prior distribution. Gamma controls the strength of the
            uninformative prior in the emission distribution. Hence, higher values of
            gamma mean more strength of belief in the prior.
          + kappa: prior distribution of the kappa parameter for the
            self-transition probability. Ignored if `sticky==False`. Kappa prior should
            have support in (0, 1) only. Higher values of kappa mean the chain is more
            likely to explore states with high self-transition probabilty.
        """
        # store chains
        self.chains = [Chain(sequence) for sequence in emission_sequences]

        # sticky flag
        if type(sticky) is not bool:
            raise ValueError("`sticky` must be type bool")
        self.sticky = sticky

        # store hyperparameter priors
        self.priors = {
            "alpha": lambda: np.random.gamma(2, 2),
            "gamma": lambda: np.random.gamma(3, 3),
            "alpha_emission": lambda: np.random.gamma(2, 2),
            "gamma_emission": lambda: np.random.gamma(3, 3),
            "kappa": lambda: np.random.beta(1, 1),
        }
        if priors is not None:
            self.priors.update(priors)
        if len(self.priors) > 5:
            raise ValueError("Unknown hyperparameter priors present")

        if not self.sticky:
            self.priors["kappa"] = lambda: 0
            if priors is not None and "kappa" in priors:
                raise ValueError("`sticky` is False, but kappa prior function given")

        # store initial hyperparameter values
        self.hyperparameters = {param: prior() for param, prior in self.priors.items()}

        # use internal properties to store fit hyperparameters
        self.n_initial: InitDict
        self.n_emission: NestedInitDict
        self.n_transition: NestedInitDict
        self.n_initial = {None: 0}
        self.n_emission = {None: {None: 0}}
        self.n_transition = {None: {None: 0}}

        # use internal properties to store current state for probabilities
        self.p_initial: InitDict
        self.p_emission: NestedInitDict
        self.p_transition: NestedInitDict
        self.p_initial = {None: 1}
        self.p_emission = {None: {None: 1}}
        self.p_transition = {None: {None: 1}}

        # store derived hyperparameters
        self.auxiliary_transition_variables: NestedInitDict
        self.beta_transition: InitDict
        self.beta_emission: InitDict
        self.auxiliary_transition_variables = {None: {None: 0}}
        self.beta_transition = {None: 1}
        self.beta_emission = {None: 1}

        # states & emissions
        # TODO: figure out emissions's type...
        if emissions is None:
            emissions = functools.reduce(  # type: ignore
                set.union, (set(c.emission_sequence) for c in self.chains), set()
            )
        elif not isinstance(emissions, set):
            raise ValueError("emissions must be a set")
        self.emissions = emissions  # type: ignore
        self.states: Set[Optional[str]] = set()

        # generate non-repeating character labels for latent states
        self._label_generator = label_generator(string.ascii_lowercase)

        # keep flag to track initialisation
        self._initialised = False

    @property
    def initialised(self) -> bool:
        """
        Test whether a HDPHMM is initialised.
        :return: bool
        """
        return self._initialised

    @initialised.setter
    def initialised(self, value: Any) -> None:
        if value:
            raise AssertionError("HDPHMM must be initialised through initialise method")
        elif not value:
            self._initialised = False
        else:
            raise ValueError("initialised flag must be Boolean")

    @property
    def c(self) -> int:
        """
        Number of chains in the HMM.
        :return: int
        """
        return len(self.chains)

    @property
    def k(self) -> int:
        """
        Number of latent states in the HMM currently.
        :return: int
        """
        return len(self.states)

    @property
    def n(self) -> int:
        """
        Number of unique emissions. If `emissions` was specified when the HDPHMM was
        created, then this counts the number of elements in `emissions`. Otherwise,
        counts the number of observed emissions across all emission sequences.
        :return: int
        """
        return len(self.emissions)

    def tabulate(self) -> np.array:
        """
        Convert the latent and emission sequences for all chains into a single numpy
        array. Array contains an index which matches a Chain's index in
        HDPHMM.chains, the current latent state, and the emission for all chains at
        all times.
        :return: numpy array with dimension (l, 3), where l is the length of the Chain
        """
        hmm_array = np.concatenate(
            tuple(
                np.concatenate(
                    (np.array([[n] * self.chains[n].T]).T, self.chains[n].tabulate()),
                    axis=1,
                )
                for n in range(self.c)
            ),
            axis=0,
        )
        return hmm_array

    def __repr__(self) -> str:
        return "<bayesian_hmm.HDPHMM, size {C}>".format(C=self.c)

    def __str__(self, print_len: int = 15) -> str:
        fs = (
            "bayesian_hmm.HDPHMM,"
            + " ({C} chains, {K} states, {N} emissions, {Ob} observations)"
        )
        return fs.format(C=self.c, K=self.k, N=self.n, Ob=sum(c.T for c in self.chains))

    def state_generator(self, eps: Numeric = 1e-12) -> Generator[str, None, None]:
        """
        Create a new state for the HDPHMM, and update all parameters accordingly.
        This involves updating
          + The counts for the new symbol
          + The auxiliary variables for the new symbol
          + The probabilities for the new symbol
          + The states captured by the HDPHMM
        :return: str, label of the new state
        """
        while True:
            label = next(self._label_generator)

            # update counts with zeros (assume _n_update called later)
            # state irrelevant for constant count (all zeros)
            self.n_initial[label] = 0
            self.n_transition[label] = {s: 0 for s in self.states.union({label, None})}
            for s in self.states:
                self.n_transition[s].update({label: 0})
            self.n_emission[label] = {e: 0 for e in self.emissions}

            # update auxiliary transition variables
            self.auxiliary_transition_variables[label] = {
                s2: 1 for s2 in list(self.states) + [label, None]
            }
            for s1 in self.states:
                self.auxiliary_transition_variables[s1][label] = 1

            # update beta_transition value and split out from current pseudo state
            temp_beta = np.random.beta(1, self.hyperparameters["gamma"])
            self.beta_transition[label] = temp_beta * self.beta_transition[None]
            self.beta_transition[None] = (1 - temp_beta) * self.beta_transition[None]

            # update starting probability
            temp_p_initial = np.random.beta(1, self.hyperparameters["gamma"])
            self.p_initial[label] = temp_p_initial * self.p_initial[None]
            self.p_initial[None] = (1 - temp_p_initial) * self.p_initial[None]

            # update transition from new state
            temp_p_transition = np.random.dirichlet(
                [self.beta_transition[s] for s in list(self.states) + [label, None]]
            )
            p_transition_label = dict(
                zip(list(self.states) + [label, None], temp_p_transition)
            )
            self.p_transition[label] = shrink_probabilities(p_transition_label, eps)

            # update transitions into new state
            for state in self.states.union({None}):
                # (note that label not included in self.states)
                temp_p_transition = np.random.beta(1, self.hyperparameters["gamma"])
                self.p_transition[state][label] = (
                    self.p_transition[state][None] * temp_p_transition
                )
                self.p_transition[state][None] = self.p_transition[state][None] * (
                    1 - temp_p_transition
                )

            # update emission probabilities
            temp_p_emission = np.random.dirichlet(
                [
                    self.hyperparameters["alpha"] * self.beta_emission[e] for e in self.emissions
                ]
            )
            self.p_emission[label] = dict(zip(self.emissions, temp_p_emission))

            # save label
            self.states = self.states.union({label})

            yield label

    def initialise(self, k: int = 20) -> None:
        """
        Initialise the HDPHMM. This involves:
          + Choosing starting values for all hyperparameters
          + Initialising all Chains (see Chain.initialise for further info)
          + Initialising priors for probabilities (i.e. the Hierarchical priors)
          + Updating all counts
        
        sampling latent states, auxiliary beam variables,
        Typically called directly from a HDPHMM object.
        :param k: number of symbols to sample from for latent states
        :return: None
        """
        # create as many states as needed
        states = [next(self._label_generator) for _ in range(k)]
        self.states = set(states)

        # set hyperparameters
        self.hyperparameters = {param: prior() for param, prior in self.priors.items()}

        # initialise chains
        for c in self.chains:
            c.initialise(states)

        # initialise hierarchical priors
        temp_beta = sorted(
            np.random.dirichlet(
                [self.hyperparameters["gamma"] / (self.k + 1)] * (self.k + 1)
            ),
            reverse=True,
        )
        beta_transition = dict(zip(list(self.states) + [None], temp_beta))
        self.beta_transition = shrink_probabilities(beta_transition)
        self.auxiliary_transition_variables = {
            s1: {s2: 1 for s2 in self.states.union({None})}
            for s1 in self.states.union({None})
        }

        # update counts before resampling
        self._n_update()

        # resample remaining hyperparameters
        self.resample_beta_transition()
        self.resample_beta_emission()
        self.resample_p_initial()
        self.resample_p_transition()
        self.resample_p_emission()

        # set initialised flag
        self._initialised = True

    def update_states(self):
        """
        Remove defunct states from the internal set of states, and merge all parameters
        associated with these states back into the 'None' values.
        """
        # identify states to merge
        states_prev = self.states
        states_next = set(
            sorted(
                functools.reduce(
                    set.union, (set(c.latent_sequence) for c in self.chains), set()
                )
            )
        )
        states_removed = states_prev - states_next

        # merge old probabilities into None
        for state in states_removed:
            # remove entries and add to aggregate None state
            self.beta_transition[None] += self.beta_transition.pop(state)
            self.p_initial[None] += self.p_initial.pop(state)
            for s1 in states_next.union({None}):
                self.p_transition[s1][None] += self.p_transition[s1].pop(state)

            # remove transition vector entirely
            del self.p_transition[state]

        # update internal state tracking
        self.states = states_next

    def _n_update(self):
        """
        Update counts required for resampling probabilities. These counts are used
        to sample from the posterior distribution for probabilities. This function
        should be called after any latent state is changed, including after resampling.
        :return: None
        """
        # check that all chains are initialised
        if any(not chain.initialised_flag for chain in self.chains):
            raise AssertionError(
                "Chains must be initialised before calculating fit parameters"
            )

        # transition count for non-oracle transitions
        n_initial = {s: 0 for s in self.states.union({None})}
        n_emission = {
            s: {e: 0 for e in self.emissions} for s in self.states.union({None})
        }
        n_transition = {
            s1: {s2: 0 for s2 in self.states.union({None})}
            for s1 in self.states.union({None})
        }

        # increment all relevant hyperparameters while looping over sequence
        for chain in self.chains:
            # increment initial states emitted by oracle
            n_initial[chain.latent_sequence[0]] += 1

            # increment emissions only for final state
            n_emission[chain.latent_sequence[chain.T - 1]][
                chain.emission_sequence[chain.T - 1]
            ] += 1

            # increment all transitions and emissions within chain
            for t in range(chain.T - 1):
                n_emission[chain.latent_sequence[t]][chain.emission_sequence[t]] += 1
                n_transition[chain.latent_sequence[t]][
                    chain.latent_sequence[t + 1]
                ] += 1

        # store recalculated fit hyperparameters
        self.n_initial = n_initial
        self.n_emission = n_emission
        self.n_transition = n_transition

    @staticmethod
    def _resample_auxiliary_transition_atom_complete(
        alpha, beta, n, use_approximation=True
    ):
        """
        Use a resampling approach that estimates probabilities for all auxiliary
        transition parameters. This avoids the slowdown in convergence caused by
        Metropolis Hastings rejections, but is more computationally costly.
        :param alpha:
        :param beta:
        :param n:
        :param use_approximation:
        :return:
        """
        # initialise values required to resample
        p_required = np.random.uniform(0, 1)
        m = 0
        p_cumulative = 0
        scale = alpha * beta

        if not use_approximation:
            # use precise probabilities
            try:
                logp_constant = np.log(special.gamma(scale)) - np.log(
                    special.gamma(scale + n)
                )
                while p_cumulative == 0 or p_cumulative < p_required and m < n:
                    # accumulate probability
                    m += 1
                    logp_accept = (
                        m * np.log(scale)
                        + np.log(stirling(n, m, kind=1))
                        + logp_constant
                    )
                    p_cumulative += np.exp(logp_accept)
            # after one failure use only the approximation
            except (RecursionError, OverflowError):
                # correct for failed case before
                m -= 1
        while p_cumulative < p_required and m < n:
            # problems with stirling recursion (large n & m), use approximation instead
            # magic number is the Euler constant
            # approximation derived in documentation
            m += 1
            logp_accept = (
                m
                + (m + scale - 0.5) * np.log(scale)
                + (m - 1) * np.log(0.57721 + np.log(n - 1))
                - (m - 0.5) * np.log(m)
                - scale * np.log(scale + n)
                - scale
            )
            p_cumulative += np.exp(logp_accept)
        # breaks out of loop after m is sufficiently large
        return max(m, 1)

    @staticmethod
    def _resample_auxiliary_transition_atom_mh(
        alpha, beta, n, m_curr, use_approximation=True
    ):
        """
        Use a Metropolos Hastings resampling approach that often rejects the proposed
        value. This can cause the convergence to slow down (as the values are less
        dynamic) but speeds up the computation.
        :param alpha:
        :param beta:
        :param n:
        :param m_curr:
        :param use_approximation:
        :return:
        """
        # propose new m
        n = max(n, 1)
        m_proposed = random.choice(range(1, n + 1))
        if m_curr > n:
            return m_proposed

        # find relative probabilities
        if use_approximation and n > 10:
            logp_diff = (
                (m_proposed - 0.5) * np.log(m_proposed)
                - (m_curr - 0.5) * np.log(m_curr)
                + (m_proposed - m_curr) * np.log(alpha * beta * np.exp(1))
                + (m_proposed - m_curr) * np.log(0.57721 + np.log(n - 1))
            )
        else:
            p_curr = float(stirling(n, m_curr, kind=1)) * ((alpha * beta) ** m_curr)
            p_proposed = float(stirling(n, m_proposed, kind=1)) * (
                (alpha * beta) ** m_proposed
            )
            logp_diff = np.log(p_proposed) - np.log(p_curr)

        # use MH variable to decide whether to accept m_proposed
        with catch_warnings(record=True) as caught_warnings:
            p_accept = min(1, np.exp(logp_diff))
            p_accept = bool(np.random.binomial(n=1, p=p_accept))  # convert to boolean
            if caught_warnings:
                p_accept = True

        return m_proposed if p_accept else m_curr

    @staticmethod
    def _resample_auxiliary_transition_atom(
        state_pair,
        alpha,
        beta,
        n_initial,
        n_transition,
        auxiliary_transition_variables,
        resample_type="mh",
        use_approximation=True,
    ):
        """
        Resampling the auxiliary transition atoms should be performed before resampling
        the transition beta values. This is the static method, creating to allow for
        parallelised resampling.
        :param state_pair:
        :param alpha:
        :param beta:
        :param n_initial:
        :param n_transition:
        :param auxiliary_transition_variables:
        :param resample_type:
        :param use_approximation:
        :return:
        """
        # extract states
        state1, state2 = state_pair

        # apply resampling
        if resample_type == "mh":
            return HDPHMM._resample_auxiliary_transition_atom_mh(
                alpha,
                beta[state2],
                n_initial[state2] + n_transition[state1][state2],
                auxiliary_transition_variables[state1][state2],
                use_approximation,
            )
        elif resample_type == "complete":
            return HDPHMM._resample_auxiliary_transition_atom_complete(
                alpha,
                beta[state2],
                n_initial[state2] + n_transition[state1][state2],
                use_approximation,
            )
        else:
            raise ValueError("resample_type must be either mh or complete")

    # TODO: decide whether to use either MH resampling or approximation sampling and
    # remove the alternative, unnecessary complexity in code
    def _resample_auxiliary_transition_variables(
        self, ncores=1, resample_type="mh", use_approximation=True
    ):
        # standard process uses typical list comprehension
        if ncores < 2:
            self.auxiliary_transition_variables = {
                s1: {
                    s2: HDPHMM._resample_auxiliary_transition_atom(
                        (s1, s2),
                        alpha=self.hyperparameters["alpha"],
                        beta=self.beta_transition,
                        n_initial=self.n_initial,
                        n_transition=self.n_transition,
                        auxiliary_transition_variables=self.auxiliary_transition_variables,
                        resample_type=resample_type,
                        use_approximation=use_approximation,
                    )
                    for s2 in self.states
                }
                for s1 in self.states
            }

        # parallel process uses anonymous functions and mapping
        else:
            # specify ordering of states
            state_pairs = [(s1, s2) for s1 in self.states for s2 in self.states]

            # parallel process resamples
            resample_partial = functools.partial(
                HDPHMM._resample_auxiliary_transition_atom,
                alpha=self.hyperparameters["alpha"],
                beta=self.beta_transition,
                n_initial=self.n_initial,
                n_transition=self.n_transition,
                auxiliary_transition_variables=self.auxiliary_transition_variables,
                resample_type=resample_type,
                use_approximation=use_approximation,
            )
            pool = multiprocessing.Pool(processes=ncores)
            auxiliary_transition_variables = pool.map(resample_partial, state_pairs)
            pool.close()

            # store as dictionary
            for pair_n in range(len(state_pairs)):
                state1, state2 = state_pairs[pair_n]
                self.auxiliary_transition_variables[state1][
                    state2
                ] = auxiliary_transition_variables[pair_n]

    def _get_beta_transition_metaparameters(self):
        """
        Calculate parameters for the Dirichlet posterior of the transition beta
        variables (with infinite states aggregated into 'None' state)
        :return: dict, with a key for each state and None, and values equal to parameter
        values
        """
        # aggregate
        params = {
            s2: sum(self.auxiliary_transition_variables[s1][s2] for s1 in self.states)
            for s2 in self.states
        }
        params[None] = self.hyperparameters["gamma"]
        return params

    def resample_beta_transition(
        self, ncores=1, auxiliary_resample_type="mh", use_approximation=True, eps=1e-12
    ):
        """
        Resample the beta values used to calculate the starting and transition
        probabilities.
        :param ncores: int. Number of cores to use in multithreading. Values below 2
        mean the resampling step is not parallelised.
        :param auxiliary_resample_type: either "mh" or "complete". Impacts the way
        in which the auxiliary transition variables are estimated.
        :param use_approximation: bool, flag to indicate whether an approximate
        resampling should occur. ignored if `auxiliary_resample_type` is "mh"
        :param eps: shrinkage parameter to avoid rounding error.
        :return: None
        """
        # auxiliary variables must be resampled to resample beta variables
        self._resample_auxiliary_transition_variables(
            ncores=ncores,
            resample_type=auxiliary_resample_type,
            use_approximation=use_approximation,
        )

        # resample from Dirichlet posterior
        params = self._get_beta_transition_metaparameters()
        temp_result = np.random.dirichlet(list(params.values())).tolist()
        beta_transition = dict(zip(list(params.keys()), temp_result))
        self.beta_transition = shrink_probabilities(beta_transition, eps)

    def calculate_beta_transition_loglikelihood(self):
        # get Dirichlet hyperparameters
        params = self._get_beta_transition_metaparameters()
        ll_transition = np.log(
            stats.dirichlet.pdf(
                [self.beta_transition[s] for s in params.keys()],
                [params[s] for s in params.keys()],
            )
        )
        return ll_transition

    def _get_beta_emission_metaparameters(self):
        """
        Calculate parameters for the Dirichlet posterior of the emission beta variables
        (with infinite states aggregated into 'None' state)
        :return: dict, with a key for each emission, and values equal to parameter
        values
        """
        # aggregate
        params = {
            e: sum(self.n_emission[s][e] for s in self.states)
            + self.hyperparameters["gamma_emission"] / self.n
            for e in self.emissions
        }
        return params

    def resample_beta_emission(self, eps=1e-12):
        """
        Resample the beta values used to calculate the emission probabilties.
        :param eps: Minimum value for expected value before resampling.
        :return: None.
        """
        # resample from Dirichlet posterior
        params = self._get_beta_emission_metaparameters()
        temp_result = np.random.dirichlet(list(params.values())).tolist()
        beta_emission = dict(zip(list(params.keys()), temp_result))
        self.beta_emission = shrink_probabilities(beta_emission, eps)

    def calculate_beta_emission_loglikelihood(self):
        # get Dirichlet hyperparameters
        params = self._get_beta_emission_metaparameters()
        ll_emission = np.log(
            stats.dirichlet.pdf(
                [self.beta_emission[e] for e in self.emissions],
                [params[e] for e in self.emissions],
            )
        )
        return ll_emission

    def _get_p_initial_metaparameters(self):
        params = {
            s: self.n_initial[s]
            + self.hyperparameters["alpha"] * self.beta_transition[s]
            for s in self.states
        }
        params[None] = self.hyperparameters["alpha"] * self.beta_transition[None]
        return params

    def resample_p_initial(self, eps=1e-12):
        """
        Resample the starting probabilities. Performed as a sample from the posterior
        distribution, which is a Dirichlet with pseudocounts and actual counts combined.
        :param eps: minimum expected value.
        :return: None.
        """
        params = self._get_p_initial_metaparameters()
        temp_result = np.random.dirichlet(list(params.values())).tolist()
        p_initial = dict(zip(list(params.keys()), temp_result))
        self.p_initial = shrink_probabilities(p_initial, eps)

    def calculate_p_initial_loglikelihood(self):
        params = self._get_p_initial_metaparameters()
        ll_initial = np.log(
            stats.dirichlet.pdf(
                [self.p_initial[s] for s in params.keys()],
                [params[s] for s in params.keys()],
            )
        )
        return ll_initial

    def _get_p_transition_metaparameters(self, state):
        if self.sticky:
            params = {
                s2: self.n_transition[state][s2]
                + self.hyperparameters["alpha"]
                * (1 - self.hyperparameters["kappa"])
                * self.beta_transition[s2]
                for s2 in self.states
            }
            params[None] = (
                self.hyperparameters["alpha"]
                * (1 - self.hyperparameters["kappa"])
                * self.beta_transition[None]
            )
            params[state] += (
                self.hyperparameters["alpha"] * self.hyperparameters["kappa"]
            )
        else:
            params = {
                s2: self.n_transition[state][s2]
                + self.hyperparameters["alpha"] * self.beta_transition[s2]
                for s2 in self.states
            }
            params[None] = self.hyperparameters["alpha"] * self.beta_transition[None]

        return params

    def resample_p_transition(self, eps=1e-12):
        """
        Resample the transition probabilities from the current beta values and kappa
        value, if the chain is sticky.
        :param eps: minimum expected value passed to Dirichlet sampling step
        :return: None
        """
        # empty current transition values
        self.p_transition = {}

        # refresh each state in turn
        for state in self.states:
            params = self._get_p_transition_metaparameters(state)
            temp_result = np.random.dirichlet(list(params.values())).tolist()
            p_transition_state = dict(zip(list(params.keys()), temp_result))
            self.p_transition[state] = shrink_probabilities(p_transition_state, eps)

        # add transition probabilities from unseen states
        # note: no stickiness update because these are aggregated states
        params = {
            k: self.hyperparameters["alpha"] * v
            for k, v in self.beta_transition.items()
        }
        temp_result = np.random.dirichlet(list(params.values())).tolist()
        p_transition_none = dict(zip(list(params.keys()), temp_result))
        self.p_transition[None] = shrink_probabilities(p_transition_none, eps)

    def calculate_p_transition_loglikelihood(self):
        """
        Note: this calculates the likelihood over all entries in the transition matrix.
        If chains have been resampled (this is the case during MCMC sampling, for
        example), then there may be entries in the transition matrix that no longer
        correspond to actual states.
        :return:
        """
        ll_transition = 0
        states = self.p_transition.keys()

        # get probability for each state
        for state in states:
            params = self._get_p_transition_metaparameters(state)
            ll_transition += np.log(
                stats.dirichlet.pdf(
                    [self.p_transition[state][s] for s in states],
                    [params[s] for s in states],
                )
            )

        # get probability for aggregate state
        params = {
            k: self.hyperparameters["alpha"] * v
            for k, v in self.beta_transition.items()
        }
        ll_transition += np.log(
            stats.dirichlet.pdf(
                [self.p_transition[None][s] for s in states],
                [params[s] for s in states],
            )
        )

        return ll_transition

    def _get_p_emission_metaparameters(self, state):
        params = {
            e: self.n_emission[state][e]
            + self.hyperparameters["alpha_emission"] * self.beta_emission[e]
            for e in self.emissions
        }
        return params

    def resample_p_emission(self, eps=1e-12):
        """
        resample emission parameters from emission priors and counts.
        :param eps: minimum expected value passed to Dirichlet distribution
        :return: None
        """
        # find hyperparameters
        for state in self.states:
            params = self._get_p_emission_metaparameters(state)
            temp_result = np.random.dirichlet(list(params.values())).tolist()
            p_emission_state = dict(zip(list(params.keys()), temp_result))
            self.p_emission[state] = shrink_probabilities(p_emission_state, eps)

        # add emission probabilities from unseen states
        params = {
            k: self.hyperparameters["alpha_emission"] * v
            for k, v in self.beta_emission.items()
        }
        temp_result = np.random.dirichlet(list(params.values())).tolist()
        p_emission_none = dict(zip(list(params.keys()), temp_result))
        self.p_emission[None] = shrink_probabilities(p_emission_none, eps)

    def calculate_p_emission_loglikelihood(self):
        ll_emission = 0

        # get probability for each state
        for state in self.states:
            params = self._get_p_emission_metaparameters(state)
            ll_emission += np.log(
                stats.dirichlet.pdf(
                    [self.p_emission[state][e] for e in self.emissions],
                    [params[e] for e in self.emissions],
                )
            )

        # get probability for aggregate state
        params = {
            k: self.hyperparameters["alpha_emission"] * v
            for k, v in self.beta_emission.items()
        }
        ll_emission += np.log(
            stats.dirichlet.pdf(
                [self.p_emission[None][e] for e in self.emissions],
                [params[e] for e in self.emissions],
            )
        )

        return ll_emission

    def print_fit_parameters(self):
        """
        Prints a copy of the current state counts.
        Used for convenient checking in a command line environment.
        For dictionaries containing the raw values, use the `n_*` attributes.
        :return:
        """
        # create copies to avoid editing
        n_initial = copy.deepcopy(self.n_initial)
        n_emission = copy.deepcopy(self.n_emission)
        n_transition = copy.deepcopy(self.n_transition)

        # make nested lists for clean printing
        initial = [[str(s)] + [str(n_initial[s])] for s in self.states]
        initial.insert(0, ["S_i", "Y_0"])
        emissions = [
            [str(s)] + [str(n_emission[s][e]) for e in self.emissions]
            for s in self.states
        ]
        emissions.insert(0, ["S_i \\ E_i"] + list(map(str, self.emissions)))
        transitions = [
            [str(s1)] + [str(n_transition[s1][s2]) for s2 in self.states]
            for s1 in self.states
        ]
        transitions.insert(0, ["S_i \\ S_j"] + list(map(lambda x: str(x), self.states)))

        # format tables
        ti = terminaltables.DoubleTable(initial, "Starting state counts")
        te = terminaltables.DoubleTable(emissions, "Emission counts")
        tt = terminaltables.DoubleTable(transitions, "Transition counts")
        ti.padding_left = 1
        ti.padding_right = 1
        te.padding_left = 1
        te.padding_right = 1
        tt.padding_left = 1
        tt.padding_right = 1
        ti.justify_columns[0] = "right"
        te.justify_columns[0] = "right"
        tt.justify_columns[0] = "right"

        # print tables
        print("\n")
        print(ti.table)
        print("\n")
        print(te.table)
        print("\n")
        print(tt.table)
        print("\n")

        #
        return None

    def print_probabilities(self):
        """
        Prints a copy of the current probabilities.
        Used for convenient checking in a command line environment.
        For dictionaries containing the raw values, use the `p_*` attributes.
        :return:
        """
        # create copies to avoid editing
        p_initial = copy.deepcopy(self.p_initial)
        p_emission = copy.deepcopy(self.p_emission)
        p_transition = copy.deepcopy(self.p_transition)

        # convert to nested lists for clean printing
        p_initial = [[str(s)] + [str(round(p_initial[s], 3))] for s in self.states]
        p_emission = [
            [str(s)] + [str(round(p_emission[s][e], 3)) for e in self.emissions]
            for s in self.states
        ]
        p_transition = [
            [str(s1)] + [str(round(p_transition[s1][s2], 3)) for s2 in self.states]
            for s1 in self.states
        ]
        p_initial.insert(0, ["S_i", "Y_0"])
        p_emission.insert(0, ["S_i \\ E_j"] + [str(e) for e in self.emissions])
        p_transition.insert(0, ["S_i \\ E_j"] + [str(s) for s in self.states])

        # format tables
        ti = terminaltables.DoubleTable(p_initial, "Starting state probabilities")
        te = terminaltables.DoubleTable(p_emission, "Emission probabilities")
        tt = terminaltables.DoubleTable(p_transition, "Transition probabilities")
        te.padding_left = 1
        te.padding_right = 1
        tt.padding_left = 1
        tt.padding_right = 1
        te.justify_columns[0] = "right"
        tt.justify_columns[0] = "right"

        # print tables
        print("\n")
        print(ti.table)
        print("\n")
        print(te.table)
        print("\n")
        print(tt.table)
        print("\n")

        #
        return None

    def calculate_chain_loglikelihood(self):
        """
        Calculate the negative log likelihood of the chain, given its current
        latent states. This is calculated based on the observed emission sequences only,
        and not on the probabilities of the hyperparameters.
        :return:
        """
        return sum(
            chain.neglogp_chain(self.p_initial, self.p_emission, self.p_transition)
            for chain in self.chains
        )

    def calculate_loglikelihood(self):
        """
        Negative log-likelihood of the entire HDPHMM object. Combines the likelihoods of
        the transition and emission beta parameters, and of the chains themselves.
        Does not include the probabilities of the hyperparameter priors.
        :return: non-negative float
        """
        return (
            self.calculate_beta_transition_loglikelihood()
            + self.calculate_beta_emission_loglikelihood()
            + self.calculate_p_initial_loglikelihood()
            + self.calculate_p_transition_loglikelihood()
            + self.calculate_p_emission_loglikelihood()
            + self.calculate_chain_loglikelihood()
        )

    def resample_chains(self, ncores=1):
        """
        Resample the latent states in all chains. This uses Beam sampling to improve the
        resampling time.
        :param ncores: int, number of threads to use in multithreading.
        :return: None
        """
        # extract probabilities
        p_initial, p_emission, p_transition = (
            self.p_initial,
            self.p_emission,
            self.p_transition,
        )

        # create temporary function for mapping
        resample_partial = functools.partial(
            Chain.resample_latent_sequence,
            states=list(self.states) + [None],
            p_initial=copy.deepcopy(p_initial),
            p_emission=copy.deepcopy(p_emission),
            p_transition=copy.deepcopy(p_transition),
        )

        # parallel process resamples
        pool = multiprocessing.Pool(processes=ncores)
        new_latent_sequences = pool.map(
            resample_partial,
            ((chain.emission_sequence, chain.latent_sequence) for chain in self.chains),
        )
        pool.close()

        # assign returned latent sequences back to Chains
        for i in range(self.c):
            self.chains[i].latent_sequence = new_latent_sequences[i]

        # update chains using results
        # TODO: parameter check if we should be using alpha or gamma as parameter
        state_generator = dirichlet_process_generator(
            self.hyperparameters["gamma"], output_generator=self.state_generator()
        )
        for chain in self.chains:
            chain.latent_sequence = [
                s if s is not None else next(state_generator)
                for s in chain.latent_sequence
            ]

        # update counts
        self._n_update()

    def maximise_hyperparameters(self):
        """
        Choose the MAP (maximum a posteriori) value for the hyperparameters.
        Not yet implemented.
        :return: None
        """
        raise NotImplementedError(
            "This has not yet been written!"
            + " Ping the author if you want it to happen."
        )
        pass

    def resample_hyperparameters(self):
        """
        Resample hyperparameters using a Metropolis Hastings algorithm. Uses a
        straightforward resampling approach, which (for each hyperparameter) samples a
        proposed value according to the prior distribution, and accepts the proposed
        value with probability scaled by the relative probabilities of the model under
        the current and proposed model.
        :return: None
        """
        # iterate and accept each in order
        for param_name in self.priors.keys():
            # don't update kappa if not a sticky chain
            if param_name == "kappa" and not self.sticky:
                continue

            # get current negative log likelihood
            likelihood_curr = self.calculate_loglikelihood()

            # log-likelihood under new parameter value
            param_current = self.hyperparameters[param_name]
            self.hyperparameters[param_name] = self.priors[param_name]()
            likelihood_proposed = self.calculate_loglikelihood()

            # find Metropolis Hasting acceptance probability
            p_accept = min(1, np.exp(likelihood_proposed - likelihood_curr))

            # choose whether to accept
            alpha_accepted = bool(np.random.binomial(n=1, p=p_accept))

            # if we do not accept, revert to the previous value
            if not alpha_accepted:
                self.hyperparameters[param_name] = param_current

    def mcmc(self, n=1000, burn_in=500, save_every=10, ncores=1, verbose=True):
        """
        Use Markov chain Monte Carlo to estimate the starting, transition, and emission
        parameters of the HDPHMM, as well as the number of latent states.
        :param n: int, number of iterations to complete.
        :param burn_in: int, number of iterations to complete before savings results.
        :param save_every: int, only iterations which are a multiple of `save_every`
        will have their results appended to the results.
        :param ncores: int, number of cores to use in multithreaded latent state
        resampling.
        :param verbose: bool, flag to indicate whether iteration-level statistics should
        be printed.
        :return: A dict containing results from every saved iteration. Includes:
          + the number of states of the HDPHMM
          + the negative log likelihood of the entire model
          + the negative log likelihood of the chains only
          + the hyperparameters of the HDPHMM
          + the emission beta values
          + the transition beta values
          + all probability dictionary objects
        """
        # store hyperparameters in a single dict
        results = {
            "state_count": list(),
            "loglikelihood": list(),
            "chain_loglikelihood": list(),
            "hyperparameters": list(),
            "beta_emission": list(),
            "beta_transition": list(),
            "parameters": list(),
        }

        for i in tqdm.tqdm(range(n)):
            # update statistics
            states_prev = copy.copy(self.states)

            # work down hierarchy when resampling
            self.update_states()
            self.resample_hyperparameters()
            self.resample_beta_transition(ncores=ncores)
            self.resample_beta_emission()
            self.resample_p_initial()
            self.resample_p_transition()
            self.resample_p_emission()
            self.resample_chains(ncores=ncores)

            # update computation-heavy statistics
            likelihood_curr = self.calculate_loglikelihood()

            # print iteration summary if required
            if verbose:
                if i == burn_in:
                    tqdm.tqdm.write("Burn-in period complete")
                states_taken = states_prev - self.states
                states_added = self.states - states_prev
                msg = [
                    "Iter: {}".format(i),
                    "Likelihood: {0:.1f}".format(likelihood_curr),
                    "states: {}".format(len(self.states)),
                ]
                if len(states_added) > 0:
                    msg.append("states added: {}".format(states_added))
                if len(states_taken) > 0:
                    msg.append("states removed: {}".format(states_taken))
                tqdm.tqdm.write(", ".join(msg))

            # store results
            if i >= burn_in and i % save_every == 0:
                # get hyperparameters as nested lists
                p_initial = copy.deepcopy(self.p_initial)
                p_emission = copy.deepcopy(self.p_emission)
                p_transition = copy.deepcopy(self.p_transition)

                # save new data
                results["state_count"].append(self.k)
                results["loglikelihood"].append(likelihood_curr)
                results["chain_loglikelihood"].append(
                    self.calculate_chain_loglikelihood()
                )
                results["hyperparameters"].append(copy.deepcopy(self.hyperparameters))
                results["beta_emission"].append(self.beta_emission)
                results["beta_transition"].append(self.beta_transition)
                results["parameters"].append(
                    {
                        "p_initial": p_initial,
                        "p_emission": p_emission,
                        "p_transition": p_transition,
                    }
                )

        # return saved observations
        return results

In [118]:
h = HDPHMM(emission_sequence)

In [119]:
asdf = h.state_generator()

KeyError: 1

In [46]:
def dirichlet_process_generator(alpha: Numeric = 1, 
                            output_generator: Iterator[Union[str, int]] = None
                            ) -> Generator[Union[str, int], None, None]:
    
    """
    Returns a generator object which yields subsequent draws from a single dirichlet
    process.
    :param alpha: alpha parameter of the Dirichlet process
    :param output_generator: generator which yields unique symbols (likely hdphmm.state_generator)
    :return: generator object
    """

    if output_generator is None:
        
        output_generator = itertools.count(start=0, step=1)
    
    count = 0
    weights = {}
    
    while True:
        
        if random.uniform(0, 1) > (count / (count + alpha)):

            val = next(output_generator)
            weights[val] = 1

        else:

            val = np.random.choice(list(weights.keys()), 1, p=list(x / count for x in weights.values()))[0]
            weights[val] += 1
        
        count += 1
        yield val

In [61]:
state_generator = dirichlet_process_generator(0.1, output_generator=None)

In [64]:
next(state_generator)

1

In [56]:
output_generator = itertools.count(start=0, step=1)

In [57]:
next(output_generator)

0

In [74]:
weights = {}
count = 0
alpha = 2
print((count / (count + alpha)))
random.uniform(0, 1) > (count / (count + alpha))

0.0


True

In [75]:
val = "s1"
weights[val] = 1
count += 1

In [78]:
print((count / (count + alpha)))
random.uniform(0, 1) > (count / (count + alpha))

0.3333333333333333


True

In [79]:
val = "s2"
weights[val] = 1
count += 1

In [85]:
print((count / (count + alpha)))
random.uniform(0, 1) > (count / (count + alpha))

0.5


False

In [90]:
val = np.random.choice(list(weights.keys()), 1, p=list(x / count for x in weights.values()))[0]
weights[val] += 1

In [91]:
weights

{'s1': 2, 's2': 1}

In [92]:
help(stats.dirichlet.pdf)

Help on method pdf in module scipy.stats._multivariate:

pdf(x, alpha) method of scipy.stats._multivariate.dirichlet_gen instance
    The Dirichlet probability density function.
    
    Parameters
    ----------
    x : array_like
        Quantiles, with the last axis of `x` denoting the components.
    alpha : array_like
        The concentration parameters. The number of entries determines the
        dimensionality of the distribution.
    
    Returns
    -------
    pdf : ndarray or scalar
        The probability density function evaluated at `x`.



In [96]:
hmm.beta_transition.keys()

dict_keys(['h', 'i', 'g', 'c', 'f', 'b', 'e', 'd', 'j', 'a', None])

In [97]:
hmm.emissions

{1, 2, 3, 4, 5, 6}

In [98]:
hmm.beta_emission

{1: 0.16062472439721218,
 2: 0.20022194358747714,
 3: 0.21289338107264402,
 4: 0.16328969313068475,
 5: 0.06818009095316378,
 6: 0.19479016685881823}

In [107]:
list(hmm.beta_transition.values())


[0.06241856843259871,
 0.07278857983245424,
 0.06438768983012266,
 0.12743544586776925,
 0.07636547601500883,
 0.06509871959553026,
 0.06382791636834269,
 0.13185779665042985,
 0.06685624019244683,
 0.09771843016884878,
 0.17124513704644787]

In [108]:

hmm.beta_emission

{1: 0.16062472439721218,
 2: 0.20022194358747714,
 3: 0.21289338107264402,
 4: 0.16328969313068475,
 5: 0.06818009095316378,
 6: 0.19479016685881823}

In [111]:
sum(np.random.dirichlet([5,2,8,5,9]))

0.9999999999999998

In [113]:

hmm.p_emission

{None: {1: 0.008023688212314832,
  2: 0.07453561468241945,
  3: 0.04478334158922741,
  4: 0.0008554700449518961,
  5: 0.05941198698216928,
  6: 0.812389898488917},
 'h': {1: 0.47188276811129487,
  2: 0.014946020356219036,
  3: 0.16031628743502602,
  4: 0.07971479741005484,
  5: 0.03925680519764025,
  6: 0.2338833214897649},
 'i': {1: 0.2518067035177135,
  2: 0.2082118786837343,
  3: 0.4321979290323758,
  4: 0.07316626959046446,
  5: 3.020657152264115e-05,
  6: 0.034587012604189366},
 'g': {1: 0.11119935397314314,
  2: 0.4690399959807394,
  3: 0.2002519803060443,
  4: 0.1499471259889298,
  5: 0.014545606025112819,
  6: 0.05501593772603031},
 'c': {1: 0.09797127356797518,
  2: 0.08264704494816406,
  3: 0.14507454095685524,
  4: 0.10560101059126319,
  5: 0.10483523830545825,
  6: 0.46387089163028394},
 'f': {1: 0.005086090918046418,
  2: 0.22044184772941178,
  3: 0.018582141874106898,
  4: 0.5663189780130751,
  5: 0.011871118485002297,
  6: 0.17769982298035755},
 'b': {1: 0.81664164091769