In [None]:
import numpy
import torch
import pyro
import pyro.distributions as dist

import matplotlib
import matplotlib.pyplot as plt

from ppca import BayesianPCA

dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
N_dim = 10
mvn = dist.MultivariateNormal(
    torch.zeros((N_dim)), torch.eye(N_dim)*torch.tensor([10,10]+[1,]*(N_dim-2))
)
samp = mvn.sample([1000])
# 
samp = samp -  torch.mean(samp, dim=0, keepdim=True)


def hinton(matrix, max_weight=None, ax=None):
    """Draw Hinton diagram for visualizing a weight matrix."""
    ax = ax if ax is not None else plt.gca()

    if not max_weight:
        max_weight = 2 ** numpy.ceil(numpy.log2(numpy.abs(matrix).max()))

    ax.patch.set_facecolor('gray')
    ax.set_aspect('equal', 'box')
    ax.xaxis.set_major_locator(plt.NullLocator())
    ax.yaxis.set_major_locator(plt.NullLocator())

    for (x, y), w in numpy.ndenumerate(matrix):
        color = 'white' if w > 0 else 'black'
        size = numpy.sqrt(abs(w) / max_weight)
        rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
                             facecolor=color, edgecolor=color)
        ax.add_patch(rect)

    ax.autoscale_view()
    ax.invert_yaxis()

In [None]:
pca = BayesianPCA(N_dim, a=0.033, b=0.033, c=0.033, d=0.033).to(dev)
guide, l = pca.fit_map(samp.to(dev), 2000, {'lr' : 0.005})

In [None]:
W = guide.median()['W'].cpu().numpy()
_, ax = plt.subplots(1, 2)
hinton(W.T, ax=ax[1])
ax[0].plot(l)
ax[0].set_yscale('log')