In [7]:
import numpy as np
import matplotlib.pyplot as plt
from infomax_class import distribution

In [None]:
class variable_length_generative_model(ABC):

    def __init__(self, param_range, n_possible_obs):
        # TODO generalise to continuous observations
        self.prior = distribution(param_range)
        self.n_possible_obs = n_possible_obs

        self.possible_sequences = {}  # TODO do this with a storage struct
        self.kl_components = {}
        # TODO load storage

    def set_prior(self, prob_densities):
        self.prior.set_probs(prob_densities)

    def possible_observation_sequences(self, N_distr):
        if N_distr.to_array() in self.possible_sequences.keys():
            return self.possible_sequences[N_distr.to_array()]
        else:
            # TODO do this for variable Ns
            sequences = [list(i) for i in itertools.product(list(range(self.n_possible_obs)), repeat=N)]
            self.possible_sequences[N_distr.to_array()] = sequences
            return sequences

    @abstractmethod
    def observation_likelihood(self, observation, param_value):
        # p(x | \theta)
        pass

    def observation_marginal(self, observation):
        # p(x) = \sum_\theta p(x | \theta) p(\theta)
        all_like = [self.observation_likelihood(observation, self.prior.eval_points[i]) * self.prior.prob_densities[i] for i in range(self.prior.eval_num)]
        return sum(all_like)

    def sequence_likelihood(self, observations, param_value):
        # p(X | theta) = \prod_x p(x | \theta)
        # the probability of observing every observation in the sequence, given the parameter
        # product of the individual likelihoods, as observations are i.i.d.
        return np.prod(np.array([self.observation_likelihood(o, param_value) for o in observations]))

    def sequence_marginal(self, observations):
        # p(X) = \sum_\theta p(X | \theta) p(\theta)
        # probability of observing the sequence given the entire prior distribution of the parameter instead of one specific value
        all_like = [self.sequence_likelihood(observations, self.prior.eval_points[i]) * self.prior.prob_densities[i] for i in range(self.prior.eval_num)]
        return sum(all_like)
    
    def _likelihoods_for_all_param(self, observation):
        return np.array([self.observation_likelihood(observation, self.prior.eval_points[i]) for i in range(self.prior.eval_num)])

    def _observation_prob_ratios(self, observation):
        # p(x | \theta) / \sum_\theta p(x | \theta) p(\theta)
        likelihoods = self._likelihoods_for_all_param(observation)
        marginal = np.sum(likelihoods * self.prior.prob_densities)
        #print("likes", likelihoods, "marg", marginal)
        return likelihoods / marginal
    
    def _predictive_distr(self, act_prior):
        possible_observations = list(range(self.n_possible_obs))
        predictive_probs = [sum([self.observation_likelihood(o, act_prior.eval_points[th]) * act_prior.prob_densities[th] for th in range(act_prior.eval_num)]) for o in possible_observations]
        return distribution(possible_observations, predictive_probs)

    def KL_divergences(self, N, posterior=None, M=0, clip=1e-6):
        # we reuse computation in this house
        storage_key = tuple(list(self.prior.prob_densities) + [N])
        if not storage_key in self.kl_components.keys():
            all_sequences = self.possible_observation_sequences(N)
            all_sequence_likelihoods = np.zeros((self.prior.eval_num, len(all_sequences)))  # p(X | \theta) \forall X, \theta
            for i_theta, p_theta in enumerate(self.prior.prob_densities):
                if p_theta == 0: continue
                for i_obs, obs in enumerate(all_sequences):
                    all_sequence_likelihoods[i_theta, i_obs] = self.sequence_likelihood(obs, self.prior.eval_points[i_theta])

            like_prior = all_sequence_likelihoods.transpose() * self.prior.prob_densities  # p(X | \theta) p(\theta) \forall X, \theta
            sequence_marginals = np.sum(like_prior.transpose(), axis=0)  # p(X) \forall X

            log_sequence_marginals = np.log(sequence_marginals, out=np.zeros_like(sequence_marginals, dtype=np.float64), where=(sequence_marginals!=0))
            log_sequence_likelihoods = np.log(all_sequence_likelihoods, out=np.zeros_like(all_sequence_likelihoods, dtype=np.float64), where=(all_sequence_likelihoods!=0))
            logdiff = log_sequence_likelihoods - log_sequence_marginals

            kl_components = all_sequence_likelihoods.transpose() * (logdiff.transpose())  # p(X | \theta) [\log p(X | \theta) - \log p(\theta)] \forall X, \theta
            self.kl_components[storage_key] = kl_components
        future_KLs = np.nansum(self.kl_components[storage_key], axis=0)  # KL[p(X | \theta) || p(X)] \forall \theta
        
        if posterior is None or M==0:
            return future_KLs
        else:
            # take M samples from the posterior-predictive distribution \sum_\theta p(x | \theta) p(\theta | X_old)
            # TODO reuse this part as well
            samples = self._predictive_distr(posterior).sample(M)
            sample_likelihoods = np.array([self._likelihoods_for_all_param(s) for s in samples])  # M x theta_res
            sample_prob_ratios = np.array([self._observation_prob_ratios(s) for s in samples])  # M x theta_res
            log_sample_prob_ratios = np.log(sample_prob_ratios, out=np.zeros_like(sample_prob_ratios, dtype=np.float64), where=(sample_prob_ratios!=0))
            #print(sample_prob_ratios)
            # we average over the samples
            #KLs = np.sum(sample_prob_ratios * (log_sample_prob_ratios + future_KLs), axis=0) / M
            # TODO why does this work way better than the one in the formulas???
            KLs = np.sum(sample_likelihoods * (log_sample_prob_ratios + future_KLs), axis=0) / M
            # the sample-based apprixmation can come back negative, so we clip
            return np.maximum(KLs, 0.) 

    def mutual_information(self, N, posterior=None, M=0, clip=1e-6):
        # TODO implement the version with sampling
        return np.nansum(self.prior.prob_densities * self.KL_divergences(N, posterior=posterior, M=M, clip=clip))

    def blahut_arimoto_prior(self, N, prior_res, n_step, posterior=None, M=0, min_delta=0, plot=False):
        self.set_prior(np.ones(prior_res) / prior_res)
        MIs = [self.mutual_information(N, posterior=posterior, M=M)]

        for step in range(n_step):
            act_kl = self.KL_divergences(N, posterior=posterior, M=M)
            exp_kl = np.exp(act_kl)
            unnorm_new_p = exp_kl * self.prior.prob_densities
            #print(act_kl, unnorm_new_p)
            self.set_prior(unnorm_new_p / np.sum(unnorm_new_p))
            MIs.append(self.mutual_information(N, posterior=posterior, M=M))
            if posterior is None or M==0:
                if MIs[-1] - MIs[-2] <= min_delta:
                    break

        if plot:
            plt.subplot(1, 2, 1)
            self.prior.plot()
            plt.subplot(1, 2, 2)
            plt.plot(MIs)

    def posterior(self, observations):
        if self.prior.prob_densities is None:
            raise RuntimeError("Prior not set, cannot calculate posterior.")
        unnorm_post = np.array([self.sequence_likelihood(observations, self.prior.eval_points[th]) * self.prior.prob_densities[th] for th in range(self.prior.eval_num)])
        post = unnorm_post / np.sum(unnorm_post)
        return distribution(self.prior.range, post)

[1.  2.  3.2 4. ]
