Skip to content

Commit

Permalink
General refactor and default configuration change
Browse files Browse the repository at this point in the history
  • Loading branch information
mkpaszkiewicz committed Jan 4, 2017
1 parent cc38f99 commit d85a51b
Show file tree
Hide file tree
Showing 16 changed files with 71 additions and 64 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -10,7 +10,7 @@ Package requires [OpenCV](http://www.opencv.org). To install it I highly recomme

Run below commands in terminal:
```
$ conda create -n opencv numpy scipy python=3
$ conda create -n opencv numpy python=3
$ source activate opencv
$ conda install -c https://conda.binstar.org/menpo opencv3
```
Expand Down
8 changes: 4 additions & 4 deletions setup.py
@@ -1,14 +1,14 @@
from setuptools import setup

setup(name='vse',
version='0.1.4',
version='0.1.5',
author='Marcin K. Paszkiewicz',
author_email='mkpaszkiewicz@gmail.com',
description='Configurable visual search engine based on the OpenCV',
description='A visual search engine using local features descriptors and bag of words, based on OpenCV',
url='https://github.com/mkpaszkiewicz/vse',
download_url='https://github.com/mkpaszkiewicz/vse/tarball/0.1.3',
download_url='https://github.com/mkpaszkiewicz/vse/tarball/0.1.5',
packages=['vse'],
keywords=['visual', 'search', 'engine', 'computer', 'vision'],
keywords=['visual search engine computer vision local descriptors BoW'],
install_requires=[
'NumPy'
]
Expand Down
6 changes: 3 additions & 3 deletions tests/__main__.py
@@ -1,9 +1,9 @@
"""Run all tests"""

import unittest
from tests.comparator_test import ComparatorTest
from tests.engine_test import VisualSearchEngineTest, BagOfVisualWordsTest
from tests.utils_test import UtilityTest
from .comparator_test import ComparatorTest
from .engine_test import VisualSearchEngineTest, BagOfVisualWordsTest
from .utils_test import UtilityTest


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion tests/comparator_test.py
@@ -1,6 +1,6 @@
import cv2
import unittest
from unittest.mock import *
from unittest.mock import Mock, patch, call
from vse.comparator import *
from vse.comparator import cosine_angle

Expand Down
4 changes: 2 additions & 2 deletions tests/engine_test.py
@@ -1,7 +1,7 @@
import unittest
from unittest.mock import *
from unittest.mock import Mock, patch

from vse import *
from vse import VisualSearchEngine, BagOfVisualWords


class VisualSearchEngineTest(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/utils_test.py
@@ -1,5 +1,5 @@
import unittest
from unittest.mock import *
from unittest.mock import Mock, patch, mock_open

from vse import *

Expand Down
Binary file removed vocabulary/vocabulary_sift_100.dat
Binary file not shown.
Binary file added vocabulary/vocabulary_sift_1k.dat
Binary file not shown.
Binary file added vocabulary/vocabulary_surf_1k.dat
Binary file not shown.
2 changes: 1 addition & 1 deletion vse/__init__.py
Expand Up @@ -6,4 +6,4 @@
from vse.ranker import *
from vse.utils import *

__version__ = '0.1.4'
__version__ = '0.1.5'
8 changes: 4 additions & 4 deletions vse/comparator.py
Expand Up @@ -17,7 +17,7 @@


class HistComparator(metaclass=ABCMeta):
REVERSED = False
reversed = False

@abstractmethod
def compare(self, h1, h2):
Expand All @@ -26,7 +26,7 @@ def compare(self, h1, h2):


class Correlation(HistComparator):
REVERSED = True
reversed = True

def compare(self, h1, h2):
return cv2.compareHist(h1, h2, cv2.HISTCMP_CORREL)
Expand All @@ -38,7 +38,7 @@ def compare(self, h1, h2):


class Intersection(HistComparator):
REVERSED = True
reversed = True

def compare(self, h1, h2):
return cv2.compareHist(h1, h2, cv2.HISTCMP_INTERSECT)
Expand Down Expand Up @@ -70,7 +70,7 @@ def compare(self, h1, h2):


class CosineAngle(HistComparator):
REVERSED = True
reversed = True

def compare(self, h1, h2):
return cosine_angle(h1, h2)
Expand Down
33 changes: 16 additions & 17 deletions vse/engine.py
Expand Up @@ -8,30 +8,31 @@
"""

import cv2
from vse.index import InvertedIndex
from vse.ranker import SimpleRanker
from vse.comparator import Intersection
from vse.utils import *
from vse.utils import load, save


def create_vse(vocabulary_path, recognized_visual_words=100):
def create_vse(vocabulary_path, recognized_visual_words=1000):
"""Create visual search engine with default configuration."""
ranker = SimpleRanker(hist_comparator=Intersection())
index = InvertedIndex(ranker=ranker, recognized_visual_words=recognized_visual_words)
bovw = BagOfVisualWords(extractor=cv2.xfeatures2d.SIFT_create(),
matcher=cv2.BFMatcher(normType=cv2.NORM_L2),
vocabulary=load(vocabulary_path))
return VisualSearchEngine(index, bovw)
inverted_index = InvertedIndex(ranker=ranker, recognized_visual_words=recognized_visual_words)
bag_of_visual_words = BagOfVisualWords(extractor=cv2.xfeatures2d.SURF_create(),
matcher=cv2.BFMatcher(normType=cv2.NORM_L2),
vocabulary=load(vocabulary_path))
return VisualSearchEngine(inverted_index, bag_of_visual_words)


class VisualSearchEngine:
def __init__(self, image_index, bovw):
def __init__(self, image_index, bag_of_visual_words):
self.image_index = image_index
self.bovw = bovw
self.bag_of_visual_words = bag_of_visual_words

def add_to_index(self, image_id, image):
"""Adds image id and its histogram to index. Argument image contains binary image."""
hist = self.bovw.generate_hist(image)
hist = self.bag_of_visual_words.generate_hist(image)
self.image_index[image_id] = hist

def remove_from_index(self, image_id):
Expand All @@ -40,7 +41,7 @@ def remove_from_index(self, image_id):

def find_similar(self, image, n=1):
"""Returns at most n similar images."""
query_hist = self.bovw.generate_hist(image)
query_hist = self.bag_of_visual_words.generate_hist(image)
return self.image_index.find(query_hist, n)


Expand All @@ -57,15 +58,13 @@ def generate_hist(self, image):
return hist


def cluster_voc_from_img(images, extractor, recognized_visual_words=100, filename=''):
def cluster_vocabulary_from_img(images, extractor, recognized_visual_words=1000, filename=''):
"""Generates visual words vocabulary from images. Saves to file if filename given."""
desc = []
for image in images:
desc.append(extractor.detectAndCompute(image, None)[1])
return cluster_voc_from_desc(desc, recognized_visual_words, filename)
descriptors = [extractor.detectAndCompute(image, None)[1] for image in images]
return cluster_vocabulary_from_descriptors(descriptors, recognized_visual_words, filename)


def cluster_voc_from_desc(descriptors, recognized_visual_words=100, filename=''):
def cluster_vocabulary_from_descriptors(descriptors, recognized_visual_words=1000, filename=''):
"""Generates visual words vocabulary from images descriptors. Saves to file if filename given."""
bow_kmeans_trainer = cv2.BOWKMeansTrainer(recognized_visual_words)
for desc in descriptors:
Expand Down
26 changes: 13 additions & 13 deletions vse/error.py
@@ -1,4 +1,4 @@
import vse.engine
import vse


class VisualSearchEngineError(Exception):
Expand All @@ -9,33 +9,33 @@ class VisualSearchEngineError(Exception):
class DuplicatedImageError(VisualSearchEngineError):
"""Raised when trying to add already existing image to the image index."""

def __init__(self, image_path):
msg = 'Image {} already exists in the index'.format(image_path)
VisualSearchEngineError.__init__(self, msg)
def __init__(self, image_id):
message = 'Image {} already exists in the index'.format(image_id)
VisualSearchEngineError.__init__(self, message)


class NoImageError(VisualSearchEngineError):
"""Raised when trying to delete non-existing image path from the image index."""

def __init__(self, image_path):
msg = 'Image {} does not exist in the index'.format(image_path)
VisualSearchEngineError.__init__(self, msg)
def __init__(self, image_id):
message = 'Image {} does not exist in the index'.format(image_id)
VisualSearchEngineError.__init__(self, message)


class ImageSizeError(VisualSearchEngineError):
"""Raised if loaded image width or height is smaller than IMAGE_MIN_SIZE."""

def __init__(self, image_path='image'):
msg = 'Both width and height of the {} must be greater than {}'.format(image_path, vse.engine.IMAGE_MIN_SIZE)
VisualSearchEngineError.__init__(self, msg)
def __init__(self, image_id='image'):
message = 'Both width and height of the {} must be greater than {}'.format(image_id, vse.utils.IMAGE_MIN_SIZE)
VisualSearchEngineError.__init__(self, message)


class ImageLoaderError(VisualSearchEngineError):
"""Raised if cannot read image from file or buffer"""

def __init__(self, image_path=''):
if image_path:
msg = 'Cannot read file: {}'.format(image_path)
message = 'Cannot read file: {}'.format(image_path)
else:
msg = 'Cannot read image from buffer'
VisualSearchEngineError.__init__(self, msg)
message = 'Cannot read image from buffer'
VisualSearchEngineError.__init__(self, message)
2 changes: 1 addition & 1 deletion vse/index.py
@@ -1,5 +1,5 @@
import abc
from vse.error import *
from vse.error import NoImageError, DuplicatedImageError


class Index:
Expand Down
37 changes: 22 additions & 15 deletions vse/ranker.py
Expand Up @@ -6,7 +6,7 @@

__all__ = ['Ranker',
'SimpleRanker',
'WeightingRanker'
'WeighingRanker'
]


Expand All @@ -31,8 +31,12 @@ def rank(self, query_hist, items, n, freq_vector):
"""Ranks index items by similarity to query_hist. Returns list of tuples: (image_id, diff_ratio)."""
pass

def _rank_best_results(self, items, n, diff_ratio_function):
results = [(image_id, diff_ratio_function(hist)) for image_id, hist in items]
return self._n_best_results(results, n)

def _n_best_results(self, results, n):
if self.hist_comparator.REVERSED:
if self.hist_comparator.reversed:
function = heapq.nlargest
else:
function = heapq.nsmallest
Expand All @@ -44,21 +48,24 @@ def __init__(self, hist_comparator):
Ranker.__init__(self, hist_comparator)

def rank(self, query_hist, items, n, freq_vector=None):
results = [(image_id, self.hist_comparator.compare(hist, query_hist)) for image_id, hist in items]
return self._n_best_results(results, n)

def diff_ratio_function(hist):
return self.hist_comparator.compare(hist, query_hist)

return self._rank_best_results(items, n, diff_ratio_function)

class WeightingRanker(Ranker):
def __init__(self, hist_comparator, query_weight=tfidf, item_weight=tfidf):

class WeighingRanker(Ranker):
def __init__(self, hist_comparator, query_weigh_function=tfidf, item_weigh_function=tfidf):
Ranker.__init__(self, hist_comparator)
self.query_weight = query_weight
self.item_weight = item_weight
self.query_weigh_function = query_weigh_function
self.item_weigh_function = item_weigh_function

def rank(self, query_hist, items, n, freq_vector):
results = []
weighted_query_hist = normalize(self.query_weight(query_hist, freq_vector))
for image_id, hist in items:
weighted_item_hist = normalize(self.item_weight(hist, freq_vector))
diff_ratio = self.hist_comparator.compare(weighted_item_hist, weighted_query_hist)
results.append((image_id, diff_ratio))
return self._n_best_results(results, n)
weighted_query_hist = normalize(self.query_weigh_function(query_hist, freq_vector))

def diff_ratio_function(hist):
weighted_item_hist = normalize(self.item_weigh_function(hist, freq_vector))
return self.hist_comparator.compare(weighted_item_hist, weighted_query_hist)

return self._rank_best_results(items, n, diff_ratio_function)
3 changes: 2 additions & 1 deletion vse/utils.py
Expand Up @@ -78,4 +78,5 @@ def save(filename, data, protocol=pickle.HIGHEST_PROTOCOL):

def normalize(hist):
"""Normalizes histogram by casting values to [0, 1]."""
return numpy.array([val / sum(hist) for val in hist], dtype=numpy.float32)
total_sum = sum(hist)
return numpy.array([val / total_sum for val in hist], dtype=numpy.float32)

0 comments on commit d85a51b

Please sign in to comment.