In [1]:
pcentre, ncentre, scentre = [4, 4], [-4, -4], [-8, -8]
low, high, step = -7, 7, 0.2
I = 100

# means of the distributions

In [2]:
import numpy as np
mu = {}
mu[0] = lambda n=1: np.array(pcentre) + np.random.randn(n, 2)
mu[1] = lambda n=1: np.array(ncentre) + np.random.randn(n, 2)
sampler = lambda n=1: np.array(scentre) + 16 * np.random.randn(n, 2)

# mesh

In [3]:
x1 = np.arange(low, high, step)
x2 = np.arange(low, high, step)
X1, X2 = np.meshgrid(x1, x2, indexing='xy')

# multivariate normal

In [4]:
from scipy.stats import multivariate_normal
F = {}

F[1] = multivariate_normal.pdf(np.array([np.ravel(X1), np.ravel(X2)]).transpose(), pcentre, np.eye(2))
F[1] = np.reshape(F[1], (len(X1), len(X2)))

F[2] = multivariate_normal.pdf(np.array([np.ravel(X1), np.ravel(X2)]).transpose(), ncentre, np.eye(2))
F[2] = np.reshape(F[2], (len(X1), len(X2)))

# powercell density

In [5]:
from scipy.spatial.distance import cdist
def sample_powercell_density(X = None, w = None, nu = None):
    n = np.array(X).shape[0]
    if n == 1:
        rho = np.ones((n,1))
        return rho
    
    in_ = np.zeros((n,1))
    cnt = 64000
    Y = np.array(nu(cnt))
    
    product = cdist(X,Y)
    if len(w.shape) == 1:
        w = w.reshape((-1, 1))
    difference = product - w
    
    idx = np.argmin(difference, axis=0)
    for i in range(cnt):
        in_[idx[i]] = in_[idx[i]] + 1
    rho = in_ / cnt
    
    return rho

# weight update

In [6]:
def weight_update(X = None, w = None, mu = None):
    n = len(w)
    grad = 1 / n - sample_powercell_density(X, w, mu)
    alpha = 0.001
    beta = 0.99
    z = np.zeros(n)
    normGrad = np.linalg.norm(grad)
    i = 1
    while normGrad > 0.01: # 0.0001
        if i % 100 == 0:
            print('Iter: ', i,  'norm: ',normGrad)
        i = i +1
        
        if len(z.shape) == 1:
            z = z.reshape((-1, 1))
        z = beta * z + grad
        if len(w.shape) == 1:
            w = w.reshape((-1, 1))
        w = w + alpha*z
        grad = 1 / n - sample_powercell_density(X, w, mu)
        normGrad = np.linalg.norm(grad)
        
    return w

# powercell_means

In [7]:
def sample_powercell_means(X, w, nu):
    bary = np.zeros((np.array(X).shape[0], np.array(X).shape[1]))
    cnt = 64000
    Y = nu(cnt)
    n = np.array(X).shape[0]
    in_= np.zeros((n,1))
    if n == 1:
        bary = np.mean(Y, axis=1)
        return bary

    product = cdist(X,Y)
    if len(w.shape) == 1:
        w = w.reshape((-1,1))
    difference = product - w
    
    idx = np.argmin(difference, axis=0)
    for i in range(cnt):
        bary[idx[i]] = bary[idx[i]] + Y[i]
        in_[idx[i]] = in_[idx[i]] + 1
    
    for i in range(n):
        bary[i] = bary[i] / in_[i]
    return bary

# powercell update

In [None]:
def powercell_update(X = None, w = None, mu = None, sampler = None):
    n, m, y = np.array(X).shape[0], len(mu), sampler(1)
    if n == 0:
        X.append(y[0])
        X = np.array(X)
        w = np.zeros((1,m))
    else:
        X = np.vstack([X, y[0]])
        w = np.vstack([w, np.zeros((1, m))])
    
    for t in range(10):
        for k in range(m):
            retur = weight_update(X, w[:, k], mu[k])
            w[:, k] = np.squeeze(retur)
        n = np.array(X).shape[0]
        Xnew = np.zeros(np.array(X).shape[0])
        Msum = np.zeros((n,1))
        
        for k in range(len(mu)):
            M = sample_powercell_density(X, w[:, k], mu[k])
            B = sample_powercell_means(X, w[:, k], mu[k])
    return X, w

# files

In [None]:
from tqdm import tqdm
filename='fw-updates.gif'
X, w = [], []
for i in tqdm(range(I)):
    X, w = powercell_update(X, w, mu, sampler)

  1%|          | 1/100 [00:00<00:14,  6.69it/s]

Iter:  100 norm:  0.7071067811865476
Iter:  200 norm:  0.7071067811865476
Iter:  300 norm:  0.7071067811865476
Iter:  400 norm:  0.7071067811865476
Iter:  500 norm:  0.7071067811865476
Iter:  600 norm:  0.7071067811865476
Iter:  700 norm:  0.6437544330096153
Iter:  100 norm:  0.7071067811865476
Iter:  200 norm:  0.7071067811865476
Iter:  300 norm:  0.7071067811865476
Iter:  400 norm:  0.706974198665075
Iter:  500 norm:  0.5191489599123984
Iter:  600 norm:  0.5349483770545354


  2%|▏         | 2/100 [02:53<2:21:42, 86.76s/it]

