Skip to content

Commit

Permalink
Made CLUE more error-tolerant
Browse files Browse the repository at this point in the history
  • Loading branch information
Callidior committed Oct 25, 2017
1 parent 75c54ea commit 98bde4c
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions clue.py
Expand Up @@ -11,7 +11,7 @@

## CLUE ##

def clue(features, queries, select_clusters, k = 200, max_clusters = 10, T = 0.9, show_progress = False):
def clue(features, queries, select_clusters, k = 200, max_clusters = 10, T = 0.9, min_cluster_size = 2, show_progress = False):
""" CLUE method for cluster-based relevance feedback in image retrieval.
Reference:
Expand All @@ -33,6 +33,8 @@ def clue(features, queries, select_clusters, k = 200, max_clusters = 10, T = 0.9
T - Threshold for the n-cut value. Nodes with an n-cut value larger than this threshold won't be subdivided any further.
min_cluster_size - Minimum number of items per cluster.
show_progress - If True, a progress bar will be shown (requires tqdm).
Returns: re-ranked retrieval results as dictionary mapping query IDs to tuples consisting of an ordered list of retrieved image IDs
Expand All @@ -49,7 +51,7 @@ def clue(features, queries, select_clusters, k = 200, max_clusters = 10, T = 0.9
query_feat = features[query['img_id']]

# Spectral clustering of top results
tree = RecursiveNormalizedCuts(max_clusters, T)
tree = RecursiveNormalizedCuts(max_clusters, T, min_cluster_size)
tree.fit([(id, features[id]) for id in ret[:k]])
clusters = tree.clusters()

Expand All @@ -74,10 +76,11 @@ def clue(features, queries, select_clusters, k = 200, max_clusters = 10, T = 0.9

class RecursiveNormalizedCuts(object):

def __init__(self, max_clusters, T):
def __init__(self, max_clusters, T, min_cluster_size = 2):
object.__init__(self)
self.max_clusters = max_clusters
self.T = T
self.min_cluster_size = min_cluster_size
self.tree = { 'depth' : 0, 'height' : 0, 'size' : 0, 'leafs' : 1, 'children' : [], 'parent' : None, 'items' : [], 'affinity' : [] }


Expand All @@ -93,7 +96,7 @@ def fit(self, feat):
queue = []
heapq.heappush(queue, (-1 * len(self.tree['items']), np.random.rand(), self.tree))
while (self.tree['leafs'] < self.max_clusters) and (len(queue) > 0):
if len(queue[0][2]['items']) < 2:
if len(queue[0][2]['items']) <= self.min_cluster_size:
break
left, right, ncut_value = self.split(heapq.heappop(queue)[2])
if ncut_value > self.T:
Expand All @@ -106,14 +109,19 @@ def fit(self, feat):
def split(self, node):

# Perform normalized cut
ind = SpectralClustering(2, affinity = 'precomputed', assign_labels = 'discretize').fit_predict(node['affinity'])
try:
ind = SpectralClustering(2, affinity = 'precomputed', assign_labels = 'discretize').fit_predict(node['affinity'])
except KeyboardInterrupt:
raise
except:
return None, None, 0

# Create left and right node
mask1, mask2 = (ind == 0), (ind == 1)
if not (np.any(mask1) and np.any(mask2)):
return None, None, 0
left = { 'depth' : node['depth'] + 1, 'height' : 0, 'size' : 0, 'leafs' : 1, 'children' : [], 'parent' : node, 'items' : [f for i, f in enumerate(node['items']) if ind[i] == 0], 'affinity' : node['affinity'][mask1,:][:,mask1] }
right = { 'depth' : node['depth'] + 1, 'height' : 0, 'size' : 0, 'leafs' : 1, 'children' : [], 'parent' : node, 'items' : [f for i, f in enumerate(node['items']) if ind[i] == 1], 'affinity' : node['affinity'][mask2,:][:,mask2] }
left = { 'depth' : node['depth'] + 1, 'height' : 0, 'size' : 0, 'leafs' : 1, 'children' : [], 'parent' : node, 'items' : [f for i, f in enumerate(node['items']) if ind[i] == 0], 'affinity' : node['affinity'][np.ix_(mask1, mask1)] }
right = { 'depth' : node['depth'] + 1, 'height' : 0, 'size' : 0, 'leafs' : 1, 'children' : [], 'parent' : node, 'items' : [f for i, f in enumerate(node['items']) if ind[i] == 1], 'affinity' : node['affinity'][np.ix_(mask2, mask2)] }

# Force the node with the lower minimum distance to the query to be the left node
if ind[0] == 1: # items are already sorted when passed to fit(), so we just need to look at the first item instead of re-computing all distances
Expand Down

0 comments on commit 98bde4c

Please sign in to comment.