## Dirichlet distribution

## Install Package

In [None]:
#!pip install plotly==5.2.1
!pip install plotly

## Import Packages

In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import plotly.graph_objects as go

## Define Functions

In [None]:
def plotdist_3d(s):
    
    data=[go.Scatter3d(
        
        x=s[:, 0],
        y=s[:, 1],
        z=s[:, 2],
        mode='markers',
        marker=dict(size=1, color='blue', opacity=0.7),
    )]
    
    layout = go.Layout(
        
        margin=dict(l=0, r=0, b=0, t=0)
    )
    
    go.Figure(data, layout).show()

In [None]:
def plotdist_2d(s):
        
    x = s[:, 0]
    y = s[:, 1]

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    axes[0].plot(x, y, '.')
    axes[0].axis('square')
    axes[0].set_xlim([0, 1])
    axes[0].set_ylim([0, 1])
    axes[0].set_xlabel('x')
    axes[0].set_ylabel('y')
    axes[0].set_title(r'$\alpha=({},{},{})$'.format(alpha[0], alpha[1], alpha[2]))

    sns.kdeplot(x=x, y=y, cmap='Blues', fill=True, ax=axes[1])
#    sns.histplot(x=x, y=y, bins=20)
    axes[1].axis('square')
    axes[1].set_xlim([0, 1])
    axes[1].set_ylim([0, 1])
    axes[1].set_xlabel('x')
    axes[1].set_ylabel('y')
    axes[1].set_title(r'$\alpha=({},{},{})$'.format(alpha[0], alpha[1], alpha[2]))
    

In [None]:
def plotdist_bars(s):
    
    x = np.arange(3)
    
    fig = plt.figure(figsize=(10, 4))
    
    for k in range(10):
    
        ax = fig.add_subplot(2, 5, k + 1)
        
        ax.bar(x, s[k, :], tick_label=x)
        
        ax.set_ylim([0, 1.1])
        ax.set_yticks([])
        

## Visualize Distribution

### Case : $\alpha = (5, 5, 5)$

In [None]:
alpha = np.array([5, 5, 5])

s = np.random.dirichlet(alpha, 5000)

plotdist_3d(s)

In [None]:
plotdist_2d(s)

In [None]:
plotdist_bars(s)

### Case : $\alpha = (1, 1, 1)$

In [None]:
alpha = np.array([1, 1, 1])

s = np.random.dirichlet(alpha, 5000)

plotdist_3d(s)

In [None]:
plotdist_2d(s)

In [None]:
plotdist_bars(s)

### Case : $\alpha = (0.2, 0.2, 0.2)$

In [None]:
alpha = np.array([0.2, 0.2, 0.2])

s = np.random.dirichlet(alpha, 5000)

plotdist_3d(s)

In [None]:
plotdist_2d(s)

In [None]:
plotdist_bars(s)