Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Feature/progress bar #28

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 9 additions & 3 deletions embetter/text/_s2v.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from tqdm.auto import tqdm
from sense2vec import Sense2Vec

from embetter.base import BaseEstimator


Expand All @@ -13,11 +13,17 @@ class Sense2VecEncoder(BaseEstimator):

Arguments:
path: path to downloaded model
show_progress_bar: Show a progress bar when encoding phrases
"""

def __init__(self, path):
def __init__(self, path, show_progress_bar=False):
self.s2v = Sense2Vec().from_disk(path)
self.show_progress_bar = show_progress_bar

def transform(self, X, y=None):
"""Transforms the phrase text into a numeric representation."""
return np.array([self.s2v[x] for x in X])
instances = tqdm(X, desc="Sense2Vec Encoding") if self.show_progress_bar else X
output = []
for x in instances:
output.append(self.s2v[x])
return np.array(output)
6 changes: 4 additions & 2 deletions embetter/text/_sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class SentenceEncoder(EmbetterBase):

Arguments:
name: name of model, see available options
show_progress_bar: Display progress bar when encoding sentences.

The following model names should be supported:

Expand Down Expand Up @@ -65,10 +66,11 @@ class SentenceEncoder(EmbetterBase):
```
"""

def __init__(self, name="all-MiniLM-L6-v2"):
def __init__(self, name="all-MiniLM-L6-v2", show_progress_bar=False):
self.name = name
self.tfm = SBERT(name)
self.show_progress_bar = show_progress_bar

def transform(self, X, y=None):
"""Transforms the text into a numeric representation."""
return self.tfm.encode(X)
return self.tfm.encode(X, show_progress_bar=self.show_progress_bar)
8 changes: 6 additions & 2 deletions embetter/vision/_colorhist.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from tqdm.auto import tqdm
from embetter.base import EmbetterBase


Expand All @@ -10,6 +11,7 @@ class ColorHistogramEncoder(EmbetterBase):

Arguments:
n_buckets: number of buckets per color
show_progress_bar: Show a progress bar when encoding images

**Usage**:

Expand Down Expand Up @@ -37,16 +39,18 @@ class ColorHistogramEncoder(EmbetterBase):
```
"""

def __init__(self, n_buckets=256):
def __init__(self, n_buckets=256, show_progress_bar=False):
self.n_buckets = n_buckets
self.show_progress_bar = show_progress_bar

def transform(self, X, y=None):
"""
Takes a sequence of `PIL.Image` and returns a numpy array representing
a color histogram for each.
"""
output = np.zeros((len(X), self.n_buckets * 3))
for i, x in enumerate(X):
instances = tqdm(X, desc="ColorHistogramEncoder") if self.show_progress_bar else X
for i, x in enumerate(instances):
arr = np.array(x)
output[i, :] = np.concatenate(
[
Expand Down
10 changes: 7 additions & 3 deletions embetter/vision/_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
from embetter.base import EmbetterBase


Expand All @@ -12,6 +13,7 @@ class ImageLoader(EmbetterBase):
Arguments:
convert: Color [conversion setting](https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.convert) from the Python image library.
out: What kind of image output format to expect.
show_progress_bar: Show a progress bar when loading images

**Usage**

Expand Down Expand Up @@ -50,9 +52,10 @@ class ImageLoader(EmbetterBase):

"""

def __init__(self, convert: str = "RGB", out: str = "pil") -> None:
def __init__(self, convert: str = "RGB", out: str = "pil", show_progress_bar=False) -> None:
self.convert = convert
self.out = out
self.show_progress_bar = show_progress_bar

def fit(self, X, y=None):
"""
Expand All @@ -69,7 +72,8 @@ def transform(self, X, y=None):
"""
Turn a file path into numpy array containing pixel values.
"""
instances = tqdm(X, desc="Loading Images") if self.show_progress_bar else X
if self.out == "pil":
return [Image.open(x).convert(self.convert) for x in X]
return [Image.open(x).convert(self.convert) for x in instances]
if self.out == "numpy":
return np.array([np.array(Image.open(x).convert(self.convert)) for x in X])
return np.array([np.array(Image.open(x).convert(self.convert)) for x in instances])
11 changes: 9 additions & 2 deletions embetter/vision/_torchvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from timm.data import resolve_data_config

import numpy as np
from tqdm.auto import tqdm
from embetter.base import EmbetterBase


Expand All @@ -18,6 +19,7 @@ class TimmEncoder(EmbetterBase):
Arguments:
name: name of the model to use
encode_predictions: output the predictions instead of the pooled embedding layer before
show_progress_bar: Show a progress bar when processing images

**Usage**:

Expand Down Expand Up @@ -45,18 +47,23 @@ class TimmEncoder(EmbetterBase):
```
"""

def __init__(self, name="mobilenetv3_large_100", encode_predictions=False):
def __init__(self, name="mobilenetv3_large_100", encode_predictions=False, show_progress_bar=False):
self.name = name
self.encode_predictions = encode_predictions
self.model = timm.create_model(name, pretrained=True, num_classes=0)
if self.encode_predictions:
self.model = timm.create_model(name, pretrained=True)
self.config = resolve_data_config({}, model=self.model)
self.transform_img = create_transform(**self.config)
self.show_progress_bar = show_progress_bar

def transform(self, X, y=None):
"""
Transforms grabbed images into numeric representations.
"""
batch = [self.transform_img(x).unsqueeze(0) for x in X]
return np.array([self.model(x).squeeze(0).detach().numpy() for x in batch])
instances = tqdm(batch, desc="Encoding using Timm") if self.show_progress_bar else batch
output = []
for x in instances:
output.append(self.model(x).squeeze(0).detach().numpy())
return np.array(output)
13 changes: 13 additions & 0 deletions tests/test_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import numpy as np

from embetter.text import SentenceEncoder


def test_basic_sentence_encoder():
encoder = SentenceEncoder(show_progress_bar=True)
output_dim = encoder.tfm._modules['1'].word_embedding_dimension
test_sentences = ["This is a test sentence!", "And this is another one", "\rUnicode stuff: ♣️,♦️,❤️,♠️\n"]
output = encoder.fit_transform(test_sentences)
assert isinstance(output, np.ndarray)
assert output.shape == (len(test_sentences), output_dim)

4 changes: 2 additions & 2 deletions tests/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def test_color_hist_resize(n_buckets):
"""Make sure we can resize and it fits"""
X = ImageLoader().fit_transform(["tests/data/thiscatdoesnotexist.jpeg"])
shape_out = ColorHistogramEncoder(n_buckets=n_buckets).fit_transform(X).shape
shape_out = ColorHistogramEncoder(n_buckets=n_buckets, show_progress_bar=True).fit_transform(X).shape
shape_exp = (1, n_buckets * 3)
assert shape_exp == shape_out

Expand All @@ -15,6 +15,6 @@ def test_color_hist_resize(n_buckets):
def test_basic_timm(encode_predictions, size):
"""Super basic check for torch image model."""
model = TimmEncoder("mobilenetv2_120d", encode_predictions=encode_predictions)
X = ImageLoader().fit_transform(["tests/data/thiscatdoesnotexist.jpeg"])
X = ImageLoader(show_progress_bar=True).fit_transform(["tests/data/thiscatdoesnotexist.jpeg"])
out = model.fit_transform(X)
assert out.shape == (1, size)