In [143]:
from bokeh.layouts import column
from bokeh.models import ColumnDataSource, CustomJS, Slider
from bokeh.plotting import Figure, output_file, show, output_notebook
import bokeh.models as bmo
from bokeh.palettes import d3

import numpy as np

output_notebook()
#output_file("../_includes/js_on_change.html")

from bokeh.transform import linear_cmap, factor_cmap
N= 1000
maxK = 20
K = 5
alpha = 1
pi = np.random.dirichlet(alpha*np.ones(K))
means = np.random.randn(K, 2)
z = np.random.choice(range(K), p=pi, size=N)
data = means[z,:]
data += np.random.randn(*data.shape)*0.15
x = data[:,0]
y = data[:,1]


source = ColumnDataSource(data=dict(x=x, y=y, z=z))

plot = Figure(width=800, height=500)
plot.scatter('x', 'y', source=source, color=linear_cmap('z', "Turbo256", 0, 20))

code_rand_library = """

    function rnormal() {
            var u = 0, v = 0;
            while(u === 0) u = Math.random(); 
            while(v === 0) v = Math.random();
            return Math.sqrt( -2.0 * Math.log( u ) ) * Math.cos( 2.0 * Math.PI * v );
        }

    function sum(nums) {
      var accumulator = 0;
      for (var i = 0, l = nums.length; i < l; i++)
        accumulator += nums[i];
      return accumulator;
    }

    function rbeta(alpha, beta) {
      var alpha_gamma = rgamma(alpha, 1);
      return alpha_gamma / (alpha_gamma + rgamma(beta, 1));
    }

    var SG_MAGICCONST = 1 + Math.log(4.5);
    var LOG4 = Math.log(4.0);

    function rgamma(alpha, beta) {
      if (alpha > 1) {
        var ainv = Math.sqrt(2.0 * alpha - 1.0);
        var bbb = alpha - LOG4;
        var ccc = alpha + ainv;

        while (true) {
          var u1 = Math.random();
          if (!((1e-7 < u1) && (u1 < 0.9999999))) {
            continue;
          }
          var u2 = 1.0 - Math.random();
          var v = Math.log(u1/(1.0-u1))/ainv;
          var x = alpha*Math.exp(v);
          var z = u1*u1*u2;
          var r = bbb+ccc*v-x;
          if (r + SG_MAGICCONST - 4.5*z >= 0.0 || r >= Math.log(z)) {
            return x * beta;
          }
        }
      }
      else if (alpha == 1.0) {
        var u = Math.random();
        while (u <= 1e-7) {
          u = Math.random();
        }
        return -Math.log(u) * beta;
      }
      else { 
        while (true) {
          var u3 = Math.random();
          var b = (Math.E + alpha)/Math.E;
          var p = b*u3;
          if (p <= 1.0) {
            var x = Math.pow(p, (1.0/alpha));
          }
          else {
            var x = -Math.log((b-p)/alpha);
          }
          var u4 = Math.random();
          if (p > 1.0) {
            if (u4 <= Math.pow(x, (alpha - 1.0))) {
              break;
            }
          }
          else if (u4 <= Math.exp(-x)) {
            break;
          }
        }
        return x * beta;
      }
    }
    
    function rdirichlet(alpha) {
      var gammas = [];
      for (var i = 0, l = alpha.length; i < l; i++)
        gammas.push(rgamma(alpha[i], 1));
        
      var accum = sum(gammas)
      
      for (var i = 0, l = gammas.length; i < l; i++)
        gammas[i] = gammas[i]/accum
      return gammas
    }
    
    function rcategorical(pi){
      var u = Math.random();
      var k = 0
      var cum = pi[0]
      for (var i = 1, l = pi.length; i < l; i++)
        if (u <= cum){
            break;
        }
        else{
            cum = cum + pi[i]
            k = k+1
        }
      return k
    }
      
    """

callback = CustomJS(args=dict(source=source), code="""
    const data = source.data;
    const f = cb_obj.value
    const x = data['x']
    const y = data['y']
    const z = data['z']
    for (let i = 0; i < x.length; i++) {
        y[i] = Math.pow(x[i], f)
    }
    source.change.emit();
""")

callback2 = CustomJS(args=dict(source=source), code=code_rand_library + """
    const data = source.data;
    const f = cb_obj.value
    const x = data['x']
    const y = data['y']
    const z = data['z']
    
    
    console.log(rcategorical(rdirichlet([1,1,1,1])))
    for (let i = 0; i < x.length; i++) {
        x[i] = rbeta(0.1,2)
        y[i] = rbeta(0.1,2)
    }
    source.change.emit();
""")

slider = Slider(start=1, end=20, value=5, step=1, title="K")
slider.js_on_change('value', callback)

slider2 = Slider(start=1, end=2000, value=1000, step=10, title="N")
slider2.js_on_change('value', callback2)

layout = column(slider,slider2, plot)


show(layout)

In [98]:
x.shape

(200,)