In [1]:
import numpy as np

In [49]:
class GWO:
    
    def __init__(self, x, f, n_epochs=100, prng=None):
        self.n, self.d = x.shape
        self.n_epochs = n_epochs
        self._prng = prng or np.random.RandomState()
        self._f = f
        self._t = 0
        self._x = x
        self._best3 = self._get_best3()
        
    def _create_a(self):
        unif = self._prng.random(size=self.n)
        return 2*(1 - self._t/(self.n_epochs - 1))*(2*unif - 1)
    
    def _create_c(self):
        unif = self._prng.random(size=self.n)
        return 2*unif
    
    def _follow_one(self, leader):
        a = self._create_a()
        c = self._create_c()
        return leader[np.newaxis, :] - a[:, np.newaxis] * np.abs(np.outer(c, leader) - self._x)
    
    def _get_best3(self):
        fvals = self._f(self._x)
        return np.argsort(fvals)[:3]

    def _one_step(self):
        xa, xb, xc = self._x[self._best3]
        x1 = self._follow_one(xa)
        x2 = self._follow_one(xb)
        x3 = self._follow_one(xc)
        self._x = (x1 + x2 + x3) / 3.0
        self._t += 1
        self._best3 = self._get_best3()
        
    def __iter__(self):
        for t in range(self.n_epochs):
            self._one_step()
            fvals = self._f(self._x)
            yield self._x, fvals, self._best3

In [50]:
def f(x, k):
    return np.sum((x - k)**2, axis=1)

In [56]:
prng = np.random.RandomState(12345)
x = prng.uniform(-100, 100, size=(6, 4))
gwo = GWO(x, lambda x: f(x, 10), 100, prng=prng)

for x, fvals, b3 in gwo:
    print(x[b3[0]])


[ 22.1277083    4.8330249  -11.61328428 -27.95097773]
[21.02951363  8.66215152  9.07953409 -8.78045209]
[17.31128372  7.7328166  14.03121218 -9.9871219 ]
[16.49486332  8.30662219 17.22625844 -6.10829973]
[ 6.64784855 11.93164357 18.10583267 -7.37154595]
[18.4726593  14.23095188 23.53340665 -2.19699889]
[14.05254114 10.75722109 18.08723369 -4.3289203 ]
[11.8939974   8.96310111 15.27277517 -7.48844826]
[12.95585688  9.80284396 16.8404095  -1.10278335]
[14.92430338 11.15707011 19.30614739  3.74322087]
[ 8.35867468  6.32838019 10.93313571  0.35117634]
[10.51919424  8.05582575 13.7427872  -1.76376609]
[10.87908359  8.3428659  14.2285808  -0.35336904]
[ 9.5611202   7.3289597  12.50188104 -1.01059087]
[11.94142609  9.15409946 15.61446486  0.18985303]
[12.12245138  9.29282302 15.85117866  0.27451437]
[6.72793921 5.15863092 8.79808419 0.42099255]
[ 8.7793183   6.7288832  11.47793804  0.38815135]
[10.39261879  7.96501567 13.5864379   0.15891615]
[ 8.64685235  6.62757281 11.30482296 -0.05931043]
