Skip to content

Commit

Permalink
DEMO: Update PCA demo
Browse files Browse the repository at this point in the history
  • Loading branch information
jluttine committed May 5, 2015
1 parent 02fca27 commit dbbb621
Showing 1 changed file with 20 additions and 26 deletions.
46 changes: 20 additions & 26 deletions bayespy/demos/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,15 @@ def model(M, N, D):
Y = nodes.GaussianARD(F, tau,
name='Y')

return (Y, F, W, X, tau, alpha)
# Initialize some nodes randomly
X.initialize_from_random()
W.initialize_from_random()

return VB(Y, F, W, X, tau, alpha)


@bpplt.interactive
def run(M=10, N=100, D_y=3, D=5, seed=42, rotate=False, maxiter=100, debug=False, plot=True):
def run(M=10, N=100, D_y=3, D=5, seed=42, rotate=False, maxiter=1000, debug=False, plot=True):

if seed is not None:
np.random.seed(seed)
Expand All @@ -84,45 +88,35 @@ def run(M=10, N=100, D_y=3, D=5, seed=42, rotate=False, maxiter=100, debug=False
w = np.random.normal(0, 1, size=(M,1,D_y))
x = np.random.normal(0, 1, size=(1,N,D_y))
f = misc.sum_product(w, x, axes_to_sum=[-1])
y = f + np.random.normal(0, 0.2, size=(M,N))
y = f + np.random.normal(0, 0.1, size=(M,N))

# Construct model
(Y, F, W, X, tau, alpha) = model(M, N, D)
Q = model(M, N, D)

# Data with missing values
mask = random.mask(M, N, p=0.5) # randomly missing
y[~mask] = np.nan
Y.observe(y, mask=mask)

# Construct inference machine
Q = VB(Y, W, X, tau, alpha)

# Initialize some nodes randomly
X.initialize_from_random()
W.initialize_from_random()
Q['Y'].observe(y, mask=mask)

# Run inference algorithm
if rotate:
# Use rotations to speed up learning
rotW = transformations.RotateGaussianARD(W, alpha)
rotX = transformations.RotateGaussianARD(X)
rotW = transformations.RotateGaussianARD(Q['W'], Q['alpha'])
rotX = transformations.RotateGaussianARD(Q['X'])
R = transformations.RotationOptimizer(rotW, rotX, D)
for ind in range(maxiter):
Q.update()
if debug:
R.rotate(check_bound=True,
check_gradient=True)
else:
R.rotate()

else:
# Use standard VB-EM alone
Q.update(repeat=maxiter)
if debug:
Q.callback = lambda : R.rotate(check_bound=True,
check_gradient=True)
else:
Q.callback = R.rotate

# Use standard VB-EM alone
Q.update(repeat=maxiter)

# Plot results
if plot:
plt.figure()
bpplt.timeseries_normal(F, scale=2)
bpplt.timeseries_normal(Q['F'], scale=2)
bpplt.timeseries(f, color='g', linestyle='-')
bpplt.timeseries(y, color='r', linestyle='None', marker='+')

Expand Down

0 comments on commit dbbb621

Please sign in to comment.