In [1]:
import numpy as np

In [120]:
px = np.array([1/4, 1/4, 1/4, 1/4], dtype=np.float32)
pz = np.array([1/2, 1/2], dtype=np.float32)
py_x = np.array([
    [0.9, 0.1],
    [0.9, 0.1],
    [0.1, 0.9],
    [0.1, 0.9]
], dtype=np.float32)

In [121]:
def bottleneck(pz_x, py_z, beta=1.0):
    
    pxz = px[:, np.newaxis] * pz_x
    
    outer = px[:, np.newaxis] * pz[np.newaxis, :]
    
    mi = pxz * np.log(pxz / outer)
    mi = np.sum(mi)
        
    kl = py_x[:, np.newaxis, :] * np.log(py_z[np.newaxis, :, :] / py_x[:, np.newaxis, :])
    kl = np.sum(kl, axis=2)
    kl = kl * pxz
    kl = np.sum(kl)
    kl = - beta * kl
    
    return mi, kl

In [134]:
best_loss = None
best_pz_x = None
best_py_z = None

for i in range(10000):

    pz_x = np.random.dirichlet(np.ones(pz.shape[0]), size=(px.shape[0]))

    py_z = np.random.dirichlet(np.ones(py_x.shape[1]), size=(pz.shape[0]))

    
    mi, kl = bottleneck(pz_x, py_z, beta=100.0)

    loss = mi + kl

    if not np.any(np.isnan(loss)) and np.all(np.isfinite(loss)):
        
        if best_loss is None or loss < best_loss:
            best_loss = loss
            best_pz_x = pz_x
            best_py_z = py_z
            
print(best_loss)
print(best_pz_x)
print(best_py_z)

10.194708579433948
[[0.00775231 0.99224769]
 [0.02227872 0.97772128]
 [0.90212296 0.09787704]
 [0.95360247 0.04639753]]
[[0.22434759 0.77565241]
 [0.90247304 0.09752696]]
