In [None]:
output = np.array(list(partitions(3, 4)))
output

## Understanding Mixtures

In [None]:
n_components = 15
# prob_batch_dim = 0

# it's easiest if component_probs is one-dimensional (i.e. prob of each component), without trying to add a batch shape
# because mixture distribution batch shape must = components distribution rightmost batch shape, or be scalar (so be scalar)
component_probs = tf.ones((n_components)) / (prob_batch_dim * n_components)
categorical = tfp.distributions.Categorical(probs=component_probs)
categorical

# leading dimensions - 1 are batch
# final dimension is num. components

In [None]:
# final batch size dimension must equal n components
# so batch size is specified by earlier dimensions only

anything = 3
# mixture_distribution n components equals components_distribution rightmost batch shape.
batch_dimension = (anything, n_components)
component_batch_size = 32

# this will all create batch dimension, as Normal has no event dimension ever
# (always independent, so may as well call batch)
mu = tf.zeros(batch_dimension)  
sigma = tf.ones(batch_dimension)

component_distribution = tfp.distributions.Normal(mu, sigma) 
component_distribution


In [None]:
component_distribution.sample()

In [None]:
mixture = tfp.distributions.MixtureSameFamily(
    categorical,
    component_distribution
)
mixture

In [None]:
# now with DM, DM can have an event shape, so that uses the rightmost shape of the params
batch_size = n_components  # mixture_distribution n components equals components_distribution rightmost batch shape.
event_size = 5
concentration = tf.random.uniform((batch_size, event_size))

total_votes = 10  # will broadcast
component_distribution = tfp.distributions.DirichletMultinomial(total_votes, concentration) 
component_distribution


In [None]:
mixture = tfp.distributions.MixtureSameFamily(
    categorical,
    component_distribution
)
mixture

In [None]:
mixture.sample()

In [None]:
# and we can add another batch dimension to do many galaxies at once

n_galaxies = 32
# now with DM, DM can have an event shape, so that uses the rightmost shape of the params
# batch_size = (n_galaxies, n_components)
# event_size = 5

concentration = tf.random.uniform((n_galaxies, n_components, event_size))

total_votes = 10  # will broadcast
component_distribution = tfp.distributions.DirichletMultinomial(total_votes, concentration) 
print(component_distribution)

mixture = tfp.distributions.MixtureSameFamily(
    categorical,
    component_distribution
)
print(mixture)

mixture.sample()

In [None]:
concentration = tf.random.uniform((n_samples, n_galaxies, event_size))
dirichlet = tfp.distributions.DirichletMultinomial(total_votes, concentration) 
print(dirichlet)

dirichlet_over_models = tfp.distributions.Independent(
    dirichlet
)
print(dirichlet_over_models)

component_probs = tf.ones((n_samples)) / n_samples
categorical = tfp.distributions.Categorical(probs=component_probs)

mixture = tfp.distributions.MixtureSameFamily(
    categorical,
    dirichlet_over_models
)
print(mixture)

mixture.sample()

In [None]:
n_answers = 2
concentration_per_g = np.array([[2., 2.], [2., 2.]])
concentration = np.stack([concentration_per_g] * 5, axis=0)
concentration.shape

In [None]:
total_votes = 10

In [None]:
n_samples = 2

dirichlet = tfp.distributions.DirichletMultinomial(total_votes, tf.constant(concentration, tf.float32), validate_args=True) 
counts = dirichlet.sample().numpy()
counts


In [None]:
print(dirichlet.prob(counts))

In [None]:
counts.shape

In [None]:
# dirichlet.prob(np.random.randint(size=(5, 2, 2))

In [None]:

# eats a batch shape, as it assumes rightmost batch shape of distribution is component-wise
component_probs = tf.zeros((n_samples))  # equal unnormalised log prob i.e. equal prob
categorical = tfp.distributions.Categorical(logits=component_probs)

mixture = tfp.distributions.MixtureSameFamily(
    categorical,
    dirichlet
)
print(mixture)

# this is the results having picked a random component each time
# mixture.sample()

In [None]:
mixture.prob(mixture.sample())

In [None]:
votes = [(v, total_votes - v) for v in range(total_votes+1)]
votes_batched = [[v] * 5 for v in votes]
votes_batched
x = tf.constant(np.array(votes_batched), dtype=tf.float32)
x.shape

In [None]:
mixture.sample()

In [None]:
probs = [mixture.prob(b) for b in x]
plt.plot(probs)

In [None]:
from scipy.integrate import dblquad
area = dblquad(lambda x, y: x*y, 0, 0.5, lambda x: 0, lambda x: 1-2*x)
area

