Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
1 contributor

Users who have contributed to this file

64 lines (54 sloc) 1.61 KB
"""
Code by Tae-Hwan Hung(@graykode)
https://en.wikipedia.org/wiki/Dirichlet_distribution
3-Class Example
"""
from random import randint
import numpy as np
from matplotlib import pyplot as plt
def normalization(x, s):
"""
:return: normalizated list, where sum(x) == s
"""
return [(i * s) / sum(x) for i in x]
def sampling():
return normalization([randint(1, 100),
randint(1, 100), randint(1, 100)], s=1)
def gamma_function(n):
cal = 1
for i in range(2, n):
cal *= i
return cal
def beta_function(alpha):
"""
:param alpha: list, len(alpha) is k
:return:
"""
numerator = 1
for a in alpha:
numerator *= gamma_function(a)
denominator = gamma_function(sum(alpha))
return numerator / denominator
def dirichlet(x, a, n):
"""
:param x: list of [x[1,...,K], x[1,...,K], ...], shape is (n_trial, K)
:param a: list of coefficient, a_i > 0
:param n: number of trial
:return:
"""
c = (1 / beta_function(a))
y = [c * (xn[0] ** (a[0] - 1)) * (xn[1] ** (a[1] - 1))
* (xn[2] ** (a[2] - 1)) for xn in x]
x = np.arange(n)
return x, y, np.mean(y), np.std(y)
n_experiment = 1200
for ls in [(6, 2, 2), (3, 7, 5), (6, 2, 6), (2, 3, 4)]:
alpha = list(ls)
# random samping [x[1,...,K], x[1,...,K], ...], shape is (n_trial, K)
# each sum of row should be one.
x = [sampling() for _ in range(1, n_experiment + 1)]
x, y, u, s = dirichlet(x, alpha, n=n_experiment)
plt.plot(x, y, label=r'$\alpha=(%d,%d,%d)$' % (ls[0], ls[1], ls[2]))
plt.legend()
plt.savefig('graph/dirichlet.png')
plt.show()
You can’t perform that action at this time.