Skip to content

Commit

Permalink
Add more documentation and util to filter invalid images
Browse files Browse the repository at this point in the history
  • Loading branch information
chsasank committed Aug 31, 2019
1 parent 1e8e42d commit d67aa06
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 11 deletions.
20 changes: 17 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,26 @@ features that just work.
Idea is to create a simple python package to do this:

```python
from image_embedding import image_embedding
embedding = image_embedding('your_image.png')
from image_features import image_features
embedding = image_features(['your_image_1.png', 'your_image_2.jpg'])
```

So that you can use the following:


```python
from sklearn import linear_model
from image_features import image_features
X_train = image_features(['your_image_1.png', 'your_image_2.jpg'])
y_train = ['cat', 'dog']
clf = linear_model.linear_model.LogisticRegressionCV()
clf.fit(X_train, y_train)
```

Package internally uses pretrained deep learning model like resnet50 (default).

## Aim

* Inspired by [face_recognition](https://github.com/ageitgey/face_recognition) and how it just works most of the time.
* Minimal dependencies (no torch/tf etc.)
* Simple yet fairly complete implementation.
* If there is enough interest in this, can put on pypi
3 changes: 2 additions & 1 deletion image_features/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .image_features import image_features
from .image_features import image_features
from .utils import filter_invalid_images
39 changes: 32 additions & 7 deletions image_features/image_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@ def get_model(model_name):


class ImageLoader():
def __init__(self, img_paths, model, img_size=224):
def __init__(self, img_paths, model, img_size=224, augment=False):
self.load_img = utils.LoadImage()
self.tf_img = utils.TransformImage(model, scale=img_size / 256)
additional_args = {}
if augment:
additional_args = {
'random_crop': True, 'random_hflip': False,
'random_vflip': False
}
self.tf_img = utils.TransformImage(
model, scale=img_size / 256, **additional_args)
self.img_paths = img_paths

def __len__(self):
Expand All @@ -27,21 +34,39 @@ def __getitem__(self, idx):

def image_features(
img_paths, model_name='resnet50', use_gpu=torch.cuda.is_available(),
batch_size=32, num_workers=4, progress=False):
batch_size=32, num_workers=4, progress=False, augment=False):
"""
Extract deep learning image features from images.
Args:
img_paths(list): List of paths of images to extract features from.
model_name(str, optional): Deep learning model to use for feature
extraction. Default is resnet50. List of avaiable models are here:
https://github.com/Cadene/pretrained-models.pytorch
use_gpu(bool): If gpu is to be used for feature extraction. By default,
uses cuda if nvidia driver is installed.
batch_size(int): Batch size to be used for feature extraction.
num_workers(int): Number of workers to use for image loading.
progress(bool): If true, enables progressbar.
augment(bool): If true, images are augmented before passing through
the model. Useful if you're training a classifier based on these
features.
"""
if use_gpu:
device = torch.device('cuda')
else:
device = torch.device('cpu')

if isinstance(img_paths, str):
img_paths = [img_paths]
raise ValueError(f'img_paths should be a list of image paths.')

model = get_model(model_name).to(device)
dataset = ImageLoader(img_paths, model)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
dataset = ImageLoader(img_paths, model, augment=augment)
dataloader = torch.utils.data.DataLoader(
dataset, shuffle=False, batch_size=batch_size, num_workers=num_workers)
with torch.no_grad():
if progress:
pbar = tqdm(total=len(img_paths), desc='Computing features')
pbar = tqdm(total=len(img_paths), desc='Computing image features')

output_features = []
for batch in dataloader:
Expand Down
31 changes: 31 additions & 0 deletions image_features/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import multiprocessing
import pretrainedmodels.utils as utils
from tqdm import tqdm

load_img = utils.LoadImage()


def _is_valid_img(img_path):
try:
load_img(img_path)
return True
except Exception:
return False


def filter_invalid_images(img_paths, num_workers=4, progress=False):
"""Filter invalid images before computing expensive features."""
with multiprocessing.Pool(num_workers) as p:
if progress:
load_works = list(tqdm(
p.imap(_is_valid_img, img_paths),
total=len(img_paths),
desc="Filtering invalid images"))
else:
load_works = p.map(_is_valid_img, img_paths)

img_paths = [
img_path for img_path, is_loadable in
zip(img_paths, load_works) if is_loadable
]
return img_paths
Binary file added tests/data/example_image_corrupted.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
20 changes: 20 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np
import os
from image_features.utils import filter_invalid_images


data_path = os.path.join(os.path.dirname(__file__), 'data')


def test_image_features():
example_imgs = [
os.path.join(data_path, 'example_image.jpg'),
os.path.join(data_path, 'example_image_2.JPG'),
os.path.join(data_path, 'example_image_corrupted.JPG')
]

valid_imgs = filter_invalid_images(example_imgs)
assert valid_imgs == example_imgs[:2]

if __name__ == '__main__':
test_image_features()

0 comments on commit d67aa06

Please sign in to comment.