Skip to content

Commit

Permalink
BUG oborchers#35 -- Updated the remove_principal_components function
Browse files Browse the repository at this point in the history
  • Loading branch information
GrantWilliams committed Feb 9, 2021
1 parent bf09ea9 commit ad88c81
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
7 changes: 4 additions & 3 deletions fse/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def remove_principal_components(
weights : ndarray, optional
Weights to be used to weigh the components which are removed from the vectors
inplace : bool, optional
If true, removes the componentens from the vectors inplace (memory efficient)
If true, removes the components from the vectors inplace (memory efficient)
Returns
-------
Expand All @@ -132,15 +132,16 @@ def remove_principal_components(
output = None
if len(components) == 1:
if not inplace:
output = vectors.dot(w_comp.transpose()) * w_comp
output = vectors - vectors.dot(w_comp.transpose()) * w_comp
else:
vectors -= vectors.dot(w_comp.transpose()) * w_comp
else:
if not inplace:
output = vectors.dot(w_comp.transpose()).dot(w_comp)
output = vectors - vectors.dot(w_comp.transpose()).dot(w_comp)
else:
vectors -= vectors.dot(w_comp.transpose()).dot(w_comp)
elapsed = time()

logger.info(
f"removing {len(components)} principal components took {int(elapsed-start)}s"
)
Expand Down
26 changes: 22 additions & 4 deletions fse/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest

import numpy as np
from numpy.testing import assert_allclose, assert_raises

from fse.models.utils import compute_principal_components, remove_principal_components

Expand Down Expand Up @@ -30,21 +31,38 @@ def test_compute_large_components(self):

def test_remove_components_inplace(self):
m = np.ones((500, 10), dtype=np.float32)
c = np.copy(m)
out = compute_principal_components(vectors=m)
remove_principal_components(m, svd_res=out)
self.assertTrue(np.allclose(0.0, m, atol=1e-5))
assert_allclose(m, 0.0, atol=1e-5)
with assert_raises(AssertionError):
assert_allclose(m, c)


def test_remove_components(self):
m = np.ones((500, 10), dtype=np.float32)
c = np.copy(m)
out = compute_principal_components(vectors=m)
res = remove_principal_components(m, svd_res=out, inplace=False)
self.assertTrue(np.allclose(1.0, res, atol=1e-5))
assert_allclose(res, 0.0, atol=1e-5)
assert_allclose(m, c)

def test_remove_weighted_components(self):
def test_remove_weighted_components_inplace(self):
m = np.ones((500, 10), dtype=np.float32)
c = np.copy(m)
out = compute_principal_components(vectors=m)
remove_principal_components(m, svd_res=out, weights=np.array([0.5]))
self.assertTrue(np.allclose(0.75, m))
assert_allclose(m, 0.75, atol=1e-5)
with assert_raises(AssertionError):
assert_allclose(m, c)

def test_remove_weighted_components(self):
m = np.ones((500, 10), dtype=np.float32)
c = np.copy(m)
out = compute_principal_components(vectors=m)
res = remove_principal_components(m, svd_res=out, weights=np.array([0.5]), inplace=False)
assert_allclose(res, 0.75, atol=1e-5)
assert_allclose(m, c)

def test_madvise(self):
from pathlib import Path
Expand Down

0 comments on commit ad88c81

Please sign in to comment.