In [None]:
import os

import numpy
import torch
import pyro
import pyro.distributions as dist

import matplotlib
import matplotlib.pyplot as plt
%matplotlib widget

from ppca.vanilla import *

dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
def hinton(matrix, max_weight=None, ax=None):
    """Draw Hinton diagram for visualizing a weight matrix."""
    ax = ax if ax is not None else plt.gca()

    if not max_weight:
        max_weight = 2 ** numpy.ceil(numpy.log2(numpy.abs(matrix).max()))

    ax.patch.set_facecolor('gray')
    ax.set_aspect('equal', 'box')
    ax.xaxis.set_major_locator(plt.NullLocator())
    ax.yaxis.set_major_locator(plt.NullLocator())

    for (x, y), w in numpy.ndenumerate(matrix):
        color = 'white' if w > 0 else 'black'
        size = numpy.sqrt(abs(w) / max_weight)
        rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
                             facecolor=color, edgecolor=color)
        ax.add_patch(rect)

    ax.autoscale_view()
    ax.invert_yaxis()

In [None]:
N_dim = 10
mvn = dist.MultivariateNormal(
    torch.zeros((N_dim)), torch.eye(N_dim)*torch.tensor([10,10]+[1,]*(N_dim-2))
)
samp = mvn.sample([1000])
# 
samp = samp -  torch.mean(samp, dim=0, keepdim=True)

fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(samp[:,0], samp[:,1], samp[:,2])
ax.set_xlim([-10,10])
ax.set_ylim([-10,10])

In [None]:
pca = ProbabilisticPCA(N_dim, 3).to(dev)
l, guide = pca.fit_map(samp.to(dev), 1000, {'lr' : 0.01}, fix_sigma=False)
_, ax = plt.subplots()
ax.plot(l)
pyro.param("loc").item()

In [None]:
W = pca.W.weight.detach().cpu().numpy()
_, ax = plt.subplots()
hinton(W, ax=ax)

In [None]:
#comp = torch.pca_lowrank(samp, 1)
#comp[2].T[:1]

## variational

In [None]:
from ppca.variational import *

In [None]:
pca = BayesianPCA(N_dim).to(dev)
l, guide = pca.fit_map(samp.to(dev), 1, {'lr' : 0.001})

In [None]:
_, ax = plt.subplots()
ax.plot(l)
#ax.set_yscale('log')
num_ws = len([k for k in guide.median().keys() if k[0]=='w'])
#W = torch.stack([guide.median()['w_{:d}'.format(i)]
#                 for i in range(num_ws)], dim=1)
W = guide.median()['W']
alphas = guide.median()['alpha']
#alphas = [guide.median()['alpha'.format(i)].item() for i in range(num_ws)]
print(alphas)

In [None]:
beta = 1/guide.median()['tau']
print(beta)
filt = (torch.abs(W) > beta).cpu().numpy()
_, ax = plt.subplots()
hinton(W.detach().cpu().numpy().T, ax=ax)

In [None]:
pyro.render_model(pca.model, model_args=(samp.to(dev),))