Skip to content
188c5c6 Sep 7, 2019
1 contributor

### Users who have contributed to this file

64 lines (54 sloc) 1.61 KB
 """ Code by Tae-Hwan Hung(@graykode) https://en.wikipedia.org/wiki/Dirichlet_distribution 3-Class Example """ from random import randint import numpy as np from matplotlib import pyplot as plt def normalization(x, s): """ :return: normalizated list, where sum(x) == s """ return [(i * s) / sum(x) for i in x] def sampling(): return normalization([randint(1, 100), randint(1, 100), randint(1, 100)], s=1) def gamma_function(n): cal = 1 for i in range(2, n): cal *= i return cal def beta_function(alpha): """ :param alpha: list, len(alpha) is k :return: """ numerator = 1 for a in alpha: numerator *= gamma_function(a) denominator = gamma_function(sum(alpha)) return numerator / denominator def dirichlet(x, a, n): """ :param x: list of [x[1,...,K], x[1,...,K], ...], shape is (n_trial, K) :param a: list of coefficient, a_i > 0 :param n: number of trial :return: """ c = (1 / beta_function(a)) y = [c * (xn[0] ** (a[0] - 1)) * (xn[1] ** (a[1] - 1)) * (xn[2] ** (a[2] - 1)) for xn in x] x = np.arange(n) return x, y, np.mean(y), np.std(y) n_experiment = 1200 for ls in [(6, 2, 2), (3, 7, 5), (6, 2, 6), (2, 3, 4)]: alpha = list(ls) # random samping [x[1,...,K], x[1,...,K], ...], shape is (n_trial, K) # each sum of row should be one. x = [sampling() for _ in range(1, n_experiment + 1)] x, y, u, s = dirichlet(x, alpha, n=n_experiment) plt.plot(x, y, label=r'\$\alpha=(%d,%d,%d)\$' % (ls[0], ls[1], ls[2])) plt.legend() plt.savefig('graph/dirichlet.png') plt.show()
You can’t perform that action at this time.