Skip to content

Commit

Permalink
Fixed bug in how sign of low-rank STA is computed. Added tests for all
Browse files Browse the repository at this point in the history
of visualizations module.
  • Loading branch information
bnaecker committed Nov 16, 2016
1 parent 6fc55a0 commit 453b9b0
Show file tree
Hide file tree
Showing 17 changed files with 358 additions and 103 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ dist/
.coverage
htmlcov/
.cache/
.DS_Store
25 changes: 14 additions & 11 deletions pyret/filtertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,14 @@ def lowranksta(f_orig, k=10):
the rank-k filter
u : array_like
the top k spatial components (each row is a component)
the top ``k`` temporal components (each column is a component).
s : array_like
the top k singular values
the top ``k`` singular values.
u : array_like
the top k temporal components (each column is a component)
v : array_like
the top ``k`` spatial components (each row is a component). These
components have all spatial dimensions collapsed to one.
Notes
-----
Expand All @@ -205,23 +206,25 @@ def lowranksta(f_orig, k=10):
"""

# work with a copy of the filter (prevents corrupting the input)
f = f_orig.copy()
f = f_orig.copy() - f_orig.mean()

# Compute the SVD of the full filter
assert f.ndim >= 2, "Filter must be at least 2-D"
u, s, v = np.linalg.svd(f.reshape(f.shape[0], -1) - np.mean(f),
full_matrices=False)
u, s, v = np.linalg.svd(f.reshape(f.shape[0], -1), full_matrices=False)

# Keep the top k components
k = np.min([k, s.size])
u = u[:, :k]
s = s[:k]
v = v[:k, :]

# Compute the rank-k filter
fk = (u[:, :k].dot(np.diag(s[:k]).dot(v[:k, :]))).reshape(f.shape)
fk = (u.dot(np.diag(s).dot(v))).reshape(f.shape)

# Ensure that the computed filter components have the correct sign.
# The mean-subtracted filter should have positive projection onto
# the low-rank filter.
sign = np.sign(fk.ravel().dot((f - np.mean(f)).ravel()))
# The full STA should have positive projection onto first temporal
# component of the low-rank STA.
sign = np.sign(np.einsum('i,ijk->jk', u[:, 0], f).sum())
u *= sign
v *= sign

Expand Down

0 comments on commit 453b9b0

Please sign in to comment.