In [None]:
mixture.entropy()

In [None]:
mixture.entropy_lower_bound()

In [None]:
mixture

In [None]:
n_samples = 15
# eats a batch shape, as it assumes rightmost batch shape of distribution is component-wise
component_probs = tf.zeros((n_samples))  # equal unnormalised log prob i.e. equal prob
categorical = tfp.distributions.Categorical(logits=component_probs)
# print(categorical)

In [None]:

# components = [tfp.distributions.Multinomial(total_count=10., logits=tf.zeros(event_size)) for _ in range(n_samples)]

# concentration = tf.random.uniform(shape=[event_size])
# components = [tfp.distributions.DirichletMultinomial(total_votes, concentration) for _ in range(n_samples)]

concentration = tf.random.uniform(shape=[batch_size, event_size])
components = [tfp.distributions.DirichletMultinomial(total_votes, concentration) for _ in range(n_samples)]

print(components[0])

mixture = tfp.distributions.Mixture(
    categorical,
    components,
    validate_args=True
)
print(mixture)
print(mixture.sample())

# this is the results having picked a random component each time
# mixture.sample()

In [None]:

# eats a batch shape, as it assumes rightmost batch shape of distribution is component-wise
component_probs = tf.zeros((n_samples))  # equal unnormalised log prob i.e. equal prob
categorical = tfp.distributions.Categorical(logits=component_probs)
print(categorical)

# components = [tfp.distributions.Multinomial(total_count=10., logits=tf.zeros(event_size)) for _ in range(n_samples)]

concentration = tf.ones(shape=[event_size])
components = [tfp.distributions.DirichletMultinomial(total_votes, concentration) for _ in range(n_samples)]

print(components[0])

mixture = tfp.distributions.Mixture(
    categorical,
    components,
    validate_args=True
)
print(mixture)
print(mixture.sample())

# this is the results having picked a random component each time
# mixture.sample()

In [None]:
tfp.distributions.Uniform()

In [None]:
# mixture.kl_divergence(tfp.distributions.Uniform())

In [None]:
# and we can add another batch dimension to do many galaxies at once

n_samples = 15

component_probs = tf.ones((n_samples)) / n_samples
categorical = tfp.distributions.Categorical(probs=component_probs)
categorical

# mixture_distribution n components equals components_distribution rightmost batch shape.
batch_size = n_samples
event_size = 3

concentration = tf.random.uniform((n_samples, event_size))

total_votes = 10  # will broadcast
component_distribution = tfp.distributions.DirichletMultinomial(total_votes, concentration) 
print(component_distribution)

mixture = tfp.distributions.MixtureSameFamily(
    categorical,
    component_distribution
)
print(mixture)

mixture.sample()

In [None]:
# and we can add another batch dimension to do many galaxies at once

n_samples = 15

component_probs = tf.ones((n_samples)) / n_samples
categorical = tfp.distributions.Categorical(probs=component_probs)
categorical

# mixture_distribution n components equals components_distribution rightmost batch shape.
batch_size = n_samples
event_size = 3

concentration = tf.random.uniform((n_samples, event_size))

total_votes = 10  # will broadcast
component_distribution = tfp.distributions.DirichletMultinomial(total_votes, concentration) 
print(component_distribution)

mixture = tfp.distributions.MixtureSameFamily(
    categorical,
    component_distribution
)
print(mixture)

mixture.sample()

In [None]:
samples_by_q = samples[:, question.start_index:question.end_index+1].astype(np.float32)
samples_by_q.shape

In [None]:

# expected_votes = tf.ones(1000) * 40.
# n_samples = tf.ones(1000) * 15.
expected_votes = 40
n_samples = 15
mixture = acquisition_utils.dirichlet_mixture(samples_by_q, expected_votes, n_samples)

In [None]:
mixture

In [None]:
draw = mixture.sample()
draw.shape

In [None]:
draw

In [None]:
# x = np.array([np.linspace(0., 40.) for _ in range(32)])
# x = np.stack([x], axis=1)
votes = np.random.randint(low=0, high=40, size=(32, 15))
x = np.stack([votes, 40-votes], axis=-1)
print(x.shape)
# print(x.shape)
# log_probs = mixture.log_prob(x)  # batch broadcasted
# print(log_probs.shape)
# fig, axes = plt.subplots(nrows=10, figsize=(8, 20))
# for n in range(10):
#     ax = axes[n]
#     ax.plot(x, log_probs[n])
# fig.tight_layout()

In [None]:
mixture.mean()

In [None]:
new_samples = mixture.sample()