Skip to content

Commit

Permalink
fix: fix plot bug recompute size (#468)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Aug 2, 2022
1 parent f588758 commit 30486b2
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
5 changes: 4 additions & 1 deletion docarray/array/mixins/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,14 +432,17 @@ def plot_image_sprites(
import matplotlib.pyplot as plt

img_per_row = ceil(sqrt(len(self)))
img_per_col = ceil(len(self) / img_per_row)
img_size = int(canvas_size / img_per_row)

if img_size < min_size:
# image is too small, recompute the size
img_size = min_size
img_per_row = int(canvas_size / img_size)

if img_per_row == 0:
img_per_row = 1

img_per_col = ceil(len(self) / img_per_row)
max_num_img = img_per_row * img_per_col
sprite_img = np.zeros(
[img_size * img_per_col, img_size * img_per_row, 3], dtype='uint8'
Expand Down
19 changes: 16 additions & 3 deletions tests/unit/array/mixins/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from docarray import DocumentArray, Document
from docarray.array.qdrant import DocumentArrayQdrant
from docarray.array.sqlite import DocumentArraySqlite
from docarray.array.annlite import DocumentArrayAnnlite, AnnliteConfig
from docarray.array.storage.qdrant import QdrantConfig
from docarray.array.storage.weaviate import WeaviateConfig
from docarray.array.weaviate import DocumentArrayWeaviate
Expand Down Expand Up @@ -68,8 +67,17 @@ def test_sprite_fail_tensor_success_uri(
(DocumentArrayElastic, lambda: ElasticConfig(n_dim=128)),
],
)
@pytest.mark.parametrize('canvas_size', [50, 512])
@pytest.mark.parametrize('min_size', [16, 64])
def test_sprite_image_generator(
pytestconfig, tmpdir, image_source, da_cls, config_gen, start_storage
pytestconfig,
tmpdir,
image_source,
da_cls,
config_gen,
canvas_size,
min_size,
start_storage,
):
files = [
f'{pytestconfig.rootdir}/tests/image-data/*.jpg',
Expand All @@ -80,7 +88,12 @@ def test_sprite_image_generator(
else:
da = da_cls.from_files(files)
da.apply(lambda d: d.load_uri_to_image_tensor())
da.plot_image_sprites(tmpdir / 'sprint_da.png', image_source=image_source)
da.plot_image_sprites(
tmpdir / 'sprint_da.png',
image_source=image_source,
canvas_size=canvas_size,
min_size=min_size,
)
assert os.path.exists(tmpdir / 'sprint_da.png')


Expand Down

0 comments on commit 30486b2

Please sign in to comment.