Skip to content

Commit

Permalink
Merge pull request #19 from sjanssen2/seedset_protoclass
Browse files Browse the repository at this point in the history
Seedset protoclass
  • Loading branch information
qiyunzhu committed Jul 24, 2017
2 parents ad1a207 + 57a45a0 commit 2c30853
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ help:
test:
$(TEST_COMMAND)
pep8:
flake8 genome-subsampler setup.py
flake8 genomesubsampler setup.py
html:
make -C doc clean html

Expand Down
31 changes: 27 additions & 4 deletions genomesubsampler/prototypeSelection.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def prototype_selection_constructive_maxdist(dm: DistanceMatrix,
return [dm.ids[idx] for idx, x in enumerate(uncovered) if not x]


def _protoclass(dm, epsilon):
def _protoclass(dm, epsilon, seedset=None):
'''Heuristically select n prototypes for a fixed epsilon radius.
A ball is drawn around every element in the distance matrix with radius
Expand All @@ -289,6 +289,10 @@ def _protoclass(dm, epsilon):
epsilon: float
Radius for the balls to be "drawn". As a rule of thumb, the larger
epsilon, the less prototypes are found.
seedset: iterable of str
A set of element IDs that are pre-selected as prototypes. Remaining
prototypes are then recruited with the prototype selection algorithm.
Warning: It will most likely violate the global objective function.
Returns
-------
Expand Down Expand Up @@ -316,11 +320,25 @@ def _protoclass(dm: DistanceMatrix, epsilon: float) -> List[str]:
# found prototypes
prototypes = []

# if we have a non empty seedset, we create a new list of those elements
# which is later consumed by the while loop.
seeds = []
if seedset is not None:
seeds = list(seedset)

while True:
# candidate for a new prototype is the element whose epsilon ball
# covers most other elements.
idx_max = scores.argmax()
if (scores[idx_max] > 0):
if (scores[idx_max] > 0) or (len(seeds) > 0):
if len(seeds) > 0:
# if a seedset is give, the best candidate is not the above,
# but an element of the seedset. This is repeated until all
# elements of the seedsets have been consumed. The loop then
# defaults to the normal routine, i.e. uses the scores.argmax()
# element as the next prototype
idx_max = dm.ids.index(seeds[0])
seeds = seeds[1:]
# candidate is new prototype, add it to the list
prototypes.append(idx_max)
# which elements have been just covered by the new prototype
Expand All @@ -336,7 +354,8 @@ def _protoclass(dm: DistanceMatrix, epsilon: float) -> List[str]:
return np.array(dm.ids)[prototypes]


def prototype_selection_constructive_protoclass(dm, num_prototypes, steps=100):
def prototype_selection_constructive_protoclass(dm, num_prototypes, steps=100,
seedset=None):
'''Heuristically select k prototypes for given distance matrix.
Prototype selection is NP-hard. This is an implementation of a greedy
Expand All @@ -361,6 +380,10 @@ def prototype_selection_constructive_protoclass(dm, num_prototypes, steps=100):
otherwise no reduction is necessary.
steps: int
Maximal number of steps used to find a suitable epsilon.
seedset: iterable of str
A set of element IDs that are pre-selected as prototypes. Remaining
prototypes are then recruited with the prototype selection algorithm.
Warning: It will most likely violate the global objective function.
Returns
-------
Expand Down Expand Up @@ -409,7 +432,7 @@ def prototype_selection_constructive_protoclass(dm: DistanceMatrix,
# increase the stepsize in each iteration to converge faster
stepSize *= 1.1
# call the protoclass with a defined epsilon
prototypes = _protoclass(dm, epsilon)
prototypes = _protoclass(dm, epsilon, seedset)
# check if direction of epsilon changes has changed
if len(prototypes) > num_prototypes:
direction = +1
Expand Down
20 changes: 20 additions & 0 deletions genomesubsampler/tests/test_prototypeSelection.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,14 @@ def test__protoclass(self):
res)
self.assertAlmostEqual(101.91549799314, distance_sum(res, self.dm100))

# test seedset function, i.e. are 'A' and 'B' included in prototypes
res = _protoclass(self.dm20, 0.405, seedset=['A', 'B'])
self.assertCountEqual(res, ['A', 'B', 'D', 'Q'])

# test if at least one seed element is returned for too high epsilon
res = _protoclass(self.dm20, 0.805, seedset=['A', 'B'])
self.assertCountEqual(res, ['A', 'B'])

def test_prototype_selection_constructive_pMedian(self):
self.assertRaisesRegex(
ValueError,
Expand Down Expand Up @@ -568,6 +576,12 @@ def test_seedset(self):
self.assertCountEqual(('A', 'P', 'T', 'C', 'O'), res)
self.assertAlmostEqual(5.4494, distance_sum(res, self.dm20))

seedset = set(['H', 'C'])
res = prototype_selection_constructive_protoclass(
self.dm20, 5, seedset=seedset)
self.assertCountEqual(('H', 'C', 'Q', 'A', 'G'), res)
self.assertAlmostEqual(5.2747, distance_sum(res, self.dm20))

# then include different elements, to see result changes, and score
# (sum of distances) slightly drops.
seedset = ['G', 'I']
Expand All @@ -590,6 +604,12 @@ def test_seedset(self):
self.assertCountEqual(('A', 'G', 'I', 'K', 'T'), res)
self.assertAlmostEqual(5.3082, distance_sum(res, self.dm20))

seedset = set(['G', 'I'])
res = prototype_selection_constructive_protoclass(
self.dm20, 5, seedset=seedset)
self.assertCountEqual(('I', 'G', 'B', 'Q', 'A'), res)
self.assertAlmostEqual(5.1918, distance_sum(res, self.dm20))

# test on the n=100 distance matrix
seedset = ['550.L1S18.s.1.sequence', '550.L1S142.s.1.sequence',
'550.L1S176.s.1.sequence']
Expand Down

0 comments on commit 2c30853

Please sign in to comment.