Iter:  100 norm:  0.816496580927726
Iter:  200 norm:  0.816420035569416
Iter:  300 norm:  0.18235783382485016
Iter:  400 norm:  0.40915713270672877
Iter:  500 norm:  0.2917249391193127
Iter:  600 norm:  0.24005113929987223
Iter:  700 norm:  0.09447464817153552
Iter:  800 norm:  0.10398954916065083
Iter:  900 norm:  0.026920005315711152
Iter:  1000 norm:  0.0410635365007438
Iter:  100 norm:  0.27890928248818275
Iter:  200 norm:  0.21148354939963218
Iter:  300 norm:  0.17467274693590704
Iter:  400 norm:  0.10287224312198076
Iter:  500 norm:  0.046290714413210525


  3%|▎         | 3/100 [06:37<3:34:04, 132.42s/it]

Iter:  100 norm:  0.8660254037844386
Iter:  200 norm:  0.8660254037844386
Iter:  300 norm:  0.27081882583339967
Iter:  400 norm:  0.2899875662407248
Iter:  500 norm:  0.28903068997269776
Iter:  600 norm:  0.22912565956499448
Iter:  700 norm:  0.21488785342773858
Iter:  800 norm:  0.10177506841452502
Iter:  900 norm:  0.08913429945786162
Iter:  1000 norm:  0.036770541931092735
Iter:  1100 norm:  0.03055403129089516
Iter:  1200 norm:  0.01596873471134299
Iter:  100 norm:  0.8660254037844386
Iter:  200 norm:  0.646916989503749
Iter:  300 norm:  0.2935344213937588
Iter:  400 norm:  0.28904928094774124
Iter:  500 norm:  0.2563150641483543
Iter:  600 norm:  0.2885841202034404
Iter:  700 norm:  0.0558607582870681
Iter:  800 norm:  0.08539306235085055
Iter:  900 norm:  0.06631114513224756
Iter:  1000 norm:  0.03311384688417747
Iter:  1100 norm:  0.02662214810136945


  4%|▍         | 4/100 [12:03<4:49:12, 180.76s/it]

Iter:  100 norm:  0.7293301068687321
Iter:  200 norm:  0.23682894093759954
Iter:  300 norm:  0.22384221432532003
Iter:  400 norm:  0.22543619610795645
Iter:  500 norm:  0.3271197495817186
Iter:  600 norm:  0.19793809944257196
Iter:  700 norm:  0.22429498268533002
Iter:  800 norm:  0.13775536478607195
Iter:  900 norm:  0.2822427686739733
Iter:  1000 norm:  0.10205859606740628
Iter:  100 norm:  0.6860592196958565
Iter:  200 norm:  0.24586349446421737
Iter:  300 norm:  0.22631840838220882
Iter:  400 norm:  0.19936163013557173
Iter:  500 norm:  0.2074836919419813
Iter:  600 norm:  0.13552479016770325
Iter:  700 norm:  0.07321753293376015
Iter:  800 norm:  0.04200638786058318


  5%|▌         | 5/100 [16:28<5:12:54, 197.63s/it]

Iter:  100 norm:  0.18259719247710474
Iter:  200 norm:  0.18265974743586322
Iter:  300 norm:  0.18258497549015248
Iter:  400 norm:  0.18259783024635612
Iter:  500 norm:  0.18257899037913586
Iter:  600 norm:  0.06813022724593933
Iter:  700 norm:  0.1867029146876518
Iter:  800 norm:  0.12047775109585412
Iter:  900 norm:  0.097406675254392
Iter:  1000 norm:  0.09307626311368722
Iter:  1100 norm:  0.04264106109310406
Iter:  1200 norm:  0.03772419742836929
Iter:  1300 norm:  0.013950398086464541
Iter:  100 norm:  0.1816674283539556
Iter:  200 norm:  0.09621367685824785
Iter:  300 norm:  0.09312708331003035
Iter:  400 norm:  0.08065980857931714
Iter:  500 norm:  0.030805904360845893
Iter:  600 norm:  0.015890658288231278


  6%|▌         | 6/100 [20:56<5:28:02, 209.39s/it]

Iter:  100 norm:  0.15441734368125848
Iter:  200 norm:  0.1535687693135624
Iter:  300 norm:  0.08497475908715271
Iter:  400 norm:  0.13403532363784146
Iter:  500 norm:  0.05127702426112393
Iter:  600 norm:  0.035103671992946735
Iter:  100 norm:  0.15462111332488165
Iter:  200 norm:  0.15433236764026817
Iter:  300 norm:  0.1544052735191172
Iter:  400 norm:  0.15433269984197578
Iter:  500 norm:  0.15437528692497918
Iter:  600 norm:  0.154327128250191
Iter:  700 norm:  0.15432492613460372
Iter:  800 norm:  0.1543323407477177
Iter:  900 norm:  0.13799046125281048
Iter:  1000 norm:  0.24345723647185638
Iter:  1100 norm:  0.2184616441285681
Iter:  1200 norm:  0.09124213212166375


#  'r': [run all cells](https://stackoverflow.com/questions/33143753/jupyter-ipython-notebooks-shortcut-for-run-all)

In [None]:
get_ipython().run_cell_magic('javascript', '', "\nJupyter.keyboard_manager.command_shortcuts.add_shortcut('r', {\n    help : 'run all cells',\n    help_index : 'zz',\n    handler : function (event) {\n        IPython.notebook.execute_all_cells();\n        return false;\n    }}\n